LLM Caching Strategies: From Exact Match to Semantic Similarity

Introduction: LLM API calls are expensive and slow. Caching is your first line of defense against runaway costs and latency. But caching LLM responses isn’t straightforward—the same question phrased differently should return the same cached answer. This guide covers caching strategies for LLM applications: exact match caching for deterministic queries, semantic caching using embeddings for similar queries, cache invalidation strategies, TTL management, and building a production caching layer. These patterns can reduce your LLM costs by 40-70% while dramatically improving response times.

LLM Caching
LLM Caching: Exact Match and Semantic Cache Layers

Exact Match Cache

import hashlib
import json
import time
from typing import Optional, Any
from dataclasses import dataclass

@dataclass
class CacheEntry:
    response: str
    created_at: float
    ttl: int
    metadata: dict

class ExactMatchCache:
    """Simple exact-match cache for LLM responses."""
    
    def __init__(self, default_ttl: int = 3600):
        self._cache: dict[str, CacheEntry] = {}
        self.default_ttl = default_ttl
        self.hits = 0
        self.misses = 0
    
    def _make_key(self, prompt: str, model: str, **kwargs) -> str:
        """Create cache key from request parameters."""
        key_data = {
            "prompt": prompt,
            "model": model,
            **kwargs
        }
        key_str = json.dumps(key_data, sort_keys=True)
        return hashlib.sha256(key_str.encode()).hexdigest()
    
    def get(self, prompt: str, model: str, **kwargs) -> Optional[str]:
        """Get cached response if exists and not expired."""
        key = self._make_key(prompt, model, **kwargs)
        
        entry = self._cache.get(key)
        if entry is None:
            self.misses += 1
            return None
        
        # Check TTL
        if time.time() - entry.created_at > entry.ttl:
            del self._cache[key]
            self.misses += 1
            return None
        
        self.hits += 1
        return entry.response
    
    def set(
        self,
        prompt: str,
        model: str,
        response: str,
        ttl: int = None,
        **kwargs
    ):
        """Cache a response."""
        key = self._make_key(prompt, model, **kwargs)
        
        self._cache[key] = CacheEntry(
            response=response,
            created_at=time.time(),
            ttl=ttl or self.default_ttl,
            metadata={"prompt": prompt[:100], "model": model}
        )
    
    def invalidate(self, prompt: str, model: str, **kwargs):
        """Invalidate a specific cache entry."""
        key = self._make_key(prompt, model, **kwargs)
        self._cache.pop(key, None)
    
    def clear(self):
        """Clear all cache entries."""
        self._cache.clear()
    
    @property
    def hit_rate(self) -> float:
        """Calculate cache hit rate."""
        total = self.hits + self.misses
        return self.hits / total if total > 0 else 0.0

# Usage with OpenAI
from openai import OpenAI

client = OpenAI()
cache = ExactMatchCache(default_ttl=3600)

def cached_completion(prompt: str, model: str = "gpt-4o") -> str:
    """Get completion with caching."""
    
    # Check cache first
    cached = cache.get(prompt, model)
    if cached:
        return cached
    
    # Call API
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}]
    )
    
    result = response.choices[0].message.content
    
    # Cache the response
    cache.set(prompt, model, result)
    
    return result

# Same prompt returns cached response
result1 = cached_completion("What is Python?")
result2 = cached_completion("What is Python?")  # Cache hit

print(f"Hit rate: {cache.hit_rate:.1%}")

Semantic Cache with Embeddings

import numpy as np
from typing import Tuple

class SemanticCache:
    """Cache that matches semantically similar queries."""
    
    def __init__(
        self,
        similarity_threshold: float = 0.92,
        default_ttl: int = 3600
    ):
        self.threshold = similarity_threshold
        self.default_ttl = default_ttl
        self._entries: list[dict] = []
        self.hits = 0
        self.misses = 0
    
    def _get_embedding(self, text: str) -> list[float]:
        """Get embedding for text."""
        response = client.embeddings.create(
            model="text-embedding-3-small",
            input=text
        )
        return response.data[0].embedding
    
    def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
        """Calculate cosine similarity between two vectors."""
        a = np.array(a)
        b = np.array(b)
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
    
    def _find_similar(
        self,
        embedding: list[float]
    ) -> Tuple[Optional[dict], float]:
        """Find most similar cached entry."""
        best_match = None
        best_score = 0.0
        
        current_time = time.time()
        
        for entry in self._entries:
            # Skip expired entries
            if current_time - entry["created_at"] > entry["ttl"]:
                continue
            
            score = self._cosine_similarity(embedding, entry["embedding"])
            
            if score > best_score:
                best_score = score
                best_match = entry
        
        return best_match, best_score
    
    def get(self, prompt: str) -> Optional[str]:
        """Get cached response for semantically similar prompt."""
        embedding = self._get_embedding(prompt)
        
        match, score = self._find_similar(embedding)
        
        if match and score >= self.threshold:
            self.hits += 1
            return match["response"]
        
        self.misses += 1
        return None
    
    def set(self, prompt: str, response: str, ttl: int = None):
        """Cache a response with its embedding."""
        embedding = self._get_embedding(prompt)
        
        self._entries.append({
            "prompt": prompt,
            "embedding": embedding,
            "response": response,
            "created_at": time.time(),
            "ttl": ttl or self.default_ttl
        })
    
    def cleanup_expired(self):
        """Remove expired entries."""
        current_time = time.time()
        self._entries = [
            e for e in self._entries
            if current_time - e["created_at"] <= e["ttl"]
        ]

# Usage
semantic_cache = SemanticCache(similarity_threshold=0.90)

def smart_completion(prompt: str, model: str = "gpt-4o") -> str:
    """Completion with semantic caching."""
    
    # Check semantic cache
    cached = semantic_cache.get(prompt)
    if cached:
        return cached
    
    # Call API
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}]
    )
    
    result = response.choices[0].message.content
    semantic_cache.set(prompt, result)
    
    return result

# These will match semantically
result1 = smart_completion("What is the capital of France?")
result2 = smart_completion("What's France's capital city?")  # Semantic hit
result3 = smart_completion("Tell me the capital of France")  # Semantic hit

Redis-Backed Cache

import redis
import pickle
from typing import Optional

class RedisLLMCache:
    """Production-ready Redis-backed LLM cache."""
    
    def __init__(
        self,
        redis_url: str = "redis://localhost:6379",
        prefix: str = "llm_cache:",
        default_ttl: int = 3600
    ):
        self.redis = redis.from_url(redis_url)
        self.prefix = prefix
        self.default_ttl = default_ttl
    
    def _make_key(self, prompt: str, model: str, **kwargs) -> str:
        """Create Redis key."""
        key_data = json.dumps({
            "prompt": prompt,
            "model": model,
            **kwargs
        }, sort_keys=True)
        hash_val = hashlib.sha256(key_data.encode()).hexdigest()
        return f"{self.prefix}{hash_val}"
    
    def get(self, prompt: str, model: str, **kwargs) -> Optional[dict]:
        """Get cached response."""
        key = self._make_key(prompt, model, **kwargs)
        
        data = self.redis.get(key)
        if data:
            return pickle.loads(data)
        return None
    
    def set(
        self,
        prompt: str,
        model: str,
        response: str,
        usage: dict = None,
        ttl: int = None,
        **kwargs
    ):
        """Cache response with metadata."""
        key = self._make_key(prompt, model, **kwargs)
        
        data = {
            "response": response,
            "usage": usage,
            "cached_at": time.time(),
            "model": model
        }
        
        self.redis.setex(
            key,
            ttl or self.default_ttl,
            pickle.dumps(data)
        )
    
    def get_stats(self) -> dict:
        """Get cache statistics."""
        info = self.redis.info("stats")
        keys = self.redis.keys(f"{self.prefix}*")
        
        return {
            "total_keys": len(keys),
            "hits": info.get("keyspace_hits", 0),
            "misses": info.get("keyspace_misses", 0),
            "memory_used": self.redis.info("memory")["used_memory_human"]
        }

# Usage
redis_cache = RedisLLMCache(redis_url="redis://localhost:6379")

def production_completion(prompt: str, model: str = "gpt-4o") -> dict:
    """Production completion with Redis caching."""
    
    cached = redis_cache.get(prompt, model)
    if cached:
        return {
            "response": cached["response"],
            "cached": True,
            "usage": cached.get("usage")
        }
    
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}]
    )
    
    result = response.choices[0].message.content
    usage = {
        "prompt_tokens": response.usage.prompt_tokens,
        "completion_tokens": response.usage.completion_tokens,
        "total_tokens": response.usage.total_tokens
    }
    
    redis_cache.set(prompt, model, result, usage=usage)
    
    return {
        "response": result,
        "cached": False,
        "usage": usage
    }

Tiered Caching Strategy

class TieredCache:
    """Multi-tier cache: memory -> Redis -> semantic."""
    
    def __init__(
        self,
        redis_url: str = "redis://localhost:6379",
        semantic_threshold: float = 0.90
    ):
        self.memory_cache = ExactMatchCache(default_ttl=300)  # 5 min
        self.redis_cache = RedisLLMCache(redis_url, default_ttl=3600)  # 1 hour
        self.semantic_cache = SemanticCache(
            similarity_threshold=semantic_threshold,
            default_ttl=86400  # 24 hours
        )
        
        self.stats = {
            "memory_hits": 0,
            "redis_hits": 0,
            "semantic_hits": 0,
            "misses": 0
        }
    
    def get(self, prompt: str, model: str, **kwargs) -> Optional[str]:
        """Check all cache tiers."""
        
        # Tier 1: Memory (fastest)
        result = self.memory_cache.get(prompt, model, **kwargs)
        if result:
            self.stats["memory_hits"] += 1
            return result
        
        # Tier 2: Redis (fast, persistent)
        cached = self.redis_cache.get(prompt, model, **kwargs)
        if cached:
            self.stats["redis_hits"] += 1
            # Promote to memory cache
            self.memory_cache.set(prompt, model, cached["response"], **kwargs)
            return cached["response"]
        
        # Tier 3: Semantic (slowest, but catches similar queries)
        result = self.semantic_cache.get(prompt)
        if result:
            self.stats["semantic_hits"] += 1
            # Promote to faster caches
            self.memory_cache.set(prompt, model, result, **kwargs)
            self.redis_cache.set(prompt, model, result, **kwargs)
            return result
        
        self.stats["misses"] += 1
        return None
    
    def set(self, prompt: str, model: str, response: str, **kwargs):
        """Set in all cache tiers."""
        self.memory_cache.set(prompt, model, response, **kwargs)
        self.redis_cache.set(prompt, model, response, **kwargs)
        self.semantic_cache.set(prompt, response)
    
    def get_stats(self) -> dict:
        """Get comprehensive cache statistics."""
        total = sum(self.stats.values())
        
        return {
            **self.stats,
            "total_requests": total,
            "overall_hit_rate": (total - self.stats["misses"]) / total if total > 0 else 0,
            "memory_hit_rate": self.stats["memory_hits"] / total if total > 0 else 0,
            "redis_hit_rate": self.stats["redis_hits"] / total if total > 0 else 0,
            "semantic_hit_rate": self.stats["semantic_hits"] / total if total > 0 else 0
        }

# Usage
tiered_cache = TieredCache()

def optimized_completion(prompt: str, model: str = "gpt-4o") -> str:
    """Completion with tiered caching."""
    
    cached = tiered_cache.get(prompt, model)
    if cached:
        return cached
    
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}]
    )
    
    result = response.choices[0].message.content
    tiered_cache.set(prompt, model, result)
    
    return result

Cache Invalidation

from enum import Enum
from typing import Callable

class InvalidationStrategy(str, Enum):
    TTL = "ttl"  # Time-based expiration
    VERSION = "version"  # Version-based invalidation
    TAG = "tag"  # Tag-based invalidation
    DEPENDENCY = "dependency"  # Dependency tracking

class SmartCache:
    """Cache with advanced invalidation strategies."""
    
    def __init__(self, redis_url: str = "redis://localhost:6379"):
        self.redis = redis.from_url(redis_url)
        self.prefix = "smart_cache:"
        self.version = 1
    
    def _make_key(self, key: str) -> str:
        return f"{self.prefix}v{self.version}:{key}"
    
    def get(self, key: str) -> Optional[dict]:
        """Get with version check."""
        data = self.redis.get(self._make_key(key))
        if data:
            return pickle.loads(data)
        return None
    
    def set(
        self,
        key: str,
        value: Any,
        ttl: int = 3600,
        tags: list[str] = None
    ):
        """Set with tags for group invalidation."""
        full_key = self._make_key(key)
        
        self.redis.setex(full_key, ttl, pickle.dumps(value))
        
        # Track tags
        if tags:
            for tag in tags:
                self.redis.sadd(f"{self.prefix}tag:{tag}", full_key)
    
    def invalidate_by_tag(self, tag: str):
        """Invalidate all entries with a specific tag."""
        tag_key = f"{self.prefix}tag:{tag}"
        keys = self.redis.smembers(tag_key)
        
        if keys:
            self.redis.delete(*keys)
            self.redis.delete(tag_key)
    
    def invalidate_all(self):
        """Invalidate all cache entries by incrementing version."""
        self.version += 1
    
    def invalidate_pattern(self, pattern: str):
        """Invalidate entries matching a pattern."""
        keys = self.redis.keys(f"{self.prefix}*{pattern}*")
        if keys:
            self.redis.delete(*keys)

# Usage with tags
smart_cache = SmartCache()

# Cache with tags
smart_cache.set(
    "user_123_profile",
    {"name": "John", "preferences": {...}},
    tags=["user_123", "profiles"]
)

smart_cache.set(
    "user_123_settings",
    {"theme": "dark"},
    tags=["user_123", "settings"]
)

# Invalidate all user_123 data
smart_cache.invalidate_by_tag("user_123")

# Invalidate all profiles
smart_cache.invalidate_by_tag("profiles")

Production Cache Service

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional

app = FastAPI()

class CompletionRequest(BaseModel):
    prompt: str
    model: str = "gpt-4o"
    use_cache: bool = True
    cache_ttl: int = 3600

class CompletionResponse(BaseModel):
    response: str
    cached: bool
    cache_key: Optional[str]
    latency_ms: float
    tokens_saved: Optional[int]

# Initialize tiered cache
cache = TieredCache()

@app.post("/completion", response_model=CompletionResponse)
async def completion(request: CompletionRequest):
    """Get completion with intelligent caching."""
    
    start_time = time.time()
    
    if request.use_cache:
        cached = cache.get(request.prompt, request.model)
        if cached:
            return CompletionResponse(
                response=cached,
                cached=True,
                cache_key=hashlib.sha256(request.prompt.encode()).hexdigest()[:16],
                latency_ms=(time.time() - start_time) * 1000,
                tokens_saved=len(cached.split()) * 2  # Rough estimate
            )
    
    # Call LLM
    response = client.chat.completions.create(
        model=request.model,
        messages=[{"role": "user", "content": request.prompt}]
    )
    
    result = response.choices[0].message.content
    
    # Cache the response
    if request.use_cache:
        cache.set(request.prompt, request.model, result)
    
    return CompletionResponse(
        response=result,
        cached=False,
        cache_key=hashlib.sha256(request.prompt.encode()).hexdigest()[:16],
        latency_ms=(time.time() - start_time) * 1000,
        tokens_saved=None
    )

@app.get("/cache/stats")
async def cache_stats():
    """Get cache statistics."""
    return cache.get_stats()

@app.post("/cache/invalidate")
async def invalidate_cache(pattern: str = None):
    """Invalidate cache entries."""
    if pattern:
        # Invalidate by pattern
        pass
    else:
        # Clear all
        cache.memory_cache.clear()
    
    return {"status": "invalidated"}

References

Conclusion

Effective caching is essential for production LLM applications. Start with exact-match caching for deterministic queries—it’s simple and catches repeated requests. Add semantic caching to handle paraphrased queries that should return the same answer. Use tiered caching (memory → Redis → semantic) for optimal performance across different access patterns. Implement smart invalidation strategies based on your data freshness requirements. Monitor cache hit rates and adjust thresholds accordingly. A well-tuned caching layer can reduce your LLM costs by 50% or more while cutting response latency from seconds to milliseconds for cached queries.


Discover more from Code, Cloud & Context

Subscribe to get the latest posts sent to your email.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.