Semantic Caching: Reducing LLM Costs with Meaning-Based Query Matching

Introduction: LLM API calls are expensive and slow. When users ask similar questions, you’re paying for the same computation repeatedly. Traditional caching doesn’t help because queries are rarely identical—”What’s the weather?” and “Tell me the weather” are different strings but should return the same cached response. Semantic caching solves this by matching queries based on meaning rather than exact text. This guide covers practical semantic caching strategies: embedding-based similarity matching, threshold tuning for cache hits, cache invalidation policies, and building caching layers that significantly reduce costs while maintaining response quality.

Semantic Caching
Semantic Cache: Query Embedding, Similarity Search, Hit/Miss Decision

Embedding-Based Cache

from dataclasses import dataclass, field
from typing import Any, Optional
from datetime import datetime, timedelta
import hashlib
import numpy as np

@dataclass
class CacheEntry:
    """A cached query-response pair."""
    
    query: str
    response: str
    embedding: list[float]
    created_at: datetime = field(default_factory=datetime.utcnow)
    access_count: int = 0
    last_accessed: datetime = None
    metadata: dict = field(default_factory=dict)
    
    @property
    def age_seconds(self) -> float:
        """Age of entry in seconds."""
        return (datetime.utcnow() - self.created_at).total_seconds()

class EmbeddingCache:
    """Semantic cache using embeddings."""
    
    def __init__(
        self,
        client: Any,
        model: str = "text-embedding-3-small",
        similarity_threshold: float = 0.95,
        max_entries: int = 10000,
        ttl_seconds: int = 3600
    ):
        self.client = client
        self.model = model
        self.similarity_threshold = similarity_threshold
        self.max_entries = max_entries
        self.ttl_seconds = ttl_seconds
        
        self._entries: list[CacheEntry] = []
        self._embeddings_matrix: np.ndarray = None
    
    async def get_embedding(self, text: str) -> list[float]:
        """Get embedding for text."""
        
        response = await self.client.embeddings.create(
            model=self.model,
            input=text
        )
        return response.data[0].embedding
    
    def _compute_similarity(
        self,
        query_embedding: list[float],
        cache_embedding: list[float]
    ) -> float:
        """Compute cosine similarity."""
        
        q = np.array(query_embedding)
        c = np.array(cache_embedding)
        
        return float(np.dot(q, c) / (np.linalg.norm(q) * np.linalg.norm(c)))
    
    async def get(self, query: str) -> Optional[str]:
        """Get cached response for query."""
        
        if not self._entries:
            return None
        
        # Get query embedding
        query_embedding = await self.get_embedding(query)
        
        # Find best match
        best_match = None
        best_similarity = 0.0
        
        for entry in self._entries:
            # Skip expired entries
            if entry.age_seconds > self.ttl_seconds:
                continue
            
            similarity = self._compute_similarity(query_embedding, entry.embedding)
            
            if similarity > best_similarity:
                best_similarity = similarity
                best_match = entry
        
        # Check threshold
        if best_match and best_similarity >= self.similarity_threshold:
            best_match.access_count += 1
            best_match.last_accessed = datetime.utcnow()
            return best_match.response
        
        return None
    
    async def set(
        self,
        query: str,
        response: str,
        embedding: list[float] = None,
        **metadata
    ):
        """Cache a query-response pair."""
        
        # Get embedding if not provided
        if embedding is None:
            embedding = await self.get_embedding(query)
        
        entry = CacheEntry(
            query=query,
            response=response,
            embedding=embedding,
            metadata=metadata
        )
        
        self._entries.append(entry)
        
        # Evict if over capacity
        if len(self._entries) > self.max_entries:
            self._evict()
    
    def _evict(self):
        """Evict old or least-used entries."""
        
        now = datetime.utcnow()
        
        # Remove expired entries
        self._entries = [
            e for e in self._entries
            if e.age_seconds <= self.ttl_seconds
        ]
        
        # If still over capacity, remove least recently used
        if len(self._entries) > self.max_entries:
            self._entries.sort(
                key=lambda e: (e.last_accessed or e.created_at),
                reverse=True
            )
            self._entries = self._entries[:self.max_entries]
    
    def clear(self):
        """Clear all cache entries."""
        self._entries = []
    
    @property
    def stats(self) -> dict:
        """Get cache statistics."""
        
        total_accesses = sum(e.access_count for e in self._entries)
        
        return {
            "entries": len(self._entries),
            "total_accesses": total_accesses,
            "avg_accesses": total_accesses / len(self._entries) if self._entries else 0
        }

