Latest Articles

Conversation State Management: Building Context-Aware AI Assistants

Introduction: Conversation state management is the foundation of building coherent, context-aware AI assistants. Without proper state management, every message is processed in isolation—the assistant forgets what was discussed moments ago, loses track of user preferences, and fails to maintain the thread of complex multi-turn conversations. Effective state management involves storing conversation history, extracting and persisting relevant context, managing memory across sessions, and efficiently loading state for each interaction. The challenges are significant: conversations can span thousands of tokens, users expect instant responses, and state must be consistent across distributed systems. This guide covers practical patterns for conversation state management: from simple in-memory stores to sophisticated persistent memory systems with summarization and retrieval.

Conversation State Management
State Management: Load, Update, Persist

Core State Models

from dataclasses import dataclass, field
from typing import Any, Optional, List, Dict
from datetime import datetime
from enum import Enum
import uuid
import json

class MessageRole(Enum):
    """Role of message sender."""
    
    USER = "user"
    ASSISTANT = "assistant"
    SYSTEM = "system"
    TOOL = "tool"

@dataclass
class Message:
    """A conversation message."""
    
    role: MessageRole
    content: str
    timestamp: datetime = field(default_factory=datetime.now)
    message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
    metadata: dict = field(default_factory=dict)
    
    def to_dict(self) -> dict:
        return {
            "role": self.role.value,
            "content": self.content,
            "timestamp": self.timestamp.isoformat(),
            "message_id": self.message_id,
            "metadata": self.metadata
        }
    
    @classmethod
    def from_dict(cls, data: dict) -> "Message":
        return cls(
            role=MessageRole(data["role"]),
            content=data["content"],
            timestamp=datetime.fromisoformat(data["timestamp"]),
            message_id=data.get("message_id", str(uuid.uuid4())),
            metadata=data.get("metadata", {})
        )

@dataclass
class ConversationContext:
    """Extracted context from conversation."""
    
    user_name: Optional[str] = None
    user_preferences: dict = field(default_factory=dict)
    current_topic: Optional[str] = None
    entities: list[str] = field(default_factory=list)
    facts: list[str] = field(default_factory=list)
    pending_tasks: list[str] = field(default_factory=list)
    
    def merge(self, other: "ConversationContext"):
        """Merge another context into this one."""
        
        if other.user_name:
            self.user_name = other.user_name
        
        self.user_preferences.update(other.user_preferences)
        
        if other.current_topic:
            self.current_topic = other.current_topic
        
        self.entities = list(set(self.entities + other.entities))
        self.facts = list(set(self.facts + other.facts))
        self.pending_tasks = list(set(self.pending_tasks + other.pending_tasks))

@dataclass
class ConversationState:
    """Complete conversation state."""
    
    conversation_id: str
    messages: list[Message] = field(default_factory=list)
    context: ConversationContext = field(default_factory=ConversationContext)
    created_at: datetime = field(default_factory=datetime.now)
    updated_at: datetime = field(default_factory=datetime.now)
    metadata: dict = field(default_factory=dict)
    
    def add_message(self, role: MessageRole, content: str, metadata: dict = None):
        """Add a message to the conversation."""
        
        message = Message(
            role=role,
            content=content,
            metadata=metadata or {}
        )
        
        self.messages.append(message)
        self.updated_at = datetime.now()
        
        return message
    
    def get_messages_for_llm(self, max_messages: int = None) -> list[dict]:
        """Get messages formatted for LLM."""
        
        messages = self.messages
        
        if max_messages:
            messages = messages[-max_messages:]
        
        return [{"role": m.role.value, "content": m.content} for m in messages]
    
    def get_token_count(self) -> int:
        """Estimate token count."""
        
        total_chars = sum(len(m.content) for m in self.messages)
        return total_chars // 4  # Rough estimate
    
    def to_dict(self) -> dict:
        return {
            "conversation_id": self.conversation_id,
            "messages": [m.to_dict() for m in self.messages],
            "context": {
                "user_name": self.context.user_name,
                "user_preferences": self.context.user_preferences,
                "current_topic": self.context.current_topic,
                "entities": self.context.entities,
                "facts": self.context.facts,
                "pending_tasks": self.context.pending_tasks
            },
            "created_at": self.created_at.isoformat(),
            "updated_at": self.updated_at.isoformat(),
            "metadata": self.metadata
        }
    
    @classmethod
    def from_dict(cls, data: dict) -> "ConversationState":
        context_data = data.get("context", {})
        context = ConversationContext(
            user_name=context_data.get("user_name"),
            user_preferences=context_data.get("user_preferences", {}),
            current_topic=context_data.get("current_topic"),
            entities=context_data.get("entities", []),
            facts=context_data.get("facts", []),
            pending_tasks=context_data.get("pending_tasks", [])
        )
        
        return cls(
            conversation_id=data["conversation_id"],
            messages=[Message.from_dict(m) for m in data.get("messages", [])],
            context=context,
            created_at=datetime.fromisoformat(data["created_at"]),
            updated_at=datetime.fromisoformat(data["updated_at"]),
            metadata=data.get("metadata", {})
        )

