Categories

Archives

A sample text widget

Etiam pulvinar consectetur dolor sed malesuada. Ut convallis euismod dolor nec pretium. Nunc ut tristique massa.

Nam sodales mi vitae dolor ullamcorper et vulputate enim accumsan. Morbi orci magna, tincidunt vitae molestie nec, molestie at mi. Nulla nulla lorem, suscipit in posuere in, interdum non magna.

Semantic Caching for LLM Applications: Cut Costs and Latency by 50%

Introduction: LLM API calls are expensive and slow. A single GPT-4 request can cost cents and take seconds—multiply that by thousands of users asking similar questions, and costs spiral quickly. Semantic caching solves this by recognizing that “What’s the weather in NYC?” and “Tell me NYC weather” are essentially the same query. Instead of exact string matching, semantic caching uses embeddings to find similar past queries and return cached responses. This guide covers implementing semantic caching from scratch, using libraries like GPTCache, and production patterns for cache invalidation and quality control.

Semantic Caching for LLMs
Semantic Caching: Reduce Costs and Latency

Basic Semantic Cache Implementation

import numpy as np
from openai import OpenAI
import hashlib
import json
from datetime import datetime, timedelta
from typing import Optional

client = OpenAI()

class SemanticCache:
    """Simple semantic cache using in-memory storage."""
    
    def __init__(
        self,
        similarity_threshold: float = 0.95,
        ttl_hours: int = 24,
        embedding_model: str = "text-embedding-3-small"
    ):
        self.threshold = similarity_threshold
        self.ttl = timedelta(hours=ttl_hours)
        self.embedding_model = embedding_model
        
        # Storage (use Redis/Pinecone in production)
        self.embeddings = []  # List of (embedding, query, response, timestamp)
        self.exact_cache = {}  # Exact match fallback
    
    def _get_embedding(self, text: str) -> np.ndarray:
        """Get embedding for text."""
        response = client.embeddings.create(
            input=text,
            model=self.embedding_model
        )
        return np.array(response.data[0].embedding)
    
    def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
        """Calculate cosine similarity between two vectors."""
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
    
    def _normalize_query(self, query: str) -> str:
        """Normalize query for exact matching."""
        return query.lower().strip()
    
    def get(self, query: str) -> Optional[str]:
        """Try to get cached response for query."""
        
        # Try exact match first (fast)
        normalized = self._normalize_query(query)
        if normalized in self.exact_cache:
            entry = self.exact_cache[normalized]
            if datetime.now() - entry["timestamp"] < self.ttl:
                return entry["response"]
        
        # Semantic search
        if not self.embeddings:
            return None
        
        query_embedding = self._get_embedding(query)
        
        best_match = None
        best_similarity = 0
        
        for embedding, cached_query, response, timestamp in self.embeddings:
            # Check TTL
            if datetime.now() - timestamp > self.ttl:
                continue
            
            similarity = self._cosine_similarity(query_embedding, embedding)
            
            if similarity > best_similarity:
                best_similarity = similarity
                best_match = (cached_query, response)
        
        if best_similarity >= self.threshold:
            return best_match[1]
        
        return None
    
    def set(self, query: str, response: str):
        """Cache a query-response pair."""
        
        # Exact cache
        normalized = self._normalize_query(query)
        self.exact_cache[normalized] = {
            "response": response,
            "timestamp": datetime.now()
        }
        
        # Semantic cache
        embedding = self._get_embedding(query)
        self.embeddings.append((
            embedding,
            query,
            response,
            datetime.now()
        ))
    
    def clear_expired(self):
        """Remove expired entries."""
        now = datetime.now()
        
        # Clear exact cache
        self.exact_cache = {
            k: v for k, v in self.exact_cache.items()
            if now - v["timestamp"] < self.ttl
        }
        
        # Clear semantic cache
        self.embeddings = [
            entry for entry in self.embeddings
            if now - entry[3] < self.ttl
        ]

# Usage
cache = SemanticCache(similarity_threshold=0.92)

def cached_completion(query: str) -> str:
    """Get completion with semantic caching."""
    
    # Check cache
    cached = cache.get(query)
    if cached:
        print("Cache hit!")
        return cached
    
    # Call LLM
    print("Cache miss - calling LLM")
    response = client.chat.completions.create(
        model="gpt-4-turbo-preview",
        messages=[{"role": "user", "content": query}]
    )
    
    result = response.choices[0].message.content
    
    # Cache the result
    cache.set(query, result)
    
    return result

