Introduction: LLM API calls are expensive and slow. A single GPT-4 request can cost cents and take seconds—multiply that by thousands of users asking similar questions, and costs spiral quickly. Semantic caching solves this by recognizing that “What’s the weather in NYC?” and “Tell me NYC weather” are essentially the same query. Instead of exact string matching, semantic caching uses embeddings to find similar past queries and return cached responses. This guide covers implementing semantic caching from scratch, using libraries like GPTCache, and production patterns for cache invalidation and quality control.

Basic Semantic Cache Implementation
import numpy as np
from openai import OpenAI
import hashlib
import json
from datetime import datetime, timedelta
from typing import Optional
client = OpenAI()
class SemanticCache:
"""Simple semantic cache using in-memory storage."""
def __init__(
self,
similarity_threshold: float = 0.95,
ttl_hours: int = 24,
embedding_model: str = "text-embedding-3-small"
):
self.threshold = similarity_threshold
self.ttl = timedelta(hours=ttl_hours)
self.embedding_model = embedding_model
# Storage (use Redis/Pinecone in production)
self.embeddings = [] # List of (embedding, query, response, timestamp)
self.exact_cache = {} # Exact match fallback
def _get_embedding(self, text: str) -> np.ndarray:
"""Get embedding for text."""
response = client.embeddings.create(
input=text,
model=self.embedding_model
)
return np.array(response.data[0].embedding)
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
"""Calculate cosine similarity between two vectors."""
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def _normalize_query(self, query: str) -> str:
"""Normalize query for exact matching."""
return query.lower().strip()
def get(self, query: str) -> Optional[str]:
"""Try to get cached response for query."""
# Try exact match first (fast)
normalized = self._normalize_query(query)
if normalized in self.exact_cache:
entry = self.exact_cache[normalized]
if datetime.now() - entry["timestamp"] < self.ttl:
return entry["response"]
# Semantic search
if not self.embeddings:
return None
query_embedding = self._get_embedding(query)
best_match = None
best_similarity = 0
for embedding, cached_query, response, timestamp in self.embeddings:
# Check TTL
if datetime.now() - timestamp > self.ttl:
continue
similarity = self._cosine_similarity(query_embedding, embedding)
if similarity > best_similarity:
best_similarity = similarity
best_match = (cached_query, response)
if best_similarity >= self.threshold:
return best_match[1]
return None
def set(self, query: str, response: str):
"""Cache a query-response pair."""
# Exact cache
normalized = self._normalize_query(query)
self.exact_cache[normalized] = {
"response": response,
"timestamp": datetime.now()
}
# Semantic cache
embedding = self._get_embedding(query)
self.embeddings.append((
embedding,
query,
response,
datetime.now()
))
def clear_expired(self):
"""Remove expired entries."""
now = datetime.now()
# Clear exact cache
self.exact_cache = {
k: v for k, v in self.exact_cache.items()
if now - v["timestamp"] < self.ttl
}
# Clear semantic cache
self.embeddings = [
entry for entry in self.embeddings
if now - entry[3] < self.ttl
]
# Usage
cache = SemanticCache(similarity_threshold=0.92)
def cached_completion(query: str) -> str:
"""Get completion with semantic caching."""
# Check cache
cached = cache.get(query)
if cached:
print("Cache hit!")
return cached
# Call LLM
print("Cache miss - calling LLM")
response = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[{"role": "user", "content": query}]
)
result = response.choices[0].message.content
# Cache the result
cache.set(query, result)
return result
# Test
print(cached_completion("What is machine learning?"))
print(cached_completion("Explain machine learning")) # Should hit cache
print(cached_completion("What is ML?")) # Should hit cache
Redis-Based Semantic Cache
import redis
import numpy as np
from redis.commands.search.field import VectorField, TextField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import Query
import json
class RedisSemanticCache:
"""Production semantic cache using Redis with vector search."""
def __init__(
self,
redis_url: str = "redis://localhost:6379",
index_name: str = "llm_cache",
similarity_threshold: float = 0.92,
ttl_seconds: int = 86400,
vector_dim: int = 1536
):
self.redis = redis.from_url(redis_url)
self.index_name = index_name
self.threshold = similarity_threshold
self.ttl = ttl_seconds
self.vector_dim = vector_dim
self._create_index()
def _create_index(self):
"""Create Redis search index for vectors."""
try:
self.redis.ft(self.index_name).info()
except:
# Create index
schema = (
TextField("query"),
TextField("response"),
VectorField(
"embedding",
"HNSW",
{
"TYPE": "FLOAT32",
"DIM": self.vector_dim,
"DISTANCE_METRIC": "COSINE"
}
)
)
self.redis.ft(self.index_name).create_index(
schema,
definition=IndexDefinition(
prefix=["cache:"],
index_type=IndexType.HASH
)
)
def get(self, query: str, query_embedding: np.ndarray) -> Optional[str]:
"""Search for similar cached query."""
# Convert to bytes for Redis
embedding_bytes = query_embedding.astype(np.float32).tobytes()
# Vector similarity search
q = (
Query(f"*=>[KNN 1 @embedding $vec AS score]")
.return_fields("query", "response", "score")
.sort_by("score")
.dialect(2)
)
results = self.redis.ft(self.index_name).search(
q,
query_params={"vec": embedding_bytes}
)
if results.docs:
doc = results.docs[0]
similarity = 1 - float(doc.score) # Convert distance to similarity
if similarity >= self.threshold:
return doc.response
return None
def set(self, query: str, response: str, embedding: np.ndarray):
"""Cache query-response pair."""
import uuid
key = f"cache:{uuid.uuid4()}"
self.redis.hset(
key,
mapping={
"query": query,
"response": response,
"embedding": embedding.astype(np.float32).tobytes()
}
)
# Set TTL
self.redis.expire(key, self.ttl)
def stats(self) -> dict:
"""Get cache statistics."""
info = self.redis.ft(self.index_name).info()
return {
"num_docs": info["num_docs"],
"index_size": info.get("inverted_sz_mb", 0)
}
# Production wrapper
class CachedLLM:
"""LLM client with semantic caching."""
def __init__(self, cache: RedisSemanticCache):
self.cache = cache
self.client = OpenAI()
self.stats = {"hits": 0, "misses": 0}
def _get_embedding(self, text: str) -> np.ndarray:
response = self.client.embeddings.create(
input=text,
model="text-embedding-3-small"
)
return np.array(response.data[0].embedding)
def complete(self, query: str, **kwargs) -> str:
"""Get completion with caching."""
embedding = self._get_embedding(query)
# Check cache
cached = self.cache.get(query, embedding)
if cached:
self.stats["hits"] += 1
return cached
# Call LLM
self.stats["misses"] += 1
response = self.client.chat.completions.create(
model=kwargs.get("model", "gpt-4-turbo-preview"),
messages=[{"role": "user", "content": query}],
**{k: v for k, v in kwargs.items() if k != "model"}
)
result = response.choices[0].message.content
# Cache
self.cache.set(query, result, embedding)
return result
def hit_rate(self) -> float:
total = self.stats["hits"] + self.stats["misses"]
return self.stats["hits"] / total if total > 0 else 0
Using GPTCache Library
# pip install gptcache
from gptcache import cache
from gptcache.adapter import openai
from gptcache.embedding import Onnx
from gptcache.manager import CacheBase, VectorBase, get_data_manager
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation
# Initialize embedding model
onnx = Onnx()
# Setup cache with SQLite + FAISS
data_manager = get_data_manager(
CacheBase("sqlite"),
VectorBase("faiss", dimension=onnx.dimension)
)
# Initialize cache
cache.init(
embedding_func=onnx.to_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation()
)
# Set similarity threshold
cache.set_openai_key()
# Use cached OpenAI client
response = openai.ChatCompletion.create(
model="gpt-4-turbo-preview",
messages=[{"role": "user", "content": "What is Python?"}]
)
print(response.choices[0].message.content)
# Second call with similar query - should hit cache
response = openai.ChatCompletion.create(
model="gpt-4-turbo-preview",
messages=[{"role": "user", "content": "Explain Python programming"}]
)
# Advanced configuration
from gptcache.processor.pre import last_content
from gptcache.processor.post import first
cache.init(
pre_embedding_func=last_content, # Use last message for embedding
embedding_func=onnx.to_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
post_process_messages_func=first # Return first cached result
)
Cache Invalidation Strategies
from datetime import datetime, timedelta
from typing import Optional, Callable
import hashlib
class SmartCache:
"""Cache with intelligent invalidation strategies."""
def __init__(self, base_cache):
self.cache = base_cache
self.invalidation_rules = []
def add_invalidation_rule(self, rule: Callable[[str, str], bool]):
"""Add custom invalidation rule."""
self.invalidation_rules.append(rule)
def should_invalidate(self, query: str, cached_response: str) -> bool:
"""Check if cached response should be invalidated."""
for rule in self.invalidation_rules:
if rule(query, cached_response):
return True
return False
def get(self, query: str, embedding) -> Optional[str]:
"""Get with invalidation check."""
cached = self.cache.get(query, embedding)
if cached and self.should_invalidate(query, cached):
return None # Force cache miss
return cached
# Common invalidation rules
def time_sensitive_rule(query: str, response: str) -> bool:
"""Invalidate for time-sensitive queries."""
time_keywords = ["today", "now", "current", "latest", "recent", "this week"]
return any(kw in query.lower() for kw in time_keywords)
def stale_data_rule(query: str, response: str) -> bool:
"""Invalidate if response mentions outdated dates."""
import re
# Find years in response
years = re.findall(r'\b(20\d{2})\b', response)
current_year = datetime.now().year
# If response mentions years more than 1 year old, might be stale
for year in years:
if int(year) < current_year - 1:
return True
return False
def confidence_rule(query: str, response: str) -> bool:
"""Invalidate low-confidence responses."""
uncertainty_phrases = [
"i'm not sure",
"i don't know",
"i cannot",
"as of my knowledge",
"i don't have access"
]
return any(phrase in response.lower() for phrase in uncertainty_phrases)
# Usage
smart_cache = SmartCache(redis_cache)
smart_cache.add_invalidation_rule(time_sensitive_rule)
smart_cache.add_invalidation_rule(stale_data_rule)
smart_cache.add_invalidation_rule(confidence_rule)
# Context-aware caching
class ContextAwareCache:
"""Cache that considers conversation context."""
def __init__(self, base_cache):
self.cache = base_cache
def _context_hash(self, messages: list[dict]) -> str:
"""Create hash of conversation context."""
# Use last N messages for context
context_messages = messages[-3:]
context_str = json.dumps(context_messages, sort_keys=True)
return hashlib.md5(context_str.encode()).hexdigest()[:8]
def get(self, messages: list[dict], embedding) -> Optional[str]:
"""Get with context awareness."""
# Combine query with context hash
query = messages[-1]["content"]
context_hash = self._context_hash(messages[:-1])
cache_key = f"{context_hash}:{query}"
return self.cache.get(cache_key, embedding)
def set(self, messages: list[dict], response: str, embedding):
"""Set with context."""
query = messages[-1]["content"]
context_hash = self._context_hash(messages[:-1])
cache_key = f"{context_hash}:{query}"
self.cache.set(cache_key, response, embedding)
Monitoring and Optimization
from dataclasses import dataclass, field
from collections import defaultdict
import time
@dataclass
class CacheMetrics:
"""Track cache performance metrics."""
hits: int = 0
misses: int = 0
total_latency_saved_ms: float = 0
total_cost_saved: float = 0
queries_by_similarity: dict = field(default_factory=lambda: defaultdict(int))
class MonitoredCache:
"""Cache with comprehensive monitoring."""
def __init__(self, base_cache, avg_llm_latency_ms: float = 2000, cost_per_1k_tokens: float = 0.01):
self.cache = base_cache
self.metrics = CacheMetrics()
self.avg_latency = avg_llm_latency_ms
self.cost_per_1k = cost_per_1k_tokens
def get(self, query: str, embedding) -> tuple[Optional[str], dict]:
"""Get with metrics tracking."""
start = time.time()
result, similarity = self.cache.get_with_similarity(query, embedding)
latency = (time.time() - start) * 1000
if result:
self.metrics.hits += 1
self.metrics.total_latency_saved_ms += self.avg_latency - latency
# Estimate tokens saved
estimated_tokens = len(query.split()) * 4 + len(result.split()) * 4
self.metrics.total_cost_saved += (estimated_tokens / 1000) * self.cost_per_1k
# Track similarity distribution
bucket = round(similarity, 1)
self.metrics.queries_by_similarity[bucket] += 1
else:
self.metrics.misses += 1
return result, {
"hit": result is not None,
"similarity": similarity,
"latency_ms": latency
}
def get_stats(self) -> dict:
"""Get comprehensive statistics."""
total = self.metrics.hits + self.metrics.misses
return {
"hit_rate": self.metrics.hits / total if total > 0 else 0,
"total_queries": total,
"hits": self.metrics.hits,
"misses": self.metrics.misses,
"latency_saved_ms": self.metrics.total_latency_saved_ms,
"cost_saved_usd": self.metrics.total_cost_saved,
"similarity_distribution": dict(self.metrics.queries_by_similarity)
}
def optimize_threshold(self, target_hit_rate: float = 0.3) -> float:
"""Suggest optimal similarity threshold."""
# Analyze similarity distribution
dist = self.metrics.queries_by_similarity
if not dist:
return 0.92 # Default
# Find threshold that achieves target hit rate
total = sum(dist.values())
cumulative = 0
for threshold in sorted(dist.keys(), reverse=True):
cumulative += dist[threshold]
if cumulative / total >= target_hit_rate:
return threshold
return min(dist.keys())
# Dashboard data
def generate_cache_report(cache: MonitoredCache) -> str:
"""Generate cache performance report."""
stats = cache.get_stats()
report = f"""
## Cache Performance Report
### Overview
- **Hit Rate**: {stats['hit_rate']:.1%}
- **Total Queries**: {stats['total_queries']:,}
- **Cache Hits**: {stats['hits']:,}
- **Cache Misses**: {stats['misses']:,}
### Savings
- **Latency Saved**: {stats['latency_saved_ms']/1000:.1f} seconds
- **Cost Saved**: ${stats['cost_saved_usd']:.2f}
### Similarity Distribution
"""
for threshold, count in sorted(stats['similarity_distribution'].items(), reverse=True):
report += f"- {threshold:.1f}: {count} queries\n"
return report
References
- GPTCache: https://github.com/zilliztech/GPTCache
- Redis Vector Search: https://redis.io/docs/stack/search/reference/vectors/
- LangChain Caching: https://python.langchain.com/docs/modules/model_io/llms/llm_caching
- OpenAI Embeddings: https://platform.openai.com/docs/guides/embeddings
Conclusion
Semantic caching is one of the highest-impact optimizations for LLM applications. By recognizing semantically similar queries, you can dramatically reduce API costs and latency while maintaining response quality. Start with a simple in-memory cache for development, then move to Redis with vector search for production. The key decisions are similarity threshold (too high misses opportunities, too low returns irrelevant results) and TTL (balance freshness vs. hit rate). Implement smart invalidation for time-sensitive queries, and monitor your cache performance to continuously optimize. A well-tuned semantic cache can achieve 30-50% hit rates for many applications, translating directly to cost savings and faster user experiences.
Discover more from Code, Cloud & Context
Subscribe to get the latest posts sent to your email.

Leave a Reply