Introduction: LLM API calls are expensive and slow. Caching is your first line of defense against runaway costs and latency. But caching LLM responses isn’t straightforward—the same question phrased differently should return the same cached answer. This guide covers caching strategies for LLM applications: exact match caching for deterministic queries, semantic caching using embeddings for similar queries, cache invalidation strategies, TTL management, and building a production caching layer. These patterns can reduce your LLM costs by 40-70% while dramatically improving response times.

Exact Match Cache
import hashlib
import json
import time
from typing import Optional, Any
from dataclasses import dataclass
@dataclass
class CacheEntry:
response: str
created_at: float
ttl: int
metadata: dict
class ExactMatchCache:
"""Simple exact-match cache for LLM responses."""
def __init__(self, default_ttl: int = 3600):
self._cache: dict[str, CacheEntry] = {}
self.default_ttl = default_ttl
self.hits = 0
self.misses = 0
def _make_key(self, prompt: str, model: str, **kwargs) -> str:
"""Create cache key from request parameters."""
key_data = {
"prompt": prompt,
"model": model,
**kwargs
}
key_str = json.dumps(key_data, sort_keys=True)
return hashlib.sha256(key_str.encode()).hexdigest()
def get(self, prompt: str, model: str, **kwargs) -> Optional[str]:
"""Get cached response if exists and not expired."""
key = self._make_key(prompt, model, **kwargs)
entry = self._cache.get(key)
if entry is None:
self.misses += 1
return None
# Check TTL
if time.time() - entry.created_at > entry.ttl:
del self._cache[key]
self.misses += 1
return None
self.hits += 1
return entry.response
def set(
self,
prompt: str,
model: str,
response: str,
ttl: int = None,
**kwargs
):
"""Cache a response."""
key = self._make_key(prompt, model, **kwargs)
self._cache[key] = CacheEntry(
response=response,
created_at=time.time(),
ttl=ttl or self.default_ttl,
metadata={"prompt": prompt[:100], "model": model}
)
def invalidate(self, prompt: str, model: str, **kwargs):
"""Invalidate a specific cache entry."""
key = self._make_key(prompt, model, **kwargs)
self._cache.pop(key, None)
def clear(self):
"""Clear all cache entries."""
self._cache.clear()
@property
def hit_rate(self) -> float:
"""Calculate cache hit rate."""
total = self.hits + self.misses
return self.hits / total if total > 0 else 0.0
# Usage with OpenAI
from openai import OpenAI
client = OpenAI()
cache = ExactMatchCache(default_ttl=3600)
def cached_completion(prompt: str, model: str = "gpt-4o") -> str:
"""Get completion with caching."""
# Check cache first
cached = cache.get(prompt, model)
if cached:
return cached
# Call API
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}]
)
result = response.choices[0].message.content
# Cache the response
cache.set(prompt, model, result)
return result
# Same prompt returns cached response
result1 = cached_completion("What is Python?")
result2 = cached_completion("What is Python?") # Cache hit
print(f"Hit rate: {cache.hit_rate:.1%}")
Semantic Cache with Embeddings
import numpy as np
from typing import Tuple
class SemanticCache:
"""Cache that matches semantically similar queries."""
def __init__(
self,
similarity_threshold: float = 0.92,
default_ttl: int = 3600
):
self.threshold = similarity_threshold
self.default_ttl = default_ttl
self._entries: list[dict] = []
self.hits = 0
self.misses = 0
def _get_embedding(self, text: str) -> list[float]:
"""Get embedding for text."""
response = client.embeddings.create(
model="text-embedding-3-small",
input=text
)
return response.data[0].embedding
def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
"""Calculate cosine similarity between two vectors."""
a = np.array(a)
b = np.array(b)
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def _find_similar(
self,
embedding: list[float]
) -> Tuple[Optional[dict], float]:
"""Find most similar cached entry."""
best_match = None
best_score = 0.0
current_time = time.time()
for entry in self._entries:
# Skip expired entries
if current_time - entry["created_at"] > entry["ttl"]:
continue
score = self._cosine_similarity(embedding, entry["embedding"])
if score > best_score:
best_score = score
best_match = entry
return best_match, best_score
def get(self, prompt: str) -> Optional[str]:
"""Get cached response for semantically similar prompt."""
embedding = self._get_embedding(prompt)
match, score = self._find_similar(embedding)
if match and score >= self.threshold:
self.hits += 1
return match["response"]
self.misses += 1
return None
def set(self, prompt: str, response: str, ttl: int = None):
"""Cache a response with its embedding."""
embedding = self._get_embedding(prompt)
self._entries.append({
"prompt": prompt,
"embedding": embedding,
"response": response,
"created_at": time.time(),
"ttl": ttl or self.default_ttl
})
def cleanup_expired(self):
"""Remove expired entries."""
current_time = time.time()
self._entries = [
e for e in self._entries
if current_time - e["created_at"] <= e["ttl"]
]
# Usage
semantic_cache = SemanticCache(similarity_threshold=0.90)
def smart_completion(prompt: str, model: str = "gpt-4o") -> str:
"""Completion with semantic caching."""
# Check semantic cache
cached = semantic_cache.get(prompt)
if cached:
return cached
# Call API
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}]
)
result = response.choices[0].message.content
semantic_cache.set(prompt, result)
return result
# These will match semantically
result1 = smart_completion("What is the capital of France?")
result2 = smart_completion("What's France's capital city?") # Semantic hit
result3 = smart_completion("Tell me the capital of France") # Semantic hit
Redis-Backed Cache
import redis
import pickle
from typing import Optional
class RedisLLMCache:
"""Production-ready Redis-backed LLM cache."""
def __init__(
self,
redis_url: str = "redis://localhost:6379",
prefix: str = "llm_cache:",
default_ttl: int = 3600
):
self.redis = redis.from_url(redis_url)
self.prefix = prefix
self.default_ttl = default_ttl
def _make_key(self, prompt: str, model: str, **kwargs) -> str:
"""Create Redis key."""
key_data = json.dumps({
"prompt": prompt,
"model": model,
**kwargs
}, sort_keys=True)
hash_val = hashlib.sha256(key_data.encode()).hexdigest()
return f"{self.prefix}{hash_val}"
def get(self, prompt: str, model: str, **kwargs) -> Optional[dict]:
"""Get cached response."""
key = self._make_key(prompt, model, **kwargs)
data = self.redis.get(key)
if data:
return pickle.loads(data)
return None
def set(
self,
prompt: str,
model: str,
response: str,
usage: dict = None,
ttl: int = None,
**kwargs
):
"""Cache response with metadata."""
key = self._make_key(prompt, model, **kwargs)
data = {
"response": response,
"usage": usage,
"cached_at": time.time(),
"model": model
}
self.redis.setex(
key,
ttl or self.default_ttl,
pickle.dumps(data)
)
def get_stats(self) -> dict:
"""Get cache statistics."""
info = self.redis.info("stats")
keys = self.redis.keys(f"{self.prefix}*")
return {
"total_keys": len(keys),
"hits": info.get("keyspace_hits", 0),
"misses": info.get("keyspace_misses", 0),
"memory_used": self.redis.info("memory")["used_memory_human"]
}
# Usage
redis_cache = RedisLLMCache(redis_url="redis://localhost:6379")
def production_completion(prompt: str, model: str = "gpt-4o") -> dict:
"""Production completion with Redis caching."""
cached = redis_cache.get(prompt, model)
if cached:
return {
"response": cached["response"],
"cached": True,
"usage": cached.get("usage")
}
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}]
)
result = response.choices[0].message.content
usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens
}
redis_cache.set(prompt, model, result, usage=usage)
return {
"response": result,
"cached": False,
"usage": usage
}
Tiered Caching Strategy
class TieredCache:
"""Multi-tier cache: memory -> Redis -> semantic."""
def __init__(
self,
redis_url: str = "redis://localhost:6379",
semantic_threshold: float = 0.90
):
self.memory_cache = ExactMatchCache(default_ttl=300) # 5 min
self.redis_cache = RedisLLMCache(redis_url, default_ttl=3600) # 1 hour
self.semantic_cache = SemanticCache(
similarity_threshold=semantic_threshold,
default_ttl=86400 # 24 hours
)
self.stats = {
"memory_hits": 0,
"redis_hits": 0,
"semantic_hits": 0,
"misses": 0
}
def get(self, prompt: str, model: str, **kwargs) -> Optional[str]:
"""Check all cache tiers."""
# Tier 1: Memory (fastest)
result = self.memory_cache.get(prompt, model, **kwargs)
if result:
self.stats["memory_hits"] += 1
return result
# Tier 2: Redis (fast, persistent)
cached = self.redis_cache.get(prompt, model, **kwargs)
if cached:
self.stats["redis_hits"] += 1
# Promote to memory cache
self.memory_cache.set(prompt, model, cached["response"], **kwargs)
return cached["response"]
# Tier 3: Semantic (slowest, but catches similar queries)
result = self.semantic_cache.get(prompt)
if result:
self.stats["semantic_hits"] += 1
# Promote to faster caches
self.memory_cache.set(prompt, model, result, **kwargs)
self.redis_cache.set(prompt, model, result, **kwargs)
return result
self.stats["misses"] += 1
return None
def set(self, prompt: str, model: str, response: str, **kwargs):
"""Set in all cache tiers."""
self.memory_cache.set(prompt, model, response, **kwargs)
self.redis_cache.set(prompt, model, response, **kwargs)
self.semantic_cache.set(prompt, response)
def get_stats(self) -> dict:
"""Get comprehensive cache statistics."""
total = sum(self.stats.values())
return {
**self.stats,
"total_requests": total,
"overall_hit_rate": (total - self.stats["misses"]) / total if total > 0 else 0,
"memory_hit_rate": self.stats["memory_hits"] / total if total > 0 else 0,
"redis_hit_rate": self.stats["redis_hits"] / total if total > 0 else 0,
"semantic_hit_rate": self.stats["semantic_hits"] / total if total > 0 else 0
}
# Usage
tiered_cache = TieredCache()
def optimized_completion(prompt: str, model: str = "gpt-4o") -> str:
"""Completion with tiered caching."""
cached = tiered_cache.get(prompt, model)
if cached:
return cached
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}]
)
result = response.choices[0].message.content
tiered_cache.set(prompt, model, result)
return result
Cache Invalidation
from enum import Enum
from typing import Callable
class InvalidationStrategy(str, Enum):
TTL = "ttl" # Time-based expiration
VERSION = "version" # Version-based invalidation
TAG = "tag" # Tag-based invalidation
DEPENDENCY = "dependency" # Dependency tracking
class SmartCache:
"""Cache with advanced invalidation strategies."""
def __init__(self, redis_url: str = "redis://localhost:6379"):
self.redis = redis.from_url(redis_url)
self.prefix = "smart_cache:"
self.version = 1
def _make_key(self, key: str) -> str:
return f"{self.prefix}v{self.version}:{key}"
def get(self, key: str) -> Optional[dict]:
"""Get with version check."""
data = self.redis.get(self._make_key(key))
if data:
return pickle.loads(data)
return None
def set(
self,
key: str,
value: Any,
ttl: int = 3600,
tags: list[str] = None
):
"""Set with tags for group invalidation."""
full_key = self._make_key(key)
self.redis.setex(full_key, ttl, pickle.dumps(value))
# Track tags
if tags:
for tag in tags:
self.redis.sadd(f"{self.prefix}tag:{tag}", full_key)
def invalidate_by_tag(self, tag: str):
"""Invalidate all entries with a specific tag."""
tag_key = f"{self.prefix}tag:{tag}"
keys = self.redis.smembers(tag_key)
if keys:
self.redis.delete(*keys)
self.redis.delete(tag_key)
def invalidate_all(self):
"""Invalidate all cache entries by incrementing version."""
self.version += 1
def invalidate_pattern(self, pattern: str):
"""Invalidate entries matching a pattern."""
keys = self.redis.keys(f"{self.prefix}*{pattern}*")
if keys:
self.redis.delete(*keys)
# Usage with tags
smart_cache = SmartCache()
# Cache with tags
smart_cache.set(
"user_123_profile",
{"name": "John", "preferences": {...}},
tags=["user_123", "profiles"]
)
smart_cache.set(
"user_123_settings",
{"theme": "dark"},
tags=["user_123", "settings"]
)
# Invalidate all user_123 data
smart_cache.invalidate_by_tag("user_123")
# Invalidate all profiles
smart_cache.invalidate_by_tag("profiles")
Production Cache Service
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
app = FastAPI()
class CompletionRequest(BaseModel):
prompt: str
model: str = "gpt-4o"
use_cache: bool = True
cache_ttl: int = 3600
class CompletionResponse(BaseModel):
response: str
cached: bool
cache_key: Optional[str]
latency_ms: float
tokens_saved: Optional[int]
# Initialize tiered cache
cache = TieredCache()
@app.post("/completion", response_model=CompletionResponse)
async def completion(request: CompletionRequest):
"""Get completion with intelligent caching."""
start_time = time.time()
if request.use_cache:
cached = cache.get(request.prompt, request.model)
if cached:
return CompletionResponse(
response=cached,
cached=True,
cache_key=hashlib.sha256(request.prompt.encode()).hexdigest()[:16],
latency_ms=(time.time() - start_time) * 1000,
tokens_saved=len(cached.split()) * 2 # Rough estimate
)
# Call LLM
response = client.chat.completions.create(
model=request.model,
messages=[{"role": "user", "content": request.prompt}]
)
result = response.choices[0].message.content
# Cache the response
if request.use_cache:
cache.set(request.prompt, request.model, result)
return CompletionResponse(
response=result,
cached=False,
cache_key=hashlib.sha256(request.prompt.encode()).hexdigest()[:16],
latency_ms=(time.time() - start_time) * 1000,
tokens_saved=None
)
@app.get("/cache/stats")
async def cache_stats():
"""Get cache statistics."""
return cache.get_stats()
@app.post("/cache/invalidate")
async def invalidate_cache(pattern: str = None):
"""Invalidate cache entries."""
if pattern:
# Invalidate by pattern
pass
else:
# Clear all
cache.memory_cache.clear()
return {"status": "invalidated"}
References
- Redis: https://redis.io/docs/
- GPTCache: https://github.com/zilliztech/GPTCache
- LangChain Caching: https://python.langchain.com/docs/modules/model_io/llms/llm_caching
- OpenAI Embeddings: https://platform.openai.com/docs/guides/embeddings
Conclusion
Effective caching is essential for production LLM applications. Start with exact-match caching for deterministic queries—it’s simple and catches repeated requests. Add semantic caching to handle paraphrased queries that should return the same answer. Use tiered caching (memory → Redis → semantic) for optimal performance across different access patterns. Implement smart invalidation strategies based on your data freshness requirements. Monitor cache hit rates and adjust thresholds accordingly. A well-tuned caching layer can reduce your LLM costs by 50% or more while cutting response latency from seconds to milliseconds for cached queries.
Discover more from Code, Cloud & Context
Subscribe to get the latest posts sent to your email.