Vector Store Cache

from dataclasses import dataclass
from typing import Any, Optional
import json
import hashlib

@dataclass
class CacheHit:
    """Result of a cache lookup."""
    
    hit: bool
    response: str = None
    similarity: float = 0.0
    query_matched: str = None

class VectorStoreCache:
    """Cache using vector database for scalability."""
    
    def __init__(
        self,
        embedding_client: Any,
        vector_store: Any,  # Pinecone, Weaviate, etc.
        embedding_model: str = "text-embedding-3-small",
        similarity_threshold: float = 0.95,
        namespace: str = "cache"
    ):
        self.embedding_client = embedding_client
        self.vector_store = vector_store
        self.embedding_model = embedding_model
        self.similarity_threshold = similarity_threshold
        self.namespace = namespace
    
    async def get_embedding(self, text: str) -> list[float]:
        """Get embedding for text."""
        
        response = await self.embedding_client.embeddings.create(
            model=self.embedding_model,
            input=text
        )
        return response.data[0].embedding
    
    def _generate_id(self, query: str) -> str:
        """Generate unique ID for query."""
        return hashlib.sha256(query.encode()).hexdigest()[:16]
    
    async def get(self, query: str) -> CacheHit:
        """Look up query in cache."""
        
        # Get query embedding
        embedding = await self.get_embedding(query)
        
        # Search vector store
        results = await self.vector_store.query(
            vector=embedding,
            top_k=1,
            namespace=self.namespace,
            include_metadata=True
        )
        
        if not results.matches:
            return CacheHit(hit=False)
        
        best_match = results.matches[0]
        
        if best_match.score >= self.similarity_threshold:
            return CacheHit(
                hit=True,
                response=best_match.metadata.get("response"),
                similarity=best_match.score,
                query_matched=best_match.metadata.get("query")
            )
        
        return CacheHit(hit=False, similarity=best_match.score)
    
    async def set(
        self,
        query: str,
        response: str,
        ttl_seconds: int = 3600,
        **metadata
    ):
        """Store query-response in cache."""
        
        embedding = await self.get_embedding(query)
        
        cache_id = self._generate_id(query)
        
        await self.vector_store.upsert(
            vectors=[{
                "id": cache_id,
                "values": embedding,
                "metadata": {
                    "query": query,
                    "response": response,
                    "created_at": datetime.utcnow().isoformat(),
                    "ttl": ttl_seconds,
                    **metadata
                }
            }],
            namespace=self.namespace
        )
    
    async def delete(self, query: str):
        """Delete entry from cache."""
        
        cache_id = self._generate_id(query)
        
        await self.vector_store.delete(
            ids=[cache_id],
            namespace=self.namespace
        )
    
    async def clear_namespace(self):
        """Clear all entries in namespace."""
        
        await self.vector_store.delete(
            delete_all=True,
            namespace=self.namespace
        )

class RedisSemanticCache:
    """Redis-based semantic cache with vector search."""
    
    def __init__(
        self,
        redis_client: Any,
        embedding_client: Any,
        embedding_model: str = "text-embedding-3-small",
        similarity_threshold: float = 0.95,
        index_name: str = "semantic_cache"
    ):
        self.redis = redis_client
        self.embedding_client = embedding_client
        self.embedding_model = embedding_model
        self.similarity_threshold = similarity_threshold
        self.index_name = index_name
    
    async def get_embedding(self, text: str) -> list[float]:
        """Get embedding for text."""
        
        response = await self.embedding_client.embeddings.create(
            model=self.embedding_model,
            input=text
        )
        return response.data[0].embedding
    
    async def get(self, query: str) -> CacheHit:
        """Look up query in cache."""
        
        embedding = await self.get_embedding(query)
        
        # Redis vector search query
        query_str = f"*=>[KNN 1 @embedding $vec AS score]"
        
        results = await self.redis.ft(self.index_name).search(
            query_str,
            query_params={"vec": np.array(embedding).tobytes()}
        )
        
        if not results.docs:
            return CacheHit(hit=False)
        
        doc = results.docs[0]
        similarity = 1 - float(doc.score)  # Redis returns distance
        
        if similarity >= self.similarity_threshold:
            return CacheHit(
                hit=True,
                response=doc.response,
                similarity=similarity,
                query_matched=doc.query
            )
        
        return CacheHit(hit=False, similarity=similarity)
    
    async def set(
        self,
        query: str,
        response: str,
        ttl_seconds: int = 3600
    ):
        """Store query-response in cache."""
        
        embedding = await self.get_embedding(query)
        
        key = f"cache:{hashlib.sha256(query.encode()).hexdigest()[:16]}"
        
        await self.redis.hset(key, mapping={
            "query": query,
            "response": response,
            "embedding": np.array(embedding).tobytes(),
            "created_at": datetime.utcnow().isoformat()
        })
        
        await self.redis.expire(key, ttl_seconds)

