Semantic Caching for LLMs: Embedding-Based Similarity and Cache Strategies

Introduction: LLM API calls are expensive and slow—semantic caching reduces both by reusing responses for similar queries. Unlike exact-match caching, semantic caching uses embeddings to find queries that are semantically similar, even if worded differently. This enables cache hits for paraphrased questions, reducing latency from seconds to milliseconds and cutting API costs significantly. This guide covers practical semantic caching: embedding-based similarity search, cache key strategies, TTL and invalidation, hybrid caching approaches, and production considerations for cache consistency.

Semantic Caching
Semantic Cache: Query Embedding, Similarity Search, Hit/Miss Decision

Basic Semantic Cache

from dataclasses import dataclass, field
from typing import Optional
from datetime import datetime, timedelta
import numpy as np

@dataclass
class CacheEntry:
    """A cached query-response pair."""
    
    query: str
    response: str
    embedding: list[float]
    created_at: datetime = field(default_factory=datetime.now)
    ttl_seconds: int = 3600
    hit_count: int = 0
    metadata: dict = field(default_factory=dict)
    
    @property
    def is_expired(self) -> bool:
        """Check if entry has expired."""
        
        age = datetime.now() - self.created_at
        return age.total_seconds() > self.ttl_seconds

class SemanticCache:
    """Cache LLM responses using semantic similarity."""
    
    def __init__(
        self,
        embedding_client,
        similarity_threshold: float = 0.95,
        max_entries: int = 10000,
        default_ttl: int = 3600
    ):
        self.embedding_client = embedding_client
        self.threshold = similarity_threshold
        self.max_entries = max_entries
        self.default_ttl = default_ttl
        
        self.entries: list[CacheEntry] = []
        self.stats = {"hits": 0, "misses": 0}
    
    def _embed(self, text: str) -> list[float]:
        """Get embedding for text."""
        
        response = self.embedding_client.embeddings.create(
            model="text-embedding-3-small",
            input=text
        )
        return response.data[0].embedding
    
    def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
        """Calculate cosine similarity."""
        
        a_arr = np.array(a)
        b_arr = np.array(b)
        
        return np.dot(a_arr, b_arr) / (
            np.linalg.norm(a_arr) * np.linalg.norm(b_arr)
        )
    
    def get(self, query: str) -> Optional[str]:
        """Get cached response for query."""
        
        query_embedding = self._embed(query)
        
        best_match = None
        best_similarity = 0.0
        
        for entry in self.entries:
            if entry.is_expired:
                continue
            
            similarity = self._cosine_similarity(
                query_embedding,
                entry.embedding
            )
            
            if similarity > best_similarity:
                best_similarity = similarity
                best_match = entry
        
        if best_match and best_similarity >= self.threshold:
            best_match.hit_count += 1
            self.stats["hits"] += 1
            return best_match.response
        
        self.stats["misses"] += 1
        return None
    
    def set(
        self,
        query: str,
        response: str,
        ttl: int = None,
        metadata: dict = None
    ):
        """Cache a query-response pair."""
        
        # Evict expired entries
        self._evict_expired()
        
        # Evict oldest if at capacity
        if len(self.entries) >= self.max_entries:
            self._evict_lru()
        
        embedding = self._embed(query)
        
        entry = CacheEntry(
            query=query,
            response=response,
            embedding=embedding,
            ttl_seconds=ttl or self.default_ttl,
            metadata=metadata or {}
        )
        
        self.entries.append(entry)
    
    def _evict_expired(self):
        """Remove expired entries."""
        
        self.entries = [e for e in self.entries if not e.is_expired]
    
    def _evict_lru(self):
        """Remove least recently used entries."""
        
        # Sort by hit count and age
        self.entries.sort(
            key=lambda e: (e.hit_count, -e.created_at.timestamp())
        )
        
        # Remove bottom 10%
        remove_count = max(1, len(self.entries) // 10)
        self.entries = self.entries[remove_count:]
    
    def invalidate(self, query: str):
        """Invalidate cache entries similar to query."""
        
        query_embedding = self._embed(query)
        
        self.entries = [
            e for e in self.entries
            if self._cosine_similarity(query_embedding, e.embedding) < self.threshold
        ]
    
    def clear(self):
        """Clear all cache entries."""
        
        self.entries = []
    
    def get_stats(self) -> dict:
        """Get cache statistics."""
        
        total = self.stats["hits"] + self.stats["misses"]
        hit_rate = self.stats["hits"] / total if total > 0 else 0
        
        return {
            "hits": self.stats["hits"],
            "misses": self.stats["misses"],
            "hit_rate": hit_rate,
            "entries": len(self.entries)
        }

Vector Store Cache

from dataclasses import dataclass
from typing import Optional
import json
import hashlib

@dataclass
class VectorCacheEntry:
    """Cache entry with vector store backend."""
    
    id: str
    query: str
    response: str
    created_at: float
    ttl_seconds: int
    metadata: dict

class VectorStoreCache:
    """Semantic cache using vector database."""
    
    def __init__(
        self,
        embedding_client,
        vector_store,  # Pinecone, Weaviate, etc.
        similarity_threshold: float = 0.95,
        default_ttl: int = 3600
    ):
        self.embedding_client = embedding_client
        self.vector_store = vector_store
        self.threshold = similarity_threshold
        self.default_ttl = default_ttl
        
        self.stats = {"hits": 0, "misses": 0}
    
    def _embed(self, text: str) -> list[float]:
        """Get embedding for text."""
        
        response = self.embedding_client.embeddings.create(
            model="text-embedding-3-small",
            input=text
        )
        return response.data[0].embedding
    
    def _generate_id(self, query: str) -> str:
        """Generate unique ID for query."""
        
        return hashlib.md5(query.encode()).hexdigest()
    
    def get(self, query: str) -> Optional[str]:
        """Get cached response."""
        
        import time
        
        query_embedding = self._embed(query)
        
        # Search vector store
        results = self.vector_store.query(
            vector=query_embedding,
            top_k=1,
            include_metadata=True
        )
        
        if results and len(results.matches) > 0:
            match = results.matches[0]
            
            if match.score >= self.threshold:
                metadata = match.metadata
                
                # Check TTL
                created_at = metadata.get("created_at", 0)
                ttl = metadata.get("ttl_seconds", self.default_ttl)
                
                if time.time() - created_at < ttl:
                    self.stats["hits"] += 1
                    return metadata.get("response")
                else:
                    # Expired - delete
                    self.vector_store.delete(ids=[match.id])
        
        self.stats["misses"] += 1
        return None
    
    def set(
        self,
        query: str,
        response: str,
        ttl: int = None,
        metadata: dict = None
    ):
        """Cache query-response pair."""
        
        import time
        
        embedding = self._embed(query)
        entry_id = self._generate_id(query)
        
        cache_metadata = {
            "query": query,
            "response": response,
            "created_at": time.time(),
            "ttl_seconds": ttl or self.default_ttl,
            **(metadata or {})
        }
        
        self.vector_store.upsert(
            vectors=[{
                "id": entry_id,
                "values": embedding,
                "metadata": cache_metadata
            }]
        )
    
    def invalidate_by_prefix(self, prefix: str):
        """Invalidate entries matching prefix."""
        
        # Query for matching entries
        # Implementation depends on vector store capabilities
        pass

# Redis-backed semantic cache
class RedisSemanticCache:
    """Semantic cache with Redis backend."""
    
    def __init__(
        self,
        embedding_client,
        redis_client,
        similarity_threshold: float = 0.95,
        default_ttl: int = 3600,
        index_name: str = "llm_cache"
    ):
        self.embedding_client = embedding_client
        self.redis = redis_client
        self.threshold = similarity_threshold
        self.default_ttl = default_ttl
        self.index_name = index_name
        
        self._ensure_index()
    
    def _ensure_index(self):
        """Create Redis vector index if not exists."""
        
        from redis.commands.search.field import VectorField, TextField
        from redis.commands.search.indexDefinition import IndexDefinition, IndexType
        
        try:
            self.redis.ft(self.index_name).info()
        except:
            # Create index
            schema = [
                TextField("query"),
                TextField("response"),
                VectorField(
                    "embedding",
                    "FLAT",
                    {
                        "TYPE": "FLOAT32",
                        "DIM": 1536,
                        "DISTANCE_METRIC": "COSINE"
                    }
                )
            ]
            
            self.redis.ft(self.index_name).create_index(
                schema,
                definition=IndexDefinition(
                    prefix=["cache:"],
                    index_type=IndexType.HASH
                )
            )
    
    def _embed(self, text: str) -> list[float]:
        """Get embedding."""
        
        response = self.embedding_client.embeddings.create(
            model="text-embedding-3-small",
            input=text
        )
        return response.data[0].embedding
    
    def get(self, query: str) -> Optional[str]:
        """Get cached response."""
        
        from redis.commands.search.query import Query
        
        embedding = self._embed(query)
        embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
        
        q = (
            Query(f"*=>[KNN 1 @embedding $vec AS score]")
            .return_fields("query", "response", "score")
            .dialect(2)
        )
        
        results = self.redis.ft(self.index_name).search(
            q,
            query_params={"vec": embedding_bytes}
        )
        
        if results.docs:
            doc = results.docs[0]
            similarity = 1 - float(doc.score)  # Convert distance to similarity
            
            if similarity >= self.threshold:
                return doc.response
        
        return None
    
    def set(self, query: str, response: str, ttl: int = None):
        """Cache query-response."""
        
        import uuid
        
        embedding = self._embed(query)
        embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
        
        key = f"cache:{uuid.uuid4().hex}"
        
        self.redis.hset(key, mapping={
            "query": query,
            "response": response,
            "embedding": embedding_bytes
        })
        
        self.redis.expire(key, ttl or self.default_ttl)

Cache Key Strategies

from dataclasses import dataclass
from typing import Callable
import hashlib
import json

@dataclass
class CacheKey:
    """Structured cache key."""
    
    query: str
    model: str
    temperature: float
    system_prompt_hash: str
    context_hash: str
    
    def to_string(self) -> str:
        """Convert to string key."""
        
        return f"{self.model}:{self.temperature}:{self.system_prompt_hash}:{self.context_hash}:{self.query}"

class CacheKeyBuilder:
    """Build cache keys with context awareness."""
    
    def __init__(self, include_model: bool = True):
        self.include_model = include_model
    
    def build(
        self,
        query: str,
        model: str = None,
        temperature: float = None,
        system_prompt: str = None,
        context: list[dict] = None
    ) -> CacheKey:
        """Build cache key from components."""
        
        # Hash system prompt
        system_hash = ""
        if system_prompt:
            system_hash = hashlib.md5(system_prompt.encode()).hexdigest()[:8]
        
        # Hash context
        context_hash = ""
        if context:
            context_str = json.dumps(context, sort_keys=True)
            context_hash = hashlib.md5(context_str.encode()).hexdigest()[:8]
        
        return CacheKey(
            query=query,
            model=model or "default",
            temperature=temperature or 0.0,
            system_prompt_hash=system_hash,
            context_hash=context_hash
        )
    
    def normalize_query(self, query: str) -> str:
        """Normalize query for better cache hits."""
        
        # Lowercase
        normalized = query.lower()
        
        # Remove extra whitespace
        normalized = ' '.join(normalized.split())
        
        # Remove punctuation at end
        normalized = normalized.rstrip('?!.')
        
        return normalized

class ContextAwareCache:
    """Cache that considers conversation context."""
    
    def __init__(
        self,
        embedding_client,
        similarity_threshold: float = 0.95
    ):
        self.embedding_client = embedding_client
        self.threshold = similarity_threshold
        self.key_builder = CacheKeyBuilder()
        
        # Separate caches for different contexts
        self.caches: dict[str, SemanticCache] = {}
    
    def _get_context_key(self, context: list[dict]) -> str:
        """Get key for context."""
        
        if not context:
            return "no_context"
        
        # Use last few messages as context key
        recent = context[-3:] if len(context) > 3 else context
        context_str = json.dumps(recent, sort_keys=True)
        return hashlib.md5(context_str.encode()).hexdigest()[:16]
    
    def get(
        self,
        query: str,
        context: list[dict] = None
    ) -> Optional[str]:
        """Get cached response considering context."""
        
        context_key = self._get_context_key(context)
        
        if context_key not in self.caches:
            return None
        
        return self.caches[context_key].get(query)
    
    def set(
        self,
        query: str,
        response: str,
        context: list[dict] = None,
        ttl: int = None
    ):
        """Cache response with context."""
        
        context_key = self._get_context_key(context)
        
        if context_key not in self.caches:
            self.caches[context_key] = SemanticCache(
                self.embedding_client,
                similarity_threshold=self.threshold
            )
        
        self.caches[context_key].set(query, response, ttl)

# Tiered caching
class TieredCache:
    """Multi-tier cache with exact and semantic matching."""
    
    def __init__(self, embedding_client):
        self.embedding_client = embedding_client
        
        # Tier 1: Exact match (fast)
        self.exact_cache: dict[str, str] = {}
        
        # Tier 2: Semantic match (slower but more flexible)
        self.semantic_cache = SemanticCache(
            embedding_client,
            similarity_threshold=0.95
        )
        
        self.stats = {
            "exact_hits": 0,
            "semantic_hits": 0,
            "misses": 0
        }
    
    def _normalize(self, query: str) -> str:
        """Normalize for exact matching."""
        
        return ' '.join(query.lower().split())
    
    def get(self, query: str) -> Optional[str]:
        """Get from cache, trying exact first."""
        
        normalized = self._normalize(query)
        
        # Try exact match first
        if normalized in self.exact_cache:
            self.stats["exact_hits"] += 1
            return self.exact_cache[normalized]
        
        # Try semantic match
        result = self.semantic_cache.get(query)
        
        if result:
            self.stats["semantic_hits"] += 1
            return result
        
        self.stats["misses"] += 1
        return None
    
    def set(self, query: str, response: str, ttl: int = None):
        """Set in both caches."""
        
        normalized = self._normalize(query)
        
        # Store in exact cache
        self.exact_cache[normalized] = response
        
        # Store in semantic cache
        self.semantic_cache.set(query, response, ttl)

Cache Invalidation

from dataclasses import dataclass
from typing import Callable
from datetime import datetime

@dataclass
class InvalidationRule:
    """Rule for cache invalidation."""
    
    name: str
    condition: Callable[[CacheEntry], bool]
    priority: int = 0

class CacheInvalidator:
    """Manage cache invalidation."""
    
    def __init__(self, cache: SemanticCache):
        self.cache = cache
        self.rules: list[InvalidationRule] = []
    
    def add_rule(self, rule: InvalidationRule):
        """Add invalidation rule."""
        
        self.rules.append(rule)
        self.rules.sort(key=lambda r: r.priority, reverse=True)
    
    def invalidate_by_rules(self):
        """Apply all invalidation rules."""
        
        for rule in self.rules:
            self.cache.entries = [
                e for e in self.cache.entries
                if not rule.condition(e)
            ]
    
    def invalidate_by_age(self, max_age_seconds: int):
        """Invalidate entries older than max age."""
        
        cutoff = datetime.now()
        
        self.cache.entries = [
            e for e in self.cache.entries
            if (cutoff - e.created_at).total_seconds() < max_age_seconds
        ]
    
    def invalidate_by_metadata(self, key: str, value):
        """Invalidate entries with matching metadata."""
        
        self.cache.entries = [
            e for e in self.cache.entries
            if e.metadata.get(key) != value
        ]
    
    def invalidate_similar(self, query: str, threshold: float = 0.9):
        """Invalidate entries similar to query."""
        
        self.cache.invalidate(query)

# Event-based invalidation
class EventDrivenCache:
    """Cache with event-based invalidation."""
    
    def __init__(self, embedding_client):
        self.cache = SemanticCache(embedding_client)
        self.invalidation_handlers: dict[str, list[Callable]] = {}
    
    def on_event(self, event_type: str, handler: Callable):
        """Register invalidation handler for event type."""
        
        if event_type not in self.invalidation_handlers:
            self.invalidation_handlers[event_type] = []
        
        self.invalidation_handlers[event_type].append(handler)
    
    def emit_event(self, event_type: str, data: dict = None):
        """Emit event to trigger invalidation."""
        
        handlers = self.invalidation_handlers.get(event_type, [])
        
        for handler in handlers:
            handler(self.cache, data or {})
    
    def get(self, query: str) -> Optional[str]:
        return self.cache.get(query)
    
    def set(
        self,
        query: str,
        response: str,
        tags: list[str] = None
    ):
        """Set with tags for targeted invalidation."""
        
        self.cache.set(
            query,
            response,
            metadata={"tags": tags or []}
        )

# Example invalidation handlers
def invalidate_by_tag(cache: SemanticCache, data: dict):
    """Invalidate entries with specific tag."""
    
    tag = data.get("tag")
    if tag:
        cache.entries = [
            e for e in cache.entries
            if tag not in e.metadata.get("tags", [])
        ]

def invalidate_all(cache: SemanticCache, data: dict):
    """Clear entire cache."""
    
    cache.clear()

Production Cache Service

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional

app = FastAPI()

# Initialize components
from openai import OpenAI
client = OpenAI()

# Use tiered cache for best performance
cache = TieredCache(client)

class CacheRequest(BaseModel):
    query: str
    context: Optional[list[dict]] = None

class CacheSetRequest(BaseModel):
    query: str
    response: str
    ttl: Optional[int] = 3600
    tags: Optional[list[str]] = None

class LLMRequest(BaseModel):
    query: str
    system_prompt: str = "You are a helpful assistant."
    context: Optional[list[dict]] = None
    use_cache: bool = True

@app.post("/v1/cache/get")
async def cache_get(request: CacheRequest):
    """Get from cache."""
    
    result = cache.get(request.query)
    
    return {
        "hit": result is not None,
        "response": result
    }

@app.post("/v1/cache/set")
async def cache_set(request: CacheSetRequest):
    """Set cache entry."""
    
    cache.set(
        request.query,
        request.response,
        request.ttl
    )
    
    return {"cached": True}

@app.post("/v1/chat")
async def chat_with_cache(request: LLMRequest):
    """Chat with semantic caching."""
    
    # Check cache first
    if request.use_cache:
        cached = cache.get(request.query)
        
        if cached:
            return {
                "response": cached,
                "cached": True,
                "latency_ms": 0
            }
    
    # Call LLM
    import time
    start = time.time()
    
    messages = [{"role": "system", "content": request.system_prompt}]
    
    if request.context:
        messages.extend(request.context)
    
    messages.append({"role": "user", "content": request.query})
    
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=messages
    )
    
    latency = (time.time() - start) * 1000
    result = response.choices[0].message.content
    
    # Cache the response
    if request.use_cache:
        cache.set(request.query, result)
    
    return {
        "response": result,
        "cached": False,
        "latency_ms": latency
    }