State Storage Backends

from abc import ABC, abstractmethod
from typing import Optional, List
import json
import asyncio

class StateStore(ABC):
    """Abstract state store."""
    
    @abstractmethod
    async def save(self, state: ConversationState):
        """Save conversation state."""
        pass
    
    @abstractmethod
    async def load(self, conversation_id: str) -> Optional[ConversationState]:
        """Load conversation state."""
        pass
    
    @abstractmethod
    async def delete(self, conversation_id: str):
        """Delete conversation state."""
        pass
    
    @abstractmethod
    async def list_conversations(self, user_id: str = None) -> list[str]:
        """List conversation IDs."""
        pass

class InMemoryStateStore(StateStore):
    """In-memory state store."""
    
    def __init__(self):
        self.states: dict[str, ConversationState] = {}
    
    async def save(self, state: ConversationState):
        self.states[state.conversation_id] = state
    
    async def load(self, conversation_id: str) -> Optional[ConversationState]:
        return self.states.get(conversation_id)
    
    async def delete(self, conversation_id: str):
        if conversation_id in self.states:
            del self.states[conversation_id]
    
    async def list_conversations(self, user_id: str = None) -> list[str]:
        if user_id:
            return [
                cid for cid, state in self.states.items()
                if state.metadata.get("user_id") == user_id
            ]
        return list(self.states.keys())

class RedisStateStore(StateStore):
    """Redis-backed state store."""
    
    def __init__(self, redis_url: str = "redis://localhost:6379"):
        self.redis_url = redis_url
        self._client = None
    
    async def _get_client(self):
        if self._client is None:
            import redis.asyncio as redis
            self._client = redis.from_url(self.redis_url)
        return self._client
    
    async def save(self, state: ConversationState):
        client = await self._get_client()
        
        key = f"conversation:{state.conversation_id}"
        data = json.dumps(state.to_dict())
        
        await client.set(key, data)
        
        # Add to user's conversation list
        user_id = state.metadata.get("user_id")
        if user_id:
            await client.sadd(f"user:{user_id}:conversations", state.conversation_id)
    
    async def load(self, conversation_id: str) -> Optional[ConversationState]:
        client = await self._get_client()
        
        key = f"conversation:{conversation_id}"
        data = await client.get(key)
        
        if data:
            return ConversationState.from_dict(json.loads(data))
        
        return None
    
    async def delete(self, conversation_id: str):
        client = await self._get_client()
        
        key = f"conversation:{conversation_id}"
        await client.delete(key)
    
    async def list_conversations(self, user_id: str = None) -> list[str]:
        client = await self._get_client()
        
        if user_id:
            members = await client.smembers(f"user:{user_id}:conversations")
            return [m.decode() for m in members]
        
        keys = await client.keys("conversation:*")
        return [k.decode().replace("conversation:", "") for k in keys]