# Test
print(cached_completion("What is machine learning?"))
print(cached_completion("Explain machine learning"))  # Should hit cache
print(cached_completion("What is ML?"))  # Should hit cache

Redis-Based Semantic Cache

import redis
import numpy as np
from redis.commands.search.field import VectorField, TextField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import Query
import json

class RedisSemanticCache:
    """Production semantic cache using Redis with vector search."""
    
    def __init__(
        self,
        redis_url: str = "redis://localhost:6379",
        index_name: str = "llm_cache",
        similarity_threshold: float = 0.92,
        ttl_seconds: int = 86400,
        vector_dim: int = 1536
    ):
        self.redis = redis.from_url(redis_url)
        self.index_name = index_name
        self.threshold = similarity_threshold
        self.ttl = ttl_seconds
        self.vector_dim = vector_dim
        
        self._create_index()
    
    def _create_index(self):
        """Create Redis search index for vectors."""
        try:
            self.redis.ft(self.index_name).info()
        except:
            # Create index
            schema = (
                TextField("query"),
                TextField("response"),
                VectorField(
                    "embedding",
                    "HNSW",
                    {
                        "TYPE": "FLOAT32",
                        "DIM": self.vector_dim,
                        "DISTANCE_METRIC": "COSINE"
                    }
                )
            )
            
            self.redis.ft(self.index_name).create_index(
                schema,
                definition=IndexDefinition(
                    prefix=["cache:"],
                    index_type=IndexType.HASH
                )
            )
    
    def get(self, query: str, query_embedding: np.ndarray) -> Optional[str]:
        """Search for similar cached query."""
        
        # Convert to bytes for Redis
        embedding_bytes = query_embedding.astype(np.float32).tobytes()
        
        # Vector similarity search
        q = (
            Query(f"*=>[KNN 1 @embedding $vec AS score]")
            .return_fields("query", "response", "score")
            .sort_by("score")
            .dialect(2)
        )
        
        results = self.redis.ft(self.index_name).search(
            q,
            query_params={"vec": embedding_bytes}
        )
        
        if results.docs:
            doc = results.docs[0]
            similarity = 1 - float(doc.score)  # Convert distance to similarity
            
            if similarity >= self.threshold:
                return doc.response
        
        return None
    
    def set(self, query: str, response: str, embedding: np.ndarray):
        """Cache query-response pair."""
        import uuid
        
        key = f"cache:{uuid.uuid4()}"
        
        self.redis.hset(
            key,
            mapping={
                "query": query,
                "response": response,
                "embedding": embedding.astype(np.float32).tobytes()
            }
        )
        
        # Set TTL
        self.redis.expire(key, self.ttl)
    
    def stats(self) -> dict:
        """Get cache statistics."""
        info = self.redis.ft(self.index_name).info()
        return {
            "num_docs": info["num_docs"],
            "index_size": info.get("inverted_sz_mb", 0)
        }

# Production wrapper
class CachedLLM:
    """LLM client with semantic caching."""
    
    def __init__(self, cache: RedisSemanticCache):
        self.cache = cache
        self.client = OpenAI()
        self.stats = {"hits": 0, "misses": 0}
    
    def _get_embedding(self, text: str) -> np.ndarray:
        response = self.client.embeddings.create(
            input=text,
            model="text-embedding-3-small"
        )
        return np.array(response.data[0].embedding)
    
    def complete(self, query: str, **kwargs) -> str:
        """Get completion with caching."""
        
        embedding = self._get_embedding(query)
        
        # Check cache
        cached = self.cache.get(query, embedding)
        if cached:
            self.stats["hits"] += 1
            return cached
        
        # Call LLM
        self.stats["misses"] += 1
        response = self.client.chat.completions.create(
            model=kwargs.get("model", "gpt-4-turbo-preview"),
            messages=[{"role": "user", "content": query}],
            **{k: v for k, v in kwargs.items() if k != "model"}
        )
        
        result = response.choices[0].message.content
        
        # Cache
        self.cache.set(query, result, embedding)
        
        return result
    
    def hit_rate(self) -> float:
        total = self.stats["hits"] + self.stats["misses"]
        return self.stats["hits"] / total if total > 0 else 0

Using GPTCache Library

# pip install gptcache

from gptcache import cache
from gptcache.adapter import openai
from gptcache.embedding import Onnx
from gptcache.manager import CacheBase, VectorBase, get_data_manager
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation

# Initialize embedding model
onnx = Onnx()

# Setup cache with SQLite + FAISS
data_manager = get_data_manager(
    CacheBase("sqlite"),
    VectorBase("faiss", dimension=onnx.dimension)
)

# Initialize cache
cache.init(
    embedding_func=onnx.to_embeddings,
    data_manager=data_manager,
    similarity_evaluation=SearchDistanceEvaluation()
)

# Set similarity threshold
cache.set_openai_key()

# Use cached OpenAI client
response = openai.ChatCompletion.create(
    model="gpt-4-turbo-preview",
    messages=[{"role": "user", "content": "What is Python?"}]
)

print(response.choices[0].message.content)

# Second call with similar query - should hit cache
response = openai.ChatCompletion.create(
    model="gpt-4-turbo-preview",
    messages=[{"role": "user", "content": "Explain Python programming"}]
)

# Advanced configuration
from gptcache.processor.pre import last_content
from gptcache.processor.post import first

cache.init(
    pre_embedding_func=last_content,  # Use last message for embedding
    embedding_func=onnx.to_embeddings,
    data_manager=data_manager,
    similarity_evaluation=SearchDistanceEvaluation(),
    post_process_messages_func=first  # Return first cached result
)

Cache Invalidation Strategies

from datetime import datetime, timedelta
from typing import Optional, Callable
import hashlib

class SmartCache:
    """Cache with intelligent invalidation strategies."""
    
    def __init__(self, base_cache):
        self.cache = base_cache
        self.invalidation_rules = []
    
    def add_invalidation_rule(self, rule: Callable[[str, str], bool]):
        """Add custom invalidation rule."""
        self.invalidation_rules.append(rule)
    
    def should_invalidate(self, query: str, cached_response: str) -> bool:
        """Check if cached response should be invalidated."""
        for rule in self.invalidation_rules:
            if rule(query, cached_response):
                return True
        return False
    
    def get(self, query: str, embedding) -> Optional[str]:
        """Get with invalidation check."""
        cached = self.cache.get(query, embedding)
        
        if cached and self.should_invalidate(query, cached):
            return None  # Force cache miss
        
        return cached

# Common invalidation rules

def time_sensitive_rule(query: str, response: str) -> bool:
    """Invalidate for time-sensitive queries."""
    time_keywords = ["today", "now", "current", "latest", "recent", "this week"]
    return any(kw in query.lower() for kw in time_keywords)

def stale_data_rule(query: str, response: str) -> bool:
    """Invalidate if response mentions outdated dates."""
    import re
    
    # Find years in response
    years = re.findall(r'\b(20\d{2})\b', response)
    current_year = datetime.now().year
    
    # If response mentions years more than 1 year old, might be stale
    for year in years:
        if int(year) < current_year - 1:
            return True
    
    return False

def confidence_rule(query: str, response: str) -> bool:
    """Invalidate low-confidence responses."""
    uncertainty_phrases = [
        "i'm not sure",
        "i don't know",
        "i cannot",
        "as of my knowledge",
        "i don't have access"
    ]
    return any(phrase in response.lower() for phrase in uncertainty_phrases)

# Usage
smart_cache = SmartCache(redis_cache)
smart_cache.add_invalidation_rule(time_sensitive_rule)
smart_cache.add_invalidation_rule(stale_data_rule)
smart_cache.add_invalidation_rule(confidence_rule)

# Context-aware caching
class ContextAwareCache:
    """Cache that considers conversation context."""
    
    def __init__(self, base_cache):
        self.cache = base_cache
    
    def _context_hash(self, messages: list[dict]) -> str:
        """Create hash of conversation context."""
        # Use last N messages for context
        context_messages = messages[-3:]
        context_str = json.dumps(context_messages, sort_keys=True)
        return hashlib.md5(context_str.encode()).hexdigest()[:8]
    
    def get(self, messages: list[dict], embedding) -> Optional[str]:
        """Get with context awareness."""
        # Combine query with context hash
        query = messages[-1]["content"]
        context_hash = self._context_hash(messages[:-1])
        
        cache_key = f"{context_hash}:{query}"
        return self.cache.get(cache_key, embedding)
    
    def set(self, messages: list[dict], response: str, embedding):
        """Set with context."""
        query = messages[-1]["content"]
        context_hash = self._context_hash(messages[:-1])
        
        cache_key = f"{context_hash}:{query}"
        self.cache.set(cache_key, response, embedding)