@app.delete("/v1/cache")
async def clear_cache():
    """Clear all cache entries."""
    
    cache.exact_cache.clear()
    cache.semantic_cache.clear()
    
    return {"cleared": True}

@app.get("/v1/cache/stats")
async def cache_stats():
    """Get cache statistics."""
    
    return {
        "exact_hits": cache.stats["exact_hits"],
        "semantic_hits": cache.stats["semantic_hits"],
        "misses": cache.stats["misses"],
        "exact_entries": len(cache.exact_cache),
        "semantic_entries": len(cache.semantic_cache.entries)
    }

@app.get("/health")
async def health():
    return {"status": "healthy"}

References

Conclusion

Semantic caching transforms LLM economics—reducing both latency and cost for repeated or similar queries. Start with a basic embedding-based cache using cosine similarity. Choose appropriate similarity thresholds: higher (0.95+) for factual queries where precision matters, lower (0.85-0.90) for conversational queries where slight variations are acceptable. Implement tiered caching with exact match as the fast path and semantic matching as fallback. Consider context in your cache keys—the same question may have different answers in different conversation contexts. Use vector databases for production scale. Implement proper invalidation strategies based on TTL, events, or content updates. Monitor hit rates and adjust thresholds based on your specific use case. The goal is maximizing cache hits while maintaining response quality.


Discover more from Code, Cloud & Context

Subscribe to get the latest posts sent to your email.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.