class PostgresStateStore(StateStore):
    """PostgreSQL-backed state store."""
    
    def __init__(self, connection_string: str):
        self.connection_string = connection_string
        self._pool = None
    
    async def _get_pool(self):
        if self._pool is None:
            import asyncpg
            self._pool = await asyncpg.create_pool(self.connection_string)
        return self._pool
    
    async def save(self, state: ConversationState):
        pool = await self._get_pool()
        
        async with pool.acquire() as conn:
            await conn.execute("""
                INSERT INTO conversations (id, data, user_id, updated_at)
                VALUES ($1, $2, $3, $4)
                ON CONFLICT (id) DO UPDATE
                SET data = $2, updated_at = $4
            """,
                state.conversation_id,
                json.dumps(state.to_dict()),
                state.metadata.get("user_id"),
                state.updated_at
            )
    
    async def load(self, conversation_id: str) -> Optional[ConversationState]:
        pool = await self._get_pool()
        
        async with pool.acquire() as conn:
            row = await conn.fetchrow(
                "SELECT data FROM conversations WHERE id = $1",
                conversation_id
            )
            
            if row:
                return ConversationState.from_dict(json.loads(row["data"]))
        
        return None
    
    async def delete(self, conversation_id: str):
        pool = await self._get_pool()
        
        async with pool.acquire() as conn:
            await conn.execute(
                "DELETE FROM conversations WHERE id = $1",
                conversation_id
            )
    
    async def list_conversations(self, user_id: str = None) -> list[str]:
        pool = await self._get_pool()
        
        async with pool.acquire() as conn:
            if user_id:
                rows = await conn.fetch(
                    "SELECT id FROM conversations WHERE user_id = $1 ORDER BY updated_at DESC",
                    user_id
                )
            else:
                rows = await conn.fetch(
                    "SELECT id FROM conversations ORDER BY updated_at DESC"
                )
            
            return [row["id"] for row in rows]

Context Extraction

from dataclasses import dataclass
from typing import Any, Optional, List
import re

class ContextExtractor:
    """Extract context from conversations."""
    
    def __init__(self, llm_client: Any = None):
        self.llm = llm_client
    
    def extract_entities(self, text: str) -> list[str]:
        """Extract named entities from text."""
        
        # Simple pattern-based extraction
        entities = []
        
        # Capitalized words (potential names/places)
        capitalized = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', text)
        entities.extend(capitalized)
        
        # Email addresses
        emails = re.findall(r'\b[\w.-]+@[\w.-]+\.\w+\b', text)
        entities.extend(emails)
        
        # URLs
        urls = re.findall(r'https?://\S+', text)
        entities.extend(urls)
        
        return list(set(entities))
    
    def extract_user_info(self, messages: list[Message]) -> dict:
        """Extract user information from messages."""
        
        info = {}
        
        for message in messages:
            if message.role != MessageRole.USER:
                continue
            
            content = message.content.lower()
            
            # Name patterns
            name_patterns = [
                r"my name is (\w+)",
                r"i'm (\w+)",
                r"i am (\w+)",
                r"call me (\w+)"
            ]
            
            for pattern in name_patterns:
                match = re.search(pattern, content)
                if match:
                    info["name"] = match.group(1).capitalize()
                    break
            
            # Preference patterns
            if "prefer" in content or "like" in content:
                info.setdefault("preferences", []).append(message.content)
        
        return info
    
    async def extract_with_llm(
        self,
        messages: list[Message]
    ) -> ConversationContext:
        """Use LLM to extract context."""
        
        if not self.llm:
            return ConversationContext()
        
        conversation_text = "\n".join([
            f"{m.role.value}: {m.content}"
            for m in messages[-10:]  # Last 10 messages
        ])
        
        prompt = f"""Extract key information from this conversation:

{conversation_text}

Extract and return as JSON:
- user_name: The user's name if mentioned
- current_topic: The main topic being discussed
- entities: List of important names, places, products mentioned
- facts: List of facts learned about the user or situation
- pending_tasks: Any tasks or requests that haven't been completed

JSON:"""
        
        response = await self.llm.generate(prompt)
        
        try:
            import json
            data = json.loads(response)
            
            return ConversationContext(
                user_name=data.get("user_name"),
                current_topic=data.get("current_topic"),
                entities=data.get("entities", []),
                facts=data.get("facts", []),
                pending_tasks=data.get("pending_tasks", [])
            )
        except:
            return ConversationContext()