Cache Strategies

from dataclasses import dataclass
from typing import Any, Optional, Callable
from enum import Enum

class CacheStrategy(Enum):
    """Cache lookup strategies."""
    
    EXACT = "exact"
    SEMANTIC = "semantic"
    HYBRID = "hybrid"

@dataclass
class CacheConfig:
    """Cache configuration."""
    
    strategy: CacheStrategy = CacheStrategy.SEMANTIC
    similarity_threshold: float = 0.95
    ttl_seconds: int = 3600
    max_entries: int = 10000
    enable_exact_match: bool = True

class HybridCache:
    """Hybrid cache with exact and semantic matching."""
    
    def __init__(
        self,
        embedding_client: Any,
        config: CacheConfig = None
    ):
        self.embedding_client = embedding_client
        self.config = config or CacheConfig()
        
        # Exact match cache (hash-based)
        self._exact_cache: dict[str, CacheEntry] = {}
        
        # Semantic cache
        self._semantic_cache = EmbeddingCache(
            client=embedding_client,
            similarity_threshold=self.config.similarity_threshold,
            max_entries=self.config.max_entries,
            ttl_seconds=self.config.ttl_seconds
        )
    
    def _hash_query(self, query: str) -> str:
        """Hash query for exact matching."""
        
        # Normalize query
        normalized = query.lower().strip()
        return hashlib.sha256(normalized.encode()).hexdigest()
    
    async def get(self, query: str) -> Optional[str]:
        """Get cached response."""
        
        # Try exact match first (fast)
        if self.config.enable_exact_match:
            query_hash = self._hash_query(query)
            
            if query_hash in self._exact_cache:
                entry = self._exact_cache[query_hash]
                
                if entry.age_seconds <= self.config.ttl_seconds:
                    entry.access_count += 1
                    entry.last_accessed = datetime.utcnow()
                    return entry.response
        
        # Try semantic match
        if self.config.strategy in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID]:
            return await self._semantic_cache.get(query)
        
        return None
    
    async def set(self, query: str, response: str, **metadata):
        """Cache query-response pair."""
        
        # Store in exact cache
        if self.config.enable_exact_match:
            query_hash = self._hash_query(query)
            embedding = await self._semantic_cache.get_embedding(query)
            
            self._exact_cache[query_hash] = CacheEntry(
                query=query,
                response=response,
                embedding=embedding,
                metadata=metadata
            )
        
        # Store in semantic cache
        if self.config.strategy in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID]:
            await self._semantic_cache.set(query, response, **metadata)
    
    @property
    def stats(self) -> dict:
        """Get cache statistics."""
        
        return {
            "exact_entries": len(self._exact_cache),
            "semantic_entries": len(self._semantic_cache._entries),
            "semantic_stats": self._semantic_cache.stats
        }