Monitoring and Optimization

from dataclasses import dataclass, field
from collections import defaultdict
import time

@dataclass
class CacheMetrics:
    """Track cache performance metrics."""
    hits: int = 0
    misses: int = 0
    total_latency_saved_ms: float = 0
    total_cost_saved: float = 0
    queries_by_similarity: dict = field(default_factory=lambda: defaultdict(int))

class MonitoredCache:
    """Cache with comprehensive monitoring."""
    
    def __init__(self, base_cache, avg_llm_latency_ms: float = 2000, cost_per_1k_tokens: float = 0.01):
        self.cache = base_cache
        self.metrics = CacheMetrics()
        self.avg_latency = avg_llm_latency_ms
        self.cost_per_1k = cost_per_1k_tokens
    
    def get(self, query: str, embedding) -> tuple[Optional[str], dict]:
        """Get with metrics tracking."""
        start = time.time()
        
        result, similarity = self.cache.get_with_similarity(query, embedding)
        
        latency = (time.time() - start) * 1000
        
        if result:
            self.metrics.hits += 1
            self.metrics.total_latency_saved_ms += self.avg_latency - latency
            
            # Estimate tokens saved
            estimated_tokens = len(query.split()) * 4 + len(result.split()) * 4
            self.metrics.total_cost_saved += (estimated_tokens / 1000) * self.cost_per_1k
            
            # Track similarity distribution
            bucket = round(similarity, 1)
            self.metrics.queries_by_similarity[bucket] += 1
        else:
            self.metrics.misses += 1
        
        return result, {
            "hit": result is not None,
            "similarity": similarity,
            "latency_ms": latency
        }
    
    def get_stats(self) -> dict:
        """Get comprehensive statistics."""
        total = self.metrics.hits + self.metrics.misses
        
        return {
            "hit_rate": self.metrics.hits / total if total > 0 else 0,
            "total_queries": total,
            "hits": self.metrics.hits,
            "misses": self.metrics.misses,
            "latency_saved_ms": self.metrics.total_latency_saved_ms,
            "cost_saved_usd": self.metrics.total_cost_saved,
            "similarity_distribution": dict(self.metrics.queries_by_similarity)
        }
    
    def optimize_threshold(self, target_hit_rate: float = 0.3) -> float:
        """Suggest optimal similarity threshold."""
        # Analyze similarity distribution
        dist = self.metrics.queries_by_similarity
        
        if not dist:
            return 0.92  # Default
        
        # Find threshold that achieves target hit rate
        total = sum(dist.values())
        cumulative = 0
        
        for threshold in sorted(dist.keys(), reverse=True):
            cumulative += dist[threshold]
            if cumulative / total >= target_hit_rate:
                return threshold
        
        return min(dist.keys())

# Dashboard data
def generate_cache_report(cache: MonitoredCache) -> str:
    """Generate cache performance report."""
    stats = cache.get_stats()
    
    report = f"""
## Cache Performance Report

### Overview
- **Hit Rate**: {stats['hit_rate']:.1%}
- **Total Queries**: {stats['total_queries']:,}
- **Cache Hits**: {stats['hits']:,}
- **Cache Misses**: {stats['misses']:,}

### Savings
- **Latency Saved**: {stats['latency_saved_ms']/1000:.1f} seconds
- **Cost Saved**: ${stats['cost_saved_usd']:.2f}

### Similarity Distribution
"""
    
    for threshold, count in sorted(stats['similarity_distribution'].items(), reverse=True):
        report += f"- {threshold:.1f}: {count} queries\n"
    
    return report

References

Conclusion

Semantic caching is one of the highest-impact optimizations for LLM applications. By recognizing semantically similar queries, you can dramatically reduce API costs and latency while maintaining response quality. Start with a simple in-memory cache for development, then move to Redis with vector search for production. The key decisions are similarity threshold (too high misses opportunities, too low returns irrelevant results) and TTL (balance freshness vs. hit rate). Implement smart invalidation for time-sensitive queries, and monitor your cache performance to continuously optimize. A well-tuned semantic cache can achieve 30-50% hit rates for many applications, translating directly to cost savings and faster user experiences.


Discover more from Code, Cloud & Context

Subscribe to get the latest posts sent to your email.

Leave a Reply

You can use these HTML tags

<a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <s> <strike> <strong>

  

  

  

This site uses Akismet to reduce spam. Learn how your comment data is processed.