class ConversationSummarizer:
    """Summarize conversations for memory compression."""
    
    def __init__(self, llm_client: Any):
        self.llm = llm_client
    
    async def summarize(
        self,
        messages: list[Message],
        max_length: int = 500
    ) -> str:
        """Summarize conversation."""
        
        conversation_text = "\n".join([
            f"{m.role.value}: {m.content}"
            for m in messages
        ])
        
        prompt = f"""Summarize this conversation in {max_length} characters or less.
Focus on key decisions, facts learned, and any pending items.

Conversation:
{conversation_text}

Summary:"""
        
        return await self.llm.generate(prompt)
    
    async def create_rolling_summary(
        self,
        previous_summary: str,
        new_messages: list[Message]
    ) -> str:
        """Update summary with new messages."""
        
        new_text = "\n".join([
            f"{m.role.value}: {m.content}"
            for m in new_messages
        ])
        
        prompt = f"""Update this conversation summary with new messages.

Previous summary:
{previous_summary}

New messages:
{new_text}

Updated summary (keep under 500 characters):"""
        
        return await self.llm.generate(prompt)

Memory Management

from dataclasses import dataclass
from typing import Any, Optional, List
from datetime import datetime, timedelta

@dataclass
class MemoryConfig:
    """Configuration for memory management."""
    
    max_messages: int = 50
    max_tokens: int = 4000
    summarize_after: int = 20
    context_window: int = 10

class ConversationMemory:
    """Manage conversation memory."""
    
    def __init__(
        self,
        config: MemoryConfig,
        summarizer: ConversationSummarizer = None
    ):
        self.config = config
        self.summarizer = summarizer
        self.summary: str = ""
    
    async def prepare_context(
        self,
        state: ConversationState
    ) -> list[dict]:
        """Prepare context for LLM call."""
        
        messages = []
        
        # Add summary if available
        if self.summary:
            messages.append({
                "role": "system",
                "content": f"Previous conversation summary: {self.summary}"
            })
        
        # Add recent messages
        recent = state.messages[-self.config.context_window:]
        
        for msg in recent:
            messages.append({
                "role": msg.role.value,
                "content": msg.content
            })
        
        return messages
    
    async def update_memory(self, state: ConversationState):
        """Update memory after new messages."""
        
        # Check if we need to summarize
        if len(state.messages) > self.config.summarize_after:
            if self.summarizer:
                # Summarize older messages
                old_messages = state.messages[:-self.config.context_window]
                
                if self.summary:
                    self.summary = await self.summarizer.create_rolling_summary(
                        self.summary,
                        old_messages[-5:]  # Just the new old messages
                    )
                else:
                    self.summary = await self.summarizer.summarize(old_messages)
    
    def trim_messages(self, state: ConversationState) -> ConversationState:
        """Trim messages to fit limits."""
        
        if len(state.messages) <= self.config.max_messages:
            return state
        
        # Keep system messages and recent messages
        system_messages = [m for m in state.messages if m.role == MessageRole.SYSTEM]
        other_messages = [m for m in state.messages if m.role != MessageRole.SYSTEM]
        
        # Keep most recent
        kept = other_messages[-(self.config.max_messages - len(system_messages)):]
        
        state.messages = system_messages + kept
        
        return state