class AdaptiveThresholdCache:
    """Cache with adaptive similarity threshold."""
    
    def __init__(
        self,
        embedding_client: Any,
        initial_threshold: float = 0.95,
        min_threshold: float = 0.85,
        max_threshold: float = 0.99
    ):
        self.embedding_client = embedding_client
        self.threshold = initial_threshold
        self.min_threshold = min_threshold
        self.max_threshold = max_threshold
        
        self._cache = EmbeddingCache(
            client=embedding_client,
            similarity_threshold=initial_threshold
        )
        
        # Track feedback
        self._hits = 0
        self._misses = 0
        self._false_positives = 0
        self._false_negatives = 0
    
    async def get(self, query: str) -> Optional[str]:
        """Get cached response."""
        
        # Update cache threshold
        self._cache.similarity_threshold = self.threshold
        
        response = await self._cache.get(query)
        
        if response:
            self._hits += 1
        else:
            self._misses += 1
        
        return response
    
    def report_feedback(self, was_correct: bool, was_hit: bool):
        """Report feedback on cache result."""
        
        if was_hit and not was_correct:
            self._false_positives += 1
            # Increase threshold to be more strict
            self.threshold = min(self.threshold + 0.01, self.max_threshold)
        
        elif not was_hit and was_correct:
            # This means we should have hit but didn't
            self._false_negatives += 1
            # Decrease threshold to be more lenient
            self.threshold = max(self.threshold - 0.005, self.min_threshold)
    
    @property
    def stats(self) -> dict:
        """Get cache statistics."""
        
        total = self._hits + self._misses
        
        return {
            "current_threshold": self.threshold,
            "hit_rate": self._hits / total if total > 0 else 0,
            "false_positive_rate": self._false_positives / self._hits if self._hits > 0 else 0,
            "false_negative_rate": self._false_negatives / self._misses if self._misses > 0 else 0
        }

Caching LLM Wrapper

from dataclasses import dataclass
from typing import Any, Optional, Callable
import time

@dataclass
class CachedResponse:
    """Response with cache metadata."""
    
    content: str
    cached: bool
    similarity: float = None
    latency_ms: float = 0
    tokens_saved: int = 0

class CachingLLMClient:
    """LLM client with semantic caching."""
    
    def __init__(
        self,
        llm_client: Any,
        cache: HybridCache,
        model: str = "gpt-4o-mini"
    ):
        self.llm_client = llm_client
        self.cache = cache
        self.model = model
        
        # Stats
        self._total_requests = 0
        self._cache_hits = 0
        self._tokens_saved = 0
        self._cost_saved = 0.0
    
    async def complete(
        self,
        prompt: str,
        system_prompt: str = None,
        use_cache: bool = True,
        cache_ttl: int = 3600,
        **kwargs
    ) -> CachedResponse:
        """Complete prompt with caching."""
        
        self._total_requests += 1
        start_time = time.time()
        
        # Build cache key
        cache_key = prompt
        if system_prompt:
            cache_key = f"{system_prompt}|||{prompt}"
        
        # Try cache
        if use_cache:
            cached_response = await self.cache.get(cache_key)
            
            if cached_response:
                self._cache_hits += 1
                latency = (time.time() - start_time) * 1000
                
                return CachedResponse(
                    content=cached_response,
                    cached=True,
                    latency_ms=latency
                )
        
        # Call LLM
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})
        
        response = await self.llm_client.chat.completions.create(
            model=self.model,
            messages=messages,
            **kwargs
        )
        
        content = response.choices[0].message.content
        latency = (time.time() - start_time) * 1000
        
        # Cache response
        if use_cache:
            await self.cache.set(
                cache_key,
                content,
                model=self.model,
                tokens=response.usage.total_tokens
            )
        
        return CachedResponse(
            content=content,
            cached=False,
            latency_ms=latency
        )
    
    @property
    def stats(self) -> dict:
        """Get client statistics."""
        
        hit_rate = self._cache_hits / self._total_requests if self._total_requests > 0 else 0
        
        return {
            "total_requests": self._total_requests,
            "cache_hits": self._cache_hits,
            "hit_rate": hit_rate,
            "cache_stats": self.cache.stats
        }

