Categories

Archives

A sample text widget

Etiam pulvinar consectetur dolor sed malesuada. Ut convallis euismod dolor nec pretium. Nunc ut tristique massa.

Nam sodales mi vitae dolor ullamcorper et vulputate enim accumsan. Morbi orci magna, tincidunt vitae molestie nec, molestie at mi. Nulla nulla lorem, suscipit in posuere in, interdum non magna.

LLM Caching Strategies: Reducing Costs and Latency with Smart Response Caching

Introduction: LLM API calls are expensive and slow. A single GPT-4 request can cost $0.03-0.12 and take 2-10 seconds. When users ask similar questions repeatedly, you’re paying for the same computation over and over. Caching solves this by storing responses and returning them instantly for matching requests. But LLM caching is harder than traditional caching—users phrase the same question differently, and semantic similarity matters more than exact string matching. This guide covers practical caching strategies: exact match caching for identical requests, semantic caching using embeddings, cache invalidation strategies, and building production-ready caching layers that dramatically reduce costs and latency.

LLM Caching
LLM Caching: Exact Match, Semantic Match, Cache Store

Exact Match Caching

from dataclasses import dataclass
from typing import Any, Optional
import hashlib
import json
import time
from datetime import datetime, timedelta

@dataclass
class CacheEntry:
    """A cached LLM response."""
    
    key: str
    response: str
    model: str
    created_at: float
    expires_at: float
    hit_count: int = 0
    tokens_saved: int = 0

class ExactMatchCache:
    """Cache LLM responses using exact prompt matching."""
    
    def __init__(self, ttl_seconds: int = 3600):
        self.cache: dict[str, CacheEntry] = {}
        self.ttl_seconds = ttl_seconds
        self.stats = {
            "hits": 0,
            "misses": 0,
            "tokens_saved": 0,
            "cost_saved": 0.0
        }
    
    def _generate_key(
        self,
        prompt: str,
        model: str,
        system_prompt: str = None,
        temperature: float = None
    ) -> str:
        """Generate cache key from request parameters."""
        
        key_data = {
            "prompt": prompt,
            "model": model,
            "system_prompt": system_prompt,
            "temperature": temperature
        }
        
        key_string = json.dumps(key_data, sort_keys=True)
        return hashlib.sha256(key_string.encode()).hexdigest()
    
    def get(
        self,
        prompt: str,
        model: str,
        system_prompt: str = None,
        temperature: float = None
    ) -> Optional[str]:
        """Get cached response if available."""
        
        key = self._generate_key(prompt, model, system_prompt, temperature)
        
        if key not in self.cache:
            self.stats["misses"] += 1
            return None
        
        entry = self.cache[key]
        
        # Check expiration
        if time.time() > entry.expires_at:
            del self.cache[key]
            self.stats["misses"] += 1
            return None
        
        # Update stats
        entry.hit_count += 1
        self.stats["hits"] += 1
        self.stats["tokens_saved"] += entry.tokens_saved
        
        return entry.response
    
    def set(
        self,
        prompt: str,
        model: str,
        response: str,
        system_prompt: str = None,
        temperature: float = None,
        tokens_used: int = 0
    ) -> None:
        """Cache a response."""
        
        key = self._generate_key(prompt, model, system_prompt, temperature)
        
        self.cache[key] = CacheEntry(
            key=key,
            response=response,
            model=model,
            created_at=time.time(),
            expires_at=time.time() + self.ttl_seconds,
            tokens_saved=tokens_used
        )
    
    def invalidate(self, pattern: str = None) -> int:
        """Invalidate cache entries."""
        
        if pattern is None:
            count = len(self.cache)
            self.cache.clear()
            return count
        
        # Pattern-based invalidation
        keys_to_remove = []
        for key, entry in self.cache.items():
            if pattern in entry.response:
                keys_to_remove.append(key)
        
        for key in keys_to_remove:
            del self.cache[key]
        
        return len(keys_to_remove)
    
    def get_stats(self) -> dict:
        """Get cache statistics."""
        
        total = self.stats["hits"] + self.stats["misses"]
        hit_rate = self.stats["hits"] / total if total > 0 else 0
        
        return {
            **self.stats,
            "total_requests": total,
            "hit_rate": hit_rate,
            "cache_size": len(self.cache)
        }