class WindowedMemory:
    """Sliding window memory."""
    
    def __init__(self, window_size: int = 10):
        self.window_size = window_size
    
    def get_context(self, state: ConversationState) -> list[dict]:
        """Get windowed context."""
        
        messages = state.messages[-self.window_size:]
        
        return [
            {"role": m.role.value, "content": m.content}
            for m in messages
        ]

class SummaryMemory:
    """Summary-based memory."""
    
    def __init__(self, summarizer: ConversationSummarizer):
        self.summarizer = summarizer
        self.summaries: dict[str, str] = {}
    
    async def get_context(
        self,
        state: ConversationState,
        recent_count: int = 5
    ) -> list[dict]:
        """Get context with summary."""
        
        messages = []
        
        # Get or create summary
        if state.conversation_id not in self.summaries:
            if len(state.messages) > recent_count:
                old_messages = state.messages[:-recent_count]
                self.summaries[state.conversation_id] = await self.summarizer.summarize(old_messages)
        
        # Add summary
        summary = self.summaries.get(state.conversation_id)
        if summary:
            messages.append({
                "role": "system",
                "content": f"Conversation history: {summary}"
            })
        
        # Add recent messages
        for msg in state.messages[-recent_count:]:
            messages.append({
                "role": msg.role.value,
                "content": msg.content
            })
        
        return messages
    
    async def update(self, state: ConversationState, recent_count: int = 5):
        """Update summary with new messages."""
        
        if len(state.messages) <= recent_count:
            return
        
        old_summary = self.summaries.get(state.conversation_id, "")
        new_messages = state.messages[-(recent_count + 3):-recent_count]
        
        if new_messages:
            self.summaries[state.conversation_id] = await self.summarizer.create_rolling_summary(
                old_summary,
                new_messages
            )

class EntityMemory:
    """Entity-based memory."""
    
    def __init__(self, extractor: ContextExtractor):
        self.extractor = extractor
        self.entities: dict[str, dict] = {}  # conversation_id -> entity info
    
    async def update(self, state: ConversationState):
        """Update entity memory."""
        
        cid = state.conversation_id
        
        if cid not in self.entities:
            self.entities[cid] = {}
        
        # Extract from recent messages
        for msg in state.messages[-3:]:
            entities = self.extractor.extract_entities(msg.content)
            
            for entity in entities:
                if entity not in self.entities[cid]:
                    self.entities[cid][entity] = {
                        "first_mentioned": msg.timestamp,
                        "mentions": 0
                    }
                
                self.entities[cid][entity]["mentions"] += 1
                self.entities[cid][entity]["last_mentioned"] = msg.timestamp
    
    def get_relevant_entities(
        self,
        state: ConversationState,
        limit: int = 10
    ) -> list[str]:
        """Get most relevant entities."""
        
        cid = state.conversation_id
        
        if cid not in self.entities:
            return []
        
        # Sort by mention count
        sorted_entities = sorted(
            self.entities[cid].items(),
            key=lambda x: x[1]["mentions"],
            reverse=True
        )
        
        return [e[0] for e in sorted_entities[:limit]]

Production State Service

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, List, Dict

app = FastAPI()

class CreateConversationRequest(BaseModel):
    user_id: Optional[str] = None
    metadata: Optional[Dict] = None

class AddMessageRequest(BaseModel):
    role: str
    content: str
    metadata: Optional[Dict] = None

class GetContextRequest(BaseModel):
    max_messages: int = 10
    include_summary: bool = True

# Initialize components
store = InMemoryStateStore()
memory_config = MemoryConfig()

@app.post("/v1/conversations")
async def create_conversation(request: CreateConversationRequest) -> dict:
    """Create a new conversation."""
    
    conversation_id = str(uuid.uuid4())
    
    state = ConversationState(
        conversation_id=conversation_id,
        metadata={
            "user_id": request.user_id,
            **(request.metadata or {})
        }
    )
    
    await store.save(state)
    
    return {
        "conversation_id": conversation_id,
        "created_at": state.created_at.isoformat()
    }

@app.get("/v1/conversations/{conversation_id}")
async def get_conversation(conversation_id: str) -> dict:
    """Get conversation state."""
    
    state = await store.load(conversation_id)
    
    if not state:
        raise HTTPException(status_code=404, detail="Conversation not found")
    
    return state.to_dict()

@app.post("/v1/conversations/{conversation_id}/messages")
async def add_message(
    conversation_id: str,
    request: AddMessageRequest
) -> dict:
    """Add a message to conversation."""
    
    state = await store.load(conversation_id)
    
    if not state:
        raise HTTPException(status_code=404, detail="Conversation not found")
    
    message = state.add_message(
        role=MessageRole(request.role),
        content=request.content,
        metadata=request.metadata
    )
    
    await store.save(state)
    
    return {
        "message_id": message.message_id,
        "timestamp": message.timestamp.isoformat()
    }

@app.post("/v1/conversations/{conversation_id}/context")
async def get_context(
    conversation_id: str,
    request: GetContextRequest
) -> dict:
    """Get context for LLM call."""
    
    state = await store.load(conversation_id)
    
    if not state:
        raise HTTPException(status_code=404, detail="Conversation not found")
    
    messages = state.get_messages_for_llm(max_messages=request.max_messages)
    
    return {
        "messages": messages,
        "context": {
            "user_name": state.context.user_name,
            "current_topic": state.context.current_topic,
            "entities": state.context.entities[:10]
        },
        "token_estimate": state.get_token_count()
    }

@app.get("/v1/conversations/{conversation_id}/messages")
async def list_messages(
    conversation_id: str,
    limit: int = 50,
    offset: int = 0
) -> list[dict]:
    """List messages in conversation."""
    
    state = await store.load(conversation_id)
    
    if not state:
        raise HTTPException(status_code=404, detail="Conversation not found")
    
    messages = state.messages[offset:offset + limit]
    
    return [m.to_dict() for m in messages]

@app.delete("/v1/conversations/{conversation_id}")
async def delete_conversation(conversation_id: str) -> dict:
    """Delete a conversation."""
    
    await store.delete(conversation_id)
    
    return {"deleted": True}

@app.get("/v1/users/{user_id}/conversations")
async def list_user_conversations(user_id: str) -> list[str]:
    """List conversations for a user."""
    
    return await store.list_conversations(user_id=user_id)

@app.get("/health")
async def health():
    return {"status": "healthy"}

References

Conclusion

Conversation state management is essential for building AI assistants that feel coherent and contextually aware. The key is choosing the right storage backend for your scale and requirements: in-memory for development and small deployments, Redis for high-performance caching with persistence, PostgreSQL for durable storage with complex queries. Memory management strategies depend on your context window constraints: sliding windows are simple but lose long-term context, summarization preserves key information but requires LLM calls, entity-based memory tracks important facts efficiently. For production systems, combine these approaches: use sliding windows for recent context, periodic summarization for older history, and entity extraction for persistent facts. Context extraction should happen asynchronously to avoid adding latency to user interactions. Design your state models to be serializable and version them for backward compatibility as your application evolves. Monitor memory usage and token counts to stay within model limits. With proper state management, your AI assistant can maintain coherent conversations across sessions, remember user preferences, and provide contextually relevant responses that make users feel understood.


Discover more from Code, Cloud & Context

Subscribe to get the latest posts sent to your email.

About the Author

I am a Cloud Architect and Developer passionate about solving complex problems with modern technology. My blog explores the intersection of Cloud Architecture, Artificial Intelligence, and Software Engineering. I share tutorials, deep dives, and insights into building scalable, intelligent systems.

Areas of Expertise

Cloud Architecture (Azure, AWS)
Artificial Intelligence & LLMs
DevOps & Kubernetes
Backend Dev (C#, .NET, Python, Node.js)
© 2025 Code, Cloud & Context | Built by Nithin Mohan TK | Powered by Passion