class ContextAwareCachingClient:
    """Caching client that considers conversation context."""
    
    def __init__(
        self,
        llm_client: Any,
        embedding_client: Any,
        similarity_threshold: float = 0.90
    ):
        self.llm_client = llm_client
        self.embedding_client = embedding_client
        self.similarity_threshold = similarity_threshold
        
        # Cache entries include context
        self._cache: list[dict] = []
    
    async def get_embedding(self, text: str) -> list[float]:
        """Get embedding for text."""
        
        response = await self.embedding_client.embeddings.create(
            model="text-embedding-3-small",
            input=text
        )
        return response.data[0].embedding
    
    async def complete(
        self,
        prompt: str,
        context: list[dict] = None,
        **kwargs
    ) -> CachedResponse:
        """Complete with context-aware caching."""
        
        # Build context string
        context_str = ""
        if context:
            context_str = "\n".join(
                f"{m['role']}: {m['content']}"
                for m in context[-4:]  # Last 2 turns
            )
        
        # Create cache key from prompt + context
        cache_key = f"{context_str}\n{prompt}" if context_str else prompt
        
        # Get embedding
        query_embedding = await self.get_embedding(cache_key)
        
        # Search cache
        best_match = None
        best_similarity = 0.0
        
        for entry in self._cache:
            similarity = self._compute_similarity(query_embedding, entry["embedding"])
            
            if similarity > best_similarity:
                best_similarity = similarity
                best_match = entry
        
        if best_match and best_similarity >= self.similarity_threshold:
            return CachedResponse(
                content=best_match["response"],
                cached=True,
                similarity=best_similarity
            )
        
        # Call LLM
        messages = context or []
        messages.append({"role": "user", "content": prompt})
        
        response = await self.llm_client.chat.completions.create(
            messages=messages,
            **kwargs
        )
        
        content = response.choices[0].message.content
        
        # Cache
        self._cache.append({
            "key": cache_key,
            "embedding": query_embedding,
            "response": content,
            "created_at": datetime.utcnow()
        })
        
        return CachedResponse(content=content, cached=False)
    
    def _compute_similarity(self, emb1: list[float], emb2: list[float]) -> float:
        """Compute cosine similarity."""
        
        a = np.array(emb1)
        b = np.array(emb2)
        return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))

Production Caching Service

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional

app = FastAPI()

# Initialize components
caching_client = None  # Initialize with clients

class CompleteRequest(BaseModel):
    prompt: str
    system_prompt: Optional[str] = None
    use_cache: bool = True
    cache_ttl: int = 3600
    model: str = "gpt-4o-mini"

class CacheLookupRequest(BaseModel):
    query: str

class CacheStoreRequest(BaseModel):
    query: str
    response: str
    ttl_seconds: int = 3600

@app.post("/v1/complete")
async def complete_with_cache(request: CompleteRequest):
    """Complete prompt with caching."""
    
    result = await caching_client.complete(
        prompt=request.prompt,
        system_prompt=request.system_prompt,
        use_cache=request.use_cache,
        cache_ttl=request.cache_ttl
    )
    
    return {
        "content": result.content,
        "cached": result.cached,
        "similarity": result.similarity,
        "latency_ms": result.latency_ms
    }

@app.post("/v1/cache/lookup")
async def cache_lookup(request: CacheLookupRequest):
    """Look up query in cache."""
    
    response = await caching_client.cache.get(request.query)
    
    return {
        "hit": response is not None,
        "response": response
    }

@app.post("/v1/cache/store")
async def cache_store(request: CacheStoreRequest):
    """Store entry in cache."""
    
    await caching_client.cache.set(
        request.query,
        request.response,
        ttl=request.ttl_seconds
    )
    
    return {"stored": True}

@app.delete("/v1/cache")
async def clear_cache():
    """Clear all cache entries."""
    
    caching_client.cache._semantic_cache.clear()
    caching_client.cache._exact_cache.clear()
    
    return {"cleared": True}

@app.get("/v1/cache/stats")
async def cache_stats():
    """Get cache statistics."""
    
    return caching_client.stats

@app.get("/health")
async def health():
    return {"status": "healthy"}

References

Conclusion

Semantic caching can dramatically reduce LLM costs and latency for applications with repetitive queries. Start with a hybrid approach that combines exact matching for speed with semantic matching for flexibility. Tune your similarity threshold carefully—too low and you’ll return irrelevant cached responses, too high and you’ll miss valid cache hits. Use vector databases like Pinecone or Redis for production scale, as in-memory caches don’t persist across restarts. Implement proper TTL policies to ensure cached responses don’t become stale. Consider context when caching conversational applications—the same question in different contexts may need different answers. Track cache hit rates and false positive rates to continuously optimize your threshold. The key insight is that semantic caching trades embedding computation cost for LLM inference cost, which is usually a significant win since embeddings are much cheaper than completions.


Discover more from Code, Cloud & Context

Subscribe to get the latest posts sent to your email.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.