Introduction: Conversation state management is the foundation of building coherent, context-aware AI assistants. Without proper state management, every message is processed in isolation—the assistant forgets what was discussed moments ago, loses track of user preferences, and fails to maintain the thread of complex multi-turn conversations. Effective state management involves storing conversation history, extracting and persisting relevant context, managing memory across sessions, and efficiently loading state for each interaction. The challenges are significant: conversations can span thousands of tokens, users expect instant responses, and state must be consistent across distributed systems. This guide covers practical patterns for conversation state management: from simple in-memory stores to sophisticated persistent memory systems with summarization and retrieval.

Core State Models
from dataclasses import dataclass, field
from typing import Any, Optional, List, Dict
from datetime import datetime
from enum import Enum
import uuid
import json
class MessageRole(Enum):
"""Role of message sender."""
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
TOOL = "tool"
@dataclass
class Message:
"""A conversation message."""
role: MessageRole
content: str
timestamp: datetime = field(default_factory=datetime.now)
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
return {
"role": self.role.value,
"content": self.content,
"timestamp": self.timestamp.isoformat(),
"message_id": self.message_id,
"metadata": self.metadata
}
@classmethod
def from_dict(cls, data: dict) -> "Message":
return cls(
role=MessageRole(data["role"]),
content=data["content"],
timestamp=datetime.fromisoformat(data["timestamp"]),
message_id=data.get("message_id", str(uuid.uuid4())),
metadata=data.get("metadata", {})
)
@dataclass
class ConversationContext:
"""Extracted context from conversation."""
user_name: Optional[str] = None
user_preferences: dict = field(default_factory=dict)
current_topic: Optional[str] = None
entities: list[str] = field(default_factory=list)
facts: list[str] = field(default_factory=list)
pending_tasks: list[str] = field(default_factory=list)
def merge(self, other: "ConversationContext"):
"""Merge another context into this one."""
if other.user_name:
self.user_name = other.user_name
self.user_preferences.update(other.user_preferences)
if other.current_topic:
self.current_topic = other.current_topic
self.entities = list(set(self.entities + other.entities))
self.facts = list(set(self.facts + other.facts))
self.pending_tasks = list(set(self.pending_tasks + other.pending_tasks))
@dataclass
class ConversationState:
"""Complete conversation state."""
conversation_id: str
messages: list[Message] = field(default_factory=list)
context: ConversationContext = field(default_factory=ConversationContext)
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
metadata: dict = field(default_factory=dict)
def add_message(self, role: MessageRole, content: str, metadata: dict = None):
"""Add a message to the conversation."""
message = Message(
role=role,
content=content,
metadata=metadata or {}
)
self.messages.append(message)
self.updated_at = datetime.now()
return message
def get_messages_for_llm(self, max_messages: int = None) -> list[dict]:
"""Get messages formatted for LLM."""
messages = self.messages
if max_messages:
messages = messages[-max_messages:]
return [{"role": m.role.value, "content": m.content} for m in messages]
def get_token_count(self) -> int:
"""Estimate token count."""
total_chars = sum(len(m.content) for m in self.messages)
return total_chars // 4 # Rough estimate
def to_dict(self) -> dict:
return {
"conversation_id": self.conversation_id,
"messages": [m.to_dict() for m in self.messages],
"context": {
"user_name": self.context.user_name,
"user_preferences": self.context.user_preferences,
"current_topic": self.context.current_topic,
"entities": self.context.entities,
"facts": self.context.facts,
"pending_tasks": self.context.pending_tasks
},
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"metadata": self.metadata
}
@classmethod
def from_dict(cls, data: dict) -> "ConversationState":
context_data = data.get("context", {})
context = ConversationContext(
user_name=context_data.get("user_name"),
user_preferences=context_data.get("user_preferences", {}),
current_topic=context_data.get("current_topic"),
entities=context_data.get("entities", []),
facts=context_data.get("facts", []),
pending_tasks=context_data.get("pending_tasks", [])
)
return cls(
conversation_id=data["conversation_id"],
messages=[Message.from_dict(m) for m in data.get("messages", [])],
context=context,
created_at=datetime.fromisoformat(data["created_at"]),
updated_at=datetime.fromisoformat(data["updated_at"]),
metadata=data.get("metadata", {})
)
State Storage Backends
from abc import ABC, abstractmethod
from typing import Optional, List
import json
import asyncio
class StateStore(ABC):
"""Abstract state store."""
@abstractmethod
async def save(self, state: ConversationState):
"""Save conversation state."""
pass
@abstractmethod
async def load(self, conversation_id: str) -> Optional[ConversationState]:
"""Load conversation state."""
pass
@abstractmethod
async def delete(self, conversation_id: str):
"""Delete conversation state."""
pass
@abstractmethod
async def list_conversations(self, user_id: str = None) -> list[str]:
"""List conversation IDs."""
pass
class InMemoryStateStore(StateStore):
"""In-memory state store."""
def __init__(self):
self.states: dict[str, ConversationState] = {}
async def save(self, state: ConversationState):
self.states[state.conversation_id] = state
async def load(self, conversation_id: str) -> Optional[ConversationState]:
return self.states.get(conversation_id)
async def delete(self, conversation_id: str):
if conversation_id in self.states:
del self.states[conversation_id]
async def list_conversations(self, user_id: str = None) -> list[str]:
if user_id:
return [
cid for cid, state in self.states.items()
if state.metadata.get("user_id") == user_id
]
return list(self.states.keys())
class RedisStateStore(StateStore):
"""Redis-backed state store."""
def __init__(self, redis_url: str = "redis://localhost:6379"):
self.redis_url = redis_url
self._client = None
async def _get_client(self):
if self._client is None:
import redis.asyncio as redis
self._client = redis.from_url(self.redis_url)
return self._client
async def save(self, state: ConversationState):
client = await self._get_client()
key = f"conversation:{state.conversation_id}"
data = json.dumps(state.to_dict())
await client.set(key, data)
# Add to user's conversation list
user_id = state.metadata.get("user_id")
if user_id:
await client.sadd(f"user:{user_id}:conversations", state.conversation_id)
async def load(self, conversation_id: str) -> Optional[ConversationState]:
client = await self._get_client()
key = f"conversation:{conversation_id}"
data = await client.get(key)
if data:
return ConversationState.from_dict(json.loads(data))
return None
async def delete(self, conversation_id: str):
client = await self._get_client()
key = f"conversation:{conversation_id}"
await client.delete(key)
async def list_conversations(self, user_id: str = None) -> list[str]:
client = await self._get_client()
if user_id:
members = await client.smembers(f"user:{user_id}:conversations")
return [m.decode() for m in members]
keys = await client.keys("conversation:*")
return [k.decode().replace("conversation:", "") for k in keys]
class PostgresStateStore(StateStore):
"""PostgreSQL-backed state store."""
def __init__(self, connection_string: str):
self.connection_string = connection_string
self._pool = None
async def _get_pool(self):
if self._pool is None:
import asyncpg
self._pool = await asyncpg.create_pool(self.connection_string)
return self._pool
async def save(self, state: ConversationState):
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute("""
INSERT INTO conversations (id, data, user_id, updated_at)
VALUES ($1, $2, $3, $4)
ON CONFLICT (id) DO UPDATE
SET data = $2, updated_at = $4
""",
state.conversation_id,
json.dumps(state.to_dict()),
state.metadata.get("user_id"),
state.updated_at
)
async def load(self, conversation_id: str) -> Optional[ConversationState]:
pool = await self._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT data FROM conversations WHERE id = $1",
conversation_id
)
if row:
return ConversationState.from_dict(json.loads(row["data"]))
return None
async def delete(self, conversation_id: str):
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"DELETE FROM conversations WHERE id = $1",
conversation_id
)
async def list_conversations(self, user_id: str = None) -> list[str]:
pool = await self._get_pool()
async with pool.acquire() as conn:
if user_id:
rows = await conn.fetch(
"SELECT id FROM conversations WHERE user_id = $1 ORDER BY updated_at DESC",
user_id
)
else:
rows = await conn.fetch(
"SELECT id FROM conversations ORDER BY updated_at DESC"
)
return [row["id"] for row in rows]
Context Extraction
from dataclasses import dataclass
from typing import Any, Optional, List
import re
class ContextExtractor:
"""Extract context from conversations."""
def __init__(self, llm_client: Any = None):
self.llm = llm_client
def extract_entities(self, text: str) -> list[str]:
"""Extract named entities from text."""
# Simple pattern-based extraction
entities = []
# Capitalized words (potential names/places)
capitalized = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', text)
entities.extend(capitalized)
# Email addresses
emails = re.findall(r'\b[\w.-]+@[\w.-]+\.\w+\b', text)
entities.extend(emails)
# URLs
urls = re.findall(r'https?://\S+', text)
entities.extend(urls)
return list(set(entities))
def extract_user_info(self, messages: list[Message]) -> dict:
"""Extract user information from messages."""
info = {}
for message in messages:
if message.role != MessageRole.USER:
continue
content = message.content.lower()
# Name patterns
name_patterns = [
r"my name is (\w+)",
r"i'm (\w+)",
r"i am (\w+)",
r"call me (\w+)"
]
for pattern in name_patterns:
match = re.search(pattern, content)
if match:
info["name"] = match.group(1).capitalize()
break
# Preference patterns
if "prefer" in content or "like" in content:
info.setdefault("preferences", []).append(message.content)
return info
async def extract_with_llm(
self,
messages: list[Message]
) -> ConversationContext:
"""Use LLM to extract context."""
if not self.llm:
return ConversationContext()
conversation_text = "\n".join([
f"{m.role.value}: {m.content}"
for m in messages[-10:] # Last 10 messages
])
prompt = f"""Extract key information from this conversation:
{conversation_text}
Extract and return as JSON:
- user_name: The user's name if mentioned
- current_topic: The main topic being discussed
- entities: List of important names, places, products mentioned
- facts: List of facts learned about the user or situation
- pending_tasks: Any tasks or requests that haven't been completed
JSON:"""
response = await self.llm.generate(prompt)
try:
import json
data = json.loads(response)
return ConversationContext(
user_name=data.get("user_name"),
current_topic=data.get("current_topic"),
entities=data.get("entities", []),
facts=data.get("facts", []),
pending_tasks=data.get("pending_tasks", [])
)
except:
return ConversationContext()
class ConversationSummarizer:
"""Summarize conversations for memory compression."""
def __init__(self, llm_client: Any):
self.llm = llm_client
async def summarize(
self,
messages: list[Message],
max_length: int = 500
) -> str:
"""Summarize conversation."""
conversation_text = "\n".join([
f"{m.role.value}: {m.content}"
for m in messages
])
prompt = f"""Summarize this conversation in {max_length} characters or less.
Focus on key decisions, facts learned, and any pending items.
Conversation:
{conversation_text}
Summary:"""
return await self.llm.generate(prompt)
async def create_rolling_summary(
self,
previous_summary: str,
new_messages: list[Message]
) -> str:
"""Update summary with new messages."""
new_text = "\n".join([
f"{m.role.value}: {m.content}"
for m in new_messages
])
prompt = f"""Update this conversation summary with new messages.
Previous summary:
{previous_summary}
New messages:
{new_text}
Updated summary (keep under 500 characters):"""
return await self.llm.generate(prompt)
Memory Management
from dataclasses import dataclass
from typing import Any, Optional, List
from datetime import datetime, timedelta
@dataclass
class MemoryConfig:
"""Configuration for memory management."""
max_messages: int = 50
max_tokens: int = 4000
summarize_after: int = 20
context_window: int = 10
class ConversationMemory:
"""Manage conversation memory."""
def __init__(
self,
config: MemoryConfig,
summarizer: ConversationSummarizer = None
):
self.config = config
self.summarizer = summarizer
self.summary: str = ""
async def prepare_context(
self,
state: ConversationState
) -> list[dict]:
"""Prepare context for LLM call."""
messages = []
# Add summary if available
if self.summary:
messages.append({
"role": "system",
"content": f"Previous conversation summary: {self.summary}"
})
# Add recent messages
recent = state.messages[-self.config.context_window:]
for msg in recent:
messages.append({
"role": msg.role.value,
"content": msg.content
})
return messages
async def update_memory(self, state: ConversationState):
"""Update memory after new messages."""
# Check if we need to summarize
if len(state.messages) > self.config.summarize_after:
if self.summarizer:
# Summarize older messages
old_messages = state.messages[:-self.config.context_window]
if self.summary:
self.summary = await self.summarizer.create_rolling_summary(
self.summary,
old_messages[-5:] # Just the new old messages
)
else:
self.summary = await self.summarizer.summarize(old_messages)
def trim_messages(self, state: ConversationState) -> ConversationState:
"""Trim messages to fit limits."""
if len(state.messages) <= self.config.max_messages:
return state
# Keep system messages and recent messages
system_messages = [m for m in state.messages if m.role == MessageRole.SYSTEM]
other_messages = [m for m in state.messages if m.role != MessageRole.SYSTEM]
# Keep most recent
kept = other_messages[-(self.config.max_messages - len(system_messages)):]
state.messages = system_messages + kept
return state
class WindowedMemory:
"""Sliding window memory."""
def __init__(self, window_size: int = 10):
self.window_size = window_size
def get_context(self, state: ConversationState) -> list[dict]:
"""Get windowed context."""
messages = state.messages[-self.window_size:]
return [
{"role": m.role.value, "content": m.content}
for m in messages
]
class SummaryMemory:
"""Summary-based memory."""
def __init__(self, summarizer: ConversationSummarizer):
self.summarizer = summarizer
self.summaries: dict[str, str] = {}
async def get_context(
self,
state: ConversationState,
recent_count: int = 5
) -> list[dict]:
"""Get context with summary."""
messages = []
# Get or create summary
if state.conversation_id not in self.summaries:
if len(state.messages) > recent_count:
old_messages = state.messages[:-recent_count]
self.summaries[state.conversation_id] = await self.summarizer.summarize(old_messages)
# Add summary
summary = self.summaries.get(state.conversation_id)
if summary:
messages.append({
"role": "system",
"content": f"Conversation history: {summary}"
})
# Add recent messages
for msg in state.messages[-recent_count:]:
messages.append({
"role": msg.role.value,
"content": msg.content
})
return messages
async def update(self, state: ConversationState, recent_count: int = 5):
"""Update summary with new messages."""
if len(state.messages) <= recent_count:
return
old_summary = self.summaries.get(state.conversation_id, "")
new_messages = state.messages[-(recent_count + 3):-recent_count]
if new_messages:
self.summaries[state.conversation_id] = await self.summarizer.create_rolling_summary(
old_summary,
new_messages
)
class EntityMemory:
"""Entity-based memory."""
def __init__(self, extractor: ContextExtractor):
self.extractor = extractor
self.entities: dict[str, dict] = {} # conversation_id -> entity info
async def update(self, state: ConversationState):
"""Update entity memory."""
cid = state.conversation_id
if cid not in self.entities:
self.entities[cid] = {}
# Extract from recent messages
for msg in state.messages[-3:]:
entities = self.extractor.extract_entities(msg.content)
for entity in entities:
if entity not in self.entities[cid]:
self.entities[cid][entity] = {
"first_mentioned": msg.timestamp,
"mentions": 0
}
self.entities[cid][entity]["mentions"] += 1
self.entities[cid][entity]["last_mentioned"] = msg.timestamp
def get_relevant_entities(
self,
state: ConversationState,
limit: int = 10
) -> list[str]:
"""Get most relevant entities."""
cid = state.conversation_id
if cid not in self.entities:
return []
# Sort by mention count
sorted_entities = sorted(
self.entities[cid].items(),
key=lambda x: x[1]["mentions"],
reverse=True
)
return [e[0] for e in sorted_entities[:limit]]
Production State Service
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, List, Dict
app = FastAPI()
class CreateConversationRequest(BaseModel):
user_id: Optional[str] = None
metadata: Optional[Dict] = None
class AddMessageRequest(BaseModel):
role: str
content: str
metadata: Optional[Dict] = None
class GetContextRequest(BaseModel):
max_messages: int = 10
include_summary: bool = True
# Initialize components
store = InMemoryStateStore()
memory_config = MemoryConfig()
@app.post("/v1/conversations")
async def create_conversation(request: CreateConversationRequest) -> dict:
"""Create a new conversation."""
conversation_id = str(uuid.uuid4())
state = ConversationState(
conversation_id=conversation_id,
metadata={
"user_id": request.user_id,
**(request.metadata or {})
}
)
await store.save(state)
return {
"conversation_id": conversation_id,
"created_at": state.created_at.isoformat()
}
@app.get("/v1/conversations/{conversation_id}")
async def get_conversation(conversation_id: str) -> dict:
"""Get conversation state."""
state = await store.load(conversation_id)
if not state:
raise HTTPException(status_code=404, detail="Conversation not found")
return state.to_dict()
@app.post("/v1/conversations/{conversation_id}/messages")
async def add_message(
conversation_id: str,
request: AddMessageRequest
) -> dict:
"""Add a message to conversation."""
state = await store.load(conversation_id)
if not state:
raise HTTPException(status_code=404, detail="Conversation not found")
message = state.add_message(
role=MessageRole(request.role),
content=request.content,
metadata=request.metadata
)
await store.save(state)
return {
"message_id": message.message_id,
"timestamp": message.timestamp.isoformat()
}
@app.post("/v1/conversations/{conversation_id}/context")
async def get_context(
conversation_id: str,
request: GetContextRequest
) -> dict:
"""Get context for LLM call."""
state = await store.load(conversation_id)
if not state:
raise HTTPException(status_code=404, detail="Conversation not found")
messages = state.get_messages_for_llm(max_messages=request.max_messages)
return {
"messages": messages,
"context": {
"user_name": state.context.user_name,
"current_topic": state.context.current_topic,
"entities": state.context.entities[:10]
},
"token_estimate": state.get_token_count()
}
@app.get("/v1/conversations/{conversation_id}/messages")
async def list_messages(
conversation_id: str,
limit: int = 50,
offset: int = 0
) -> list[dict]:
"""List messages in conversation."""
state = await store.load(conversation_id)
if not state:
raise HTTPException(status_code=404, detail="Conversation not found")
messages = state.messages[offset:offset + limit]
return [m.to_dict() for m in messages]
@app.delete("/v1/conversations/{conversation_id}")
async def delete_conversation(conversation_id: str) -> dict:
"""Delete a conversation."""
await store.delete(conversation_id)
return {"deleted": True}
@app.get("/v1/users/{user_id}/conversations")
async def list_user_conversations(user_id: str) -> list[str]:
"""List conversations for a user."""
return await store.list_conversations(user_id=user_id)
@app.get("/health")
async def health():
return {"status": "healthy"}
References
- LangChain Memory: https://python.langchain.com/docs/modules/memory/
- OpenAI Assistants: https://platform.openai.com/docs/assistants/overview
- Redis: https://redis.io/docs/
- PostgreSQL: https://www.postgresql.org/docs/
- MemGPT: https://memgpt.ai/
Conclusion
Conversation state management is essential for building AI assistants that feel coherent and contextually aware. The key is choosing the right storage backend for your scale and requirements: in-memory for development and small deployments, Redis for high-performance caching with persistence, PostgreSQL for durable storage with complex queries. Memory management strategies depend on your context window constraints: sliding windows are simple but lose long-term context, summarization preserves key information but requires LLM calls, entity-based memory tracks important facts efficiently. For production systems, combine these approaches: use sliding windows for recent context, periodic summarization for older history, and entity extraction for persistent facts. Context extraction should happen asynchronously to avoid adding latency to user interactions. Design your state models to be serializable and version them for backward compatibility as your application evolves. Monitor memory usage and token counts to stay within model limits. With proper state management, your AI assistant can maintain coherent conversations across sessions, remember user preferences, and provide contextually relevant responses that make users feel understood.
Discover more from Code, Cloud & Context
Subscribe to get the latest posts sent to your email.