class RedisExactCache:
    """Redis-backed exact match cache for distributed systems."""
    
    def __init__(
        self,
        redis_client: Any,
        prefix: str = "llm_cache",
        ttl_seconds: int = 3600
    ):
        self.redis = redis_client
        self.prefix = prefix
        self.ttl_seconds = ttl_seconds
    
    def _generate_key(self, prompt: str, model: str, **kwargs) -> str:
        """Generate Redis key."""
        
        key_data = {"prompt": prompt, "model": model, **kwargs}
        key_hash = hashlib.sha256(
            json.dumps(key_data, sort_keys=True).encode()
        ).hexdigest()
        
        return f"{self.prefix}:{key_hash}"
    
    async def get(self, prompt: str, model: str, **kwargs) -> Optional[str]:
        """Get cached response from Redis."""
        
        key = self._generate_key(prompt, model, **kwargs)
        
        data = await self.redis.get(key)
        if data:
            entry = json.loads(data)
            await self.redis.hincrby(f"{self.prefix}:stats", "hits", 1)
            return entry["response"]
        
        await self.redis.hincrby(f"{self.prefix}:stats", "misses", 1)
        return None
    
    async def set(
        self,
        prompt: str,
        model: str,
        response: str,
        tokens_used: int = 0,
        **kwargs
    ) -> None:
        """Cache response in Redis."""
        
        key = self._generate_key(prompt, model, **kwargs)
        
        entry = {
            "response": response,
            "model": model,
            "tokens_used": tokens_used,
            "created_at": time.time()
        }
        
        await self.redis.setex(
            key,
            self.ttl_seconds,
            json.dumps(entry)
        )

Semantic Caching

from dataclasses import dataclass
from typing import Any, Optional
import numpy as np

@dataclass
class SemanticCacheEntry:
    """A semantically cached response."""
    
    prompt: str
    embedding: np.ndarray
    response: str
    model: str
    created_at: float
    expires_at: float
    similarity_threshold: float

class SemanticCache:
    """Cache LLM responses using semantic similarity."""
    
    def __init__(
        self,
        embedding_client: Any,
        similarity_threshold: float = 0.95,
        ttl_seconds: int = 3600,
        max_entries: int = 10000
    ):
        self.embedding_client = embedding_client
        self.similarity_threshold = similarity_threshold
        self.ttl_seconds = ttl_seconds
        self.max_entries = max_entries
        
        self.entries: list[SemanticCacheEntry] = []
        self.stats = {
            "hits": 0,
            "misses": 0,
            "semantic_hits": 0,
            "tokens_saved": 0
        }
    
    async def _get_embedding(self, text: str) -> np.ndarray:
        """Get embedding for text."""
        
        response = await self.embedding_client.embeddings.create(
            model="text-embedding-3-small",
            input=text
        )
        
        return np.array(response.data[0].embedding)
    
    def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
        """Calculate cosine similarity between two vectors."""
        
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
    
    async def get(
        self,
        prompt: str,
        model: str = None
    ) -> Optional[tuple[str, float]]:
        """Get cached response if semantically similar prompt exists."""
        
        # Get embedding for query
        query_embedding = await self._get_embedding(prompt)
        
        # Remove expired entries
        current_time = time.time()
        self.entries = [e for e in self.entries if e.expires_at > current_time]
        
        # Find best match
        best_match = None
        best_similarity = 0.0
        
        for entry in self.entries:
            # Filter by model if specified
            if model and entry.model != model:
                continue
            
            similarity = self._cosine_similarity(query_embedding, entry.embedding)
            
            if similarity > best_similarity:
                best_similarity = similarity
                best_match = entry
        
        # Check if match exceeds threshold
        if best_match and best_similarity >= self.similarity_threshold:
            self.stats["hits"] += 1
            self.stats["semantic_hits"] += 1
            return best_match.response, best_similarity
        
        self.stats["misses"] += 1
        return None
    
    async def set(
        self,
        prompt: str,
        response: str,
        model: str,
        tokens_used: int = 0
    ) -> None:
        """Cache a response with its embedding."""
        
        # Enforce max entries
        if len(self.entries) >= self.max_entries:
            # Remove oldest entries
            self.entries.sort(key=lambda e: e.created_at)
            self.entries = self.entries[len(self.entries) // 2:]
        
        embedding = await self._get_embedding(prompt)
        
        entry = SemanticCacheEntry(
            prompt=prompt,
            embedding=embedding,
            response=response,
            model=model,
            created_at=time.time(),
            expires_at=time.time() + self.ttl_seconds,
            similarity_threshold=self.similarity_threshold
        )
        
        self.entries.append(entry)

class VectorDBSemanticCache:
    """Semantic cache backed by vector database for scale."""
    
    def __init__(
        self,
        vector_db: Any,
        embedding_client: Any,
        collection_name: str = "llm_cache",
        similarity_threshold: float = 0.95,
        ttl_seconds: int = 3600
    ):
        self.vector_db = vector_db
        self.embedding_client = embedding_client
        self.collection_name = collection_name
        self.similarity_threshold = similarity_threshold
        self.ttl_seconds = ttl_seconds
    
    async def _get_embedding(self, text: str) -> list[float]:
        """Get embedding for text."""
        
        response = await self.embedding_client.embeddings.create(
            model="text-embedding-3-small",
            input=text
        )
        
        return response.data[0].embedding
    
    async def get(self, prompt: str, model: str = None) -> Optional[tuple[str, float]]:
        """Search for semantically similar cached response."""
        
        query_embedding = await self._get_embedding(prompt)
        
        # Build filter
        filter_conditions = {
            "expires_at": {"$gt": time.time()}
        }
        if model:
            filter_conditions["model"] = model
        
        # Search vector DB
        results = await self.vector_db.search(
            collection_name=self.collection_name,
            query_vector=query_embedding,
            limit=1,
            filter=filter_conditions
        )
        
        if results and results[0].score >= self.similarity_threshold:
            return results[0].payload["response"], results[0].score
        
        return None
    
    async def set(
        self,
        prompt: str,
        response: str,
        model: str,
        tokens_used: int = 0
    ) -> None:
        """Store response in vector database."""
        
        embedding = await self._get_embedding(prompt)
        
        await self.vector_db.upsert(
            collection_name=self.collection_name,
            points=[{
                "id": hashlib.sha256(prompt.encode()).hexdigest(),
                "vector": embedding,
                "payload": {
                    "prompt": prompt,
                    "response": response,
                    "model": model,
                    "tokens_used": tokens_used,
                    "created_at": time.time(),
                    "expires_at": time.time() + self.ttl_seconds
                }
            }]
        )

Hybrid Caching Strategy

from dataclasses import dataclass
from typing import Any, Optional
from enum import Enum

class CacheHitType(Enum):
    """Type of cache hit."""
    
    EXACT = "exact"
    SEMANTIC = "semantic"
    MISS = "miss"

@dataclass
class CacheResult:
    """Result of cache lookup."""
    
    hit_type: CacheHitType
    response: Optional[str] = None
    similarity: Optional[float] = None
    latency_ms: float = 0

class HybridCache:
    """Hybrid cache combining exact and semantic matching."""
    
    def __init__(
        self,
        exact_cache: ExactMatchCache,
        semantic_cache: SemanticCache,
        semantic_threshold: float = 0.95
    ):
        self.exact_cache = exact_cache
        self.semantic_cache = semantic_cache
        self.semantic_threshold = semantic_threshold
        
        self.stats = {
            "exact_hits": 0,
            "semantic_hits": 0,
            "misses": 0,
            "total_latency_saved_ms": 0
        }
    
    async def get(
        self,
        prompt: str,
        model: str,
        system_prompt: str = None,
        temperature: float = None
    ) -> CacheResult:
        """Try exact match first, then semantic match."""
        
        start_time = time.time()
        
        # Try exact match first (fast)
        exact_result = self.exact_cache.get(
            prompt, model, system_prompt, temperature
        )
        
        if exact_result:
            latency = (time.time() - start_time) * 1000
            self.stats["exact_hits"] += 1
            
            return CacheResult(
                hit_type=CacheHitType.EXACT,
                response=exact_result,
                similarity=1.0,
                latency_ms=latency
            )
        
        # Try semantic match (slower but catches paraphrases)
        semantic_result = await self.semantic_cache.get(prompt, model)
        
        if semantic_result:
            response, similarity = semantic_result
            latency = (time.time() - start_time) * 1000
            self.stats["semantic_hits"] += 1
            
            return CacheResult(
                hit_type=CacheHitType.SEMANTIC,
                response=response,
                similarity=similarity,
                latency_ms=latency
            )
        
        latency = (time.time() - start_time) * 1000
        self.stats["misses"] += 1
        
        return CacheResult(
            hit_type=CacheHitType.MISS,
            latency_ms=latency
        )
    
    async def set(
        self,
        prompt: str,
        model: str,
        response: str,
        system_prompt: str = None,
        temperature: float = None,
        tokens_used: int = 0
    ) -> None:
        """Store in both caches."""
        
        # Store in exact cache
        self.exact_cache.set(
            prompt, model, response,
            system_prompt, temperature, tokens_used
        )
        
        # Store in semantic cache
        await self.semantic_cache.set(
            prompt, response, model, tokens_used
        )
    
    def get_stats(self) -> dict:
        """Get combined statistics."""
        
        total = (
            self.stats["exact_hits"] +
            self.stats["semantic_hits"] +
            self.stats["misses"]
        )
        
        return {
            **self.stats,
            "total_requests": total,
            "exact_hit_rate": self.stats["exact_hits"] / total if total > 0 else 0,
            "semantic_hit_rate": self.stats["semantic_hits"] / total if total > 0 else 0,
            "overall_hit_rate": (
                self.stats["exact_hits"] + self.stats["semantic_hits"]
            ) / total if total > 0 else 0
        }

class TieredCache:
    """Multi-tier cache with different TTLs and strategies."""
    
    def __init__(
        self,
        l1_cache: ExactMatchCache,  # Fast, short TTL
        l2_cache: HybridCache,       # Slower, longer TTL
        l1_ttl: int = 300,           # 5 minutes
        l2_ttl: int = 3600           # 1 hour
    ):
        self.l1_cache = l1_cache
        self.l2_cache = l2_cache
        self.l1_ttl = l1_ttl
        self.l2_ttl = l2_ttl
    
    async def get(
        self,
        prompt: str,
        model: str,
        **kwargs
    ) -> CacheResult:
        """Check L1 then L2 cache."""
        
        # Check L1 (exact match, in-memory)
        l1_result = self.l1_cache.get(prompt, model, **kwargs)
        if l1_result:
            return CacheResult(
                hit_type=CacheHitType.EXACT,
                response=l1_result,
                similarity=1.0
            )
        
        # Check L2 (hybrid, may be distributed)
        l2_result = await self.l2_cache.get(prompt, model, **kwargs)
        
        if l2_result.hit_type != CacheHitType.MISS:
            # Promote to L1
            self.l1_cache.set(
                prompt, model, l2_result.response, **kwargs
            )
        
        return l2_result
    
    async def set(
        self,
        prompt: str,
        model: str,
        response: str,
        **kwargs
    ) -> None:
        """Store in both tiers."""
        
        self.l1_cache.set(prompt, model, response, **kwargs)
        await self.l2_cache.set(prompt, model, response, **kwargs)

Cache-Aware LLM Client

from dataclasses import dataclass
from typing import Any, Optional, AsyncIterator

@dataclass
class CachedResponse:
    """Response with cache metadata."""
    
    content: str
    cached: bool
    cache_hit_type: Optional[CacheHitType] = None
    similarity: Optional[float] = None
    tokens_used: int = 0
    latency_ms: float = 0
    cost_saved: float = 0

class CachedLLMClient:
    """LLM client with integrated caching."""
    
    def __init__(
        self,
        client: Any,
        cache: HybridCache,
        model: str = "gpt-4o-mini",
        enable_caching: bool = True
    ):
        self.client = client
        self.cache = cache
        self.model = model
        self.enable_caching = enable_caching
        
        # Cost per 1K tokens (approximate)
        self.cost_per_1k = {
            "gpt-4o": 0.005,
            "gpt-4o-mini": 0.00015,
            "gpt-4-turbo": 0.01,
            "claude-3-5-sonnet": 0.003,
            "claude-3-haiku": 0.00025
        }
    
    def _estimate_cost(self, tokens: int, model: str) -> float:
        """Estimate cost for tokens."""
        
        rate = self.cost_per_1k.get(model, 0.001)
        return (tokens / 1000) * rate
    
    async def complete(
        self,
        prompt: str,
        system_prompt: str = None,
        temperature: float = 0.7,
        max_tokens: int = None,
        skip_cache: bool = False
    ) -> CachedResponse:
        """Generate completion with caching."""
        
        start_time = time.time()
        
        # Check cache if enabled
        if self.enable_caching and not skip_cache and temperature == 0:
            cache_result = await self.cache.get(
                prompt, self.model, system_prompt, temperature
            )
            
            if cache_result.hit_type != CacheHitType.MISS:
                latency = (time.time() - start_time) * 1000
                
                return CachedResponse(
                    content=cache_result.response,
                    cached=True,
                    cache_hit_type=cache_result.hit_type,
                    similarity=cache_result.similarity,
                    latency_ms=latency
                )
        
        # Generate fresh response
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})
        
        response = await self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens
        )
        
        content = response.choices[0].message.content
        tokens_used = response.usage.total_tokens
        latency = (time.time() - start_time) * 1000
        
        # Cache the response (only for deterministic outputs)
        if self.enable_caching and temperature == 0:
            await self.cache.set(
                prompt, self.model, content,
                system_prompt, temperature, tokens_used
            )
        
        return CachedResponse(
            content=content,
            cached=False,
            tokens_used=tokens_used,
            latency_ms=latency
        )
    
    async def complete_batch(
        self,
        prompts: list[str],
        system_prompt: str = None,
        temperature: float = 0
    ) -> list[CachedResponse]:
        """Process batch with caching."""
        
        results = []
        uncached_prompts = []
        uncached_indices = []
        
        # Check cache for all prompts
        for i, prompt in enumerate(prompts):
            cache_result = await self.cache.get(
                prompt, self.model, system_prompt, temperature
            )
            
            if cache_result.hit_type != CacheHitType.MISS:
                results.append((i, CachedResponse(
                    content=cache_result.response,
                    cached=True,
                    cache_hit_type=cache_result.hit_type
                )))
            else:
                uncached_prompts.append(prompt)
                uncached_indices.append(i)
        
        # Generate uncached responses
        for prompt, idx in zip(uncached_prompts, uncached_indices):
            response = await self.complete(
                prompt, system_prompt, temperature, skip_cache=True
            )
            results.append((idx, response))
            
            # Cache the response
            await self.cache.set(
                prompt, self.model, response.content,
                system_prompt, temperature, response.tokens_used
            )
        
        # Sort by original index
        results.sort(key=lambda x: x[0])
        return [r[1] for r in results]

Production Caching Service

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional

app = FastAPI()

# Initialize caches
exact_cache = ExactMatchCache(ttl_seconds=3600)
semantic_cache = None  # Initialize with embedding client
hybrid_cache = None    # Initialize with both caches
cached_client = None   # Initialize with LLM client and cache

class CompletionRequest(BaseModel):
    prompt: str
    system_prompt: Optional[str] = None
    temperature: float = 0
    max_tokens: Optional[int] = None
    skip_cache: bool = False

class BatchRequest(BaseModel):
    prompts: list[str]
    system_prompt: Optional[str] = None
    temperature: float = 0

class CacheInvalidateRequest(BaseModel):
    pattern: Optional[str] = None

@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
    """Generate completion with caching."""
    
    response = await cached_client.complete(
        prompt=request.prompt,
        system_prompt=request.system_prompt,
        temperature=request.temperature,
        max_tokens=request.max_tokens,
        skip_cache=request.skip_cache
    )
    
    return {
        "content": response.content,
        "cached": response.cached,
        "cache_hit_type": response.cache_hit_type.value if response.cache_hit_type else None,
        "similarity": response.similarity,
        "tokens_used": response.tokens_used,
        "latency_ms": response.latency_ms
    }

@app.post("/v1/completions/batch")
async def create_batch_completion(request: BatchRequest):
    """Process batch with caching."""
    
    responses = await cached_client.complete_batch(
        prompts=request.prompts,
        system_prompt=request.system_prompt,
        temperature=request.temperature
    )
    
    cached_count = sum(1 for r in responses if r.cached)
    
    return {
        "responses": [
            {
                "content": r.content,
                "cached": r.cached,
                "cache_hit_type": r.cache_hit_type.value if r.cache_hit_type else None
            }
            for r in responses
        ],
        "cached_count": cached_count,
        "total_count": len(responses),
        "cache_hit_rate": cached_count / len(responses)
    }

@app.get("/v1/cache/stats")
async def get_cache_stats():
    """Get cache statistics."""
    
    return {
        "exact_cache": exact_cache.get_stats(),
        "hybrid_cache": hybrid_cache.get_stats() if hybrid_cache else None
    }

@app.post("/v1/cache/invalidate")
async def invalidate_cache(request: CacheInvalidateRequest):
    """Invalidate cache entries."""
    
    count = exact_cache.invalidate(request.pattern)
    
    return {
        "invalidated_count": count,
        "pattern": request.pattern
    }

@app.delete("/v1/cache")
async def clear_cache():
    """Clear all cache entries."""
    
    count = exact_cache.invalidate()
    
    return {
        "cleared_count": count
    }

@app.get("/health")
async def health():
    return {"status": "healthy"}

References

Conclusion

LLM caching is essential for production applications—it reduces costs, improves latency, and provides consistent responses for repeated queries. Exact match caching handles identical requests efficiently with minimal overhead. Semantic caching catches paraphrased questions by comparing embeddings, dramatically increasing hit rates for conversational applications. Hybrid approaches combine both strategies: fast exact matching for common queries and semantic fallback for variations. The key considerations are choosing appropriate similarity thresholds (0.95+ for high precision), implementing proper TTL strategies based on content freshness requirements, and building cache-aware clients that transparently handle caching logic. For production systems, use Redis or a vector database for distributed caching, implement cache warming for predictable queries, and monitor hit rates to tune thresholds. Start with exact matching for quick wins, add semantic caching for conversational use cases, and always cache deterministic (temperature=0) requests.