Categories

Archives

A sample text widget

Etiam pulvinar consectetur dolor sed malesuada. Ut convallis euismod dolor nec pretium. Nunc ut tristique massa.

Nam sodales mi vitae dolor ullamcorper et vulputate enim accumsan. Morbi orci magna, tincidunt vitae molestie nec, molestie at mi. Nulla nulla lorem, suscipit in posuere in, interdum non magna.

LLM Inference Optimization: Caching, Batching, and Smart Routing

Introduction: LLM inference can be slow and expensive, especially at scale. Optimizing inference is crucial for production applications where latency and cost directly impact user experience and business viability. This guide covers practical optimization techniques: semantic caching to avoid redundant API calls, request batching for throughput, streaming for perceived latency, model quantization for self-hosted models, and architectural patterns that balance quality with speed. These techniques can reduce costs by 50-90% and cut latency significantly without sacrificing output quality.

Inference Optimization
Inference Optimization: Caching, Batching, and Quantization

Semantic Caching

from openai import OpenAI
import hashlib
import json
import numpy as np
from typing import Optional
from datetime import datetime, timedelta

client = OpenAI()

class SemanticCache:
    """Cache LLM responses based on semantic similarity."""
    
    def __init__(
        self,
        similarity_threshold: float = 0.95,
        ttl_hours: int = 24
    ):
        self.cache: dict[str, dict] = {}
        self.embeddings: dict[str, list[float]] = {}
        self.similarity_threshold = similarity_threshold
        self.ttl = timedelta(hours=ttl_hours)
    
    def _get_embedding(self, text: str) -> list[float]:
        """Get embedding for text."""
        response = client.embeddings.create(
            model="text-embedding-3-small",
            input=text
        )
        return response.data[0].embedding
    
    def _compute_similarity(self, emb1: list[float], emb2: list[float]) -> float:
        """Compute cosine similarity."""
        return np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
    
    def _is_expired(self, entry: dict) -> bool:
        """Check if cache entry is expired."""
        created = datetime.fromisoformat(entry["created_at"])
        return datetime.now() - created > self.ttl
    
    def get(self, prompt: str) -> Optional[str]:
        """Get cached response if similar prompt exists."""
        
        prompt_embedding = self._get_embedding(prompt)
        
        best_match = None
        best_similarity = 0
        
        for key, embedding in self.embeddings.items():
            similarity = self._compute_similarity(prompt_embedding, embedding)
            
            if similarity > best_similarity and similarity >= self.similarity_threshold:
                entry = self.cache.get(key)
                if entry and not self._is_expired(entry):
                    best_match = entry
                    best_similarity = similarity
        
        if best_match:
            return best_match["response"]
        
        return None
    
    def set(self, prompt: str, response: str):
        """Cache a response."""
        
        key = hashlib.md5(prompt.encode()).hexdigest()
        embedding = self._get_embedding(prompt)
        
        self.cache[key] = {
            "prompt": prompt,
            "response": response,
            "created_at": datetime.now().isoformat()
        }
        self.embeddings[key] = embedding
    
    def get_stats(self) -> dict:
        """Get cache statistics."""
        
        valid_entries = sum(
            1 for entry in self.cache.values()
            if not self._is_expired(entry)
        )
        
        return {
            "total_entries": len(self.cache),
            "valid_entries": valid_entries,
            "expired_entries": len(self.cache) - valid_entries
        }

class CachedLLM:
    """LLM client with semantic caching."""
    
    def __init__(self, cache: SemanticCache = None):
        self.cache = cache or SemanticCache()
        self.stats = {"hits": 0, "misses": 0}
    
    def complete(
        self,
        prompt: str,
        model: str = "gpt-4o-mini",
        use_cache: bool = True
    ) -> str:
        """Get completion with caching."""
        
        if use_cache:
            cached = self.cache.get(prompt)
            if cached:
                self.stats["hits"] += 1
                return cached
        
        self.stats["misses"] += 1
        
        response = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": prompt}]
        )
        
        result = response.choices[0].message.content
        
        if use_cache:
            self.cache.set(prompt, result)
        
        return result

# Usage
llm = CachedLLM()

# First call - cache miss
response1 = llm.complete("What is machine learning?")

# Similar query - cache hit
response2 = llm.complete("Can you explain machine learning?")

print(f"Cache stats: {llm.stats}")

Request Batching

import asyncio
from dataclasses import dataclass
from typing import Callable
import time

@dataclass
class BatchRequest:
    prompt: str
    future: asyncio.Future
    created_at: float

class BatchProcessor:
    """Batch multiple requests for efficient processing."""
    
    def __init__(
        self,
        max_batch_size: int = 10,
        max_wait_ms: int = 100
    ):
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
        self.queue: list[BatchRequest] = []
        self.lock = asyncio.Lock()
        self.processing = False
    
    async def add_request(self, prompt: str) -> str:
        """Add a request to the batch queue."""
        
        future = asyncio.Future()
        request = BatchRequest(
            prompt=prompt,
            future=future,
            created_at=time.time()
        )
        
        async with self.lock:
            self.queue.append(request)
            
            # Start processing if batch is full
            if len(self.queue) >= self.max_batch_size:
                asyncio.create_task(self._process_batch())
            elif not self.processing:
                # Schedule processing after max_wait
                asyncio.create_task(self._delayed_process())
        
        return await future
    
    async def _delayed_process(self):
        """Process batch after delay."""
        
        await asyncio.sleep(self.max_wait_ms / 1000)
        await self._process_batch()
    
    async def _process_batch(self):
        """Process all queued requests."""
        
        async with self.lock:
            if not self.queue or self.processing:
                return
            
            self.processing = True
            batch = self.queue[:self.max_batch_size]
            self.queue = self.queue[self.max_batch_size:]
        
        try:
            # Process batch (in real implementation, use batch API)
            results = await self._call_llm_batch([r.prompt for r in batch])
            
            for request, result in zip(batch, results):
                request.future.set_result(result)
                
        except Exception as e:
            for request in batch:
                request.future.set_exception(e)
        
        finally:
            self.processing = False
            
            # Process remaining if any
            if self.queue:
                asyncio.create_task(self._process_batch())
    
    async def _call_llm_batch(self, prompts: list[str]) -> list[str]:
        """Call LLM for batch of prompts."""
        
        # Use asyncio.gather for parallel processing
        tasks = [
            self._call_single(prompt)
            for prompt in prompts
        ]
        
        return await asyncio.gather(*tasks)
    
    async def _call_single(self, prompt: str) -> str:
        """Call LLM for single prompt."""
        
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}]
        )
        
        return response.choices[0].message.content

# Usage
async def main():
    batcher = BatchProcessor(max_batch_size=5, max_wait_ms=50)
    
    # Submit multiple requests
    tasks = [
        batcher.add_request(f"What is {topic}?")
        for topic in ["Python", "JavaScript", "Rust", "Go", "Java"]
    ]
    
    results = await asyncio.gather(*tasks)
    
    for topic, result in zip(["Python", "JavaScript", "Rust", "Go", "Java"], results):
        print(f"{topic}: {result[:50]}...")

# asyncio.run(main())

Streaming for Perceived Latency

from typing import Generator, AsyncGenerator

def stream_completion(
    prompt: str,
    model: str = "gpt-4o-mini"
) -> Generator[str, None, None]:
    """Stream completion tokens."""
    
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        stream=True
    )
    
    for chunk in response:
        if chunk.choices[0].delta.content:
            yield chunk.choices[0].delta.content

async def async_stream_completion(
    prompt: str,
    model: str = "gpt-4o-mini"
) -> AsyncGenerator[str, None]:
    """Async stream completion tokens."""
    
    response = await client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        stream=True
    )
    
    async for chunk in response:
        if chunk.choices[0].delta.content:
            yield chunk.choices[0].delta.content

class StreamingBuffer:
    """Buffer streaming output for processing."""
    
    def __init__(self):
        self.buffer = ""
        self.complete_sentences: list[str] = []
    
    def add_chunk(self, chunk: str) -> list[str]:
        """Add chunk and return complete sentences."""
        
        self.buffer += chunk
        
        # Check for sentence boundaries
        new_sentences = []
        
        while True:
            # Find sentence end
            for end_char in [". ", "! ", "? ", ".\n", "!\n", "?\n"]:
                idx = self.buffer.find(end_char)
                if idx != -1:
                    sentence = self.buffer[:idx + 1].strip()
                    self.buffer = self.buffer[idx + len(end_char):]
                    new_sentences.append(sentence)
                    break
            else:
                break
        
        self.complete_sentences.extend(new_sentences)
        return new_sentences
    
    def flush(self) -> str:
        """Get remaining buffer content."""
        remaining = self.buffer.strip()
        self.buffer = ""
        return remaining

# Usage with FastAPI
from fastapi import FastAPI
from fastapi.responses import StreamingResponse

app = FastAPI()

@app.get("/stream")
async def stream_response(prompt: str):
    """Stream LLM response."""
    
    async def generate():
        async for chunk in async_stream_completion(prompt):
            yield f"data: {chunk}\n\n"
        yield "data: [DONE]\n\n"
    
    return StreamingResponse(
        generate(),
        media_type="text/event-stream"
    )

Model Selection and Routing

from enum import Enum
from dataclasses import dataclass

class TaskComplexity(str, Enum):
    SIMPLE = "simple"
    MODERATE = "moderate"
    COMPLEX = "complex"

@dataclass
class ModelConfig:
    name: str
    cost_per_1k_input: float
    cost_per_1k_output: float
    avg_latency_ms: float
    max_tokens: int

class ModelRouter:
    """Route requests to appropriate models based on complexity."""
    
    def __init__(self):
        self.models = {
            TaskComplexity.SIMPLE: ModelConfig(
                name="gpt-4o-mini",
                cost_per_1k_input=0.00015,
                cost_per_1k_output=0.0006,
                avg_latency_ms=500,
                max_tokens=16384
            ),
            TaskComplexity.MODERATE: ModelConfig(
                name="gpt-4o",
                cost_per_1k_input=0.0025,
                cost_per_1k_output=0.01,
                avg_latency_ms=1000,
                max_tokens=128000
            ),
            TaskComplexity.COMPLEX: ModelConfig(
                name="gpt-4o",
                cost_per_1k_input=0.0025,
                cost_per_1k_output=0.01,
                avg_latency_ms=1500,
                max_tokens=128000
            )
        }
    
    def classify_complexity(self, prompt: str) -> TaskComplexity:
        """Classify task complexity."""
        
        # Simple heuristics
        word_count = len(prompt.split())
        
        complex_indicators = [
            "analyze", "compare", "evaluate", "synthesize",
            "explain in detail", "step by step", "comprehensive"
        ]
        
        simple_indicators = [
            "what is", "define", "list", "name",
            "yes or no", "true or false"
        ]
        
        prompt_lower = prompt.lower()
        
        if any(ind in prompt_lower for ind in simple_indicators) and word_count < 50:
            return TaskComplexity.SIMPLE
        
        if any(ind in prompt_lower for ind in complex_indicators) or word_count > 200:
            return TaskComplexity.COMPLEX
        
        return TaskComplexity.MODERATE
    
    def route(self, prompt: str) -> ModelConfig:
        """Route to appropriate model."""
        
        complexity = self.classify_complexity(prompt)
        return self.models[complexity]
    
    def complete(self, prompt: str) -> tuple[str, dict]:
        """Complete with automatic routing."""
        
        config = self.route(prompt)
        
        response = client.chat.completions.create(
            model=config.name,
            messages=[{"role": "user", "content": prompt}]
        )
        
        result = response.choices[0].message.content
        
        # Calculate cost
        input_tokens = response.usage.prompt_tokens
        output_tokens = response.usage.completion_tokens
        
        cost = (
            (input_tokens / 1000) * config.cost_per_1k_input +
            (output_tokens / 1000) * config.cost_per_1k_output
        )
        
        metadata = {
            "model": config.name,
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
            "cost": cost
        }
        
        return result, metadata

# Usage
router = ModelRouter()

# Simple query -> gpt-4o-mini
result, meta = router.complete("What is Python?")
print(f"Model: {meta['model']}, Cost: ${meta['cost']:.6f}")

# Complex query -> gpt-4o
result, meta = router.complete(
    "Analyze the trade-offs between microservices and monolithic architectures, "
    "considering scalability, maintainability, and operational complexity."
)
print(f"Model: {meta['model']}, Cost: ${meta['cost']:.6f}")

Parallel Processing

import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Callable

class ParallelLLM:
    """Process multiple LLM calls in parallel."""
    
    def __init__(self, max_concurrent: int = 10):
        self.semaphore = asyncio.Semaphore(max_concurrent)
        self.executor = ThreadPoolExecutor(max_workers=max_concurrent)
    
    async def _call_with_semaphore(
        self,
        prompt: str,
        model: str = "gpt-4o-mini"
    ) -> str:
        """Call LLM with concurrency limit."""
        
        async with self.semaphore:
            response = client.chat.completions.create(
                model=model,
                messages=[{"role": "user", "content": prompt}]
            )
            return response.choices[0].message.content
    
    async def batch_complete(
        self,
        prompts: list[str],
        model: str = "gpt-4o-mini"
    ) -> list[str]:
        """Complete multiple prompts in parallel."""
        
        tasks = [
            self._call_with_semaphore(prompt, model)
            for prompt in prompts
        ]
        
        return await asyncio.gather(*tasks)
    
    async def map_reduce(
        self,
        items: list[str],
        map_prompt_fn: Callable[[str], str],
        reduce_prompt: str,
        model: str = "gpt-4o-mini"
    ) -> str:
        """Map-reduce pattern for processing large datasets."""
        
        # Map phase: process items in parallel
        map_prompts = [map_prompt_fn(item) for item in items]
        map_results = await self.batch_complete(map_prompts, model)
        
        # Reduce phase: combine results
        combined = "\n\n".join([
            f"Item {i+1} result:\n{result}"
            for i, result in enumerate(map_results)
        ])
        
        final_prompt = f"{reduce_prompt}\n\nResults to combine:\n{combined}"
        
        response = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": final_prompt}]
        )
        
        return response.choices[0].message.content

# Usage
async def main():
    parallel = ParallelLLM(max_concurrent=5)
    
    # Batch processing
    prompts = [f"Summarize the concept of {topic}" for topic in [
        "machine learning", "deep learning", "neural networks",
        "natural language processing", "computer vision"
    ]]
    
    results = await parallel.batch_complete(prompts)
    
    # Map-reduce for document analysis
    documents = ["Doc 1 content...", "Doc 2 content...", "Doc 3 content..."]
    
    summary = await parallel.map_reduce(
        items=documents,
        map_prompt_fn=lambda doc: f"Extract key points from: {doc}",
        reduce_prompt="Combine these key points into a unified summary:"
    )
    
    print(summary)

# asyncio.run(main())

Production Optimization Service

from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
from typing import Optional
import time

app = FastAPI()

# Initialize components
cache = SemanticCache(similarity_threshold=0.92)
router = ModelRouter()

class CompletionRequest(BaseModel):
    prompt: str
    model: Optional[str] = None  # Auto-route if not specified
    use_cache: bool = True
    stream: bool = False

class CompletionResponse(BaseModel):
    content: str
    model_used: str
    cached: bool
    latency_ms: float
    tokens: dict
    cost: float

@app.post("/complete", response_model=CompletionResponse)
async def complete(request: CompletionRequest):
    """Optimized completion endpoint."""
    
    start = time.time()
    cached = False
    
    # Check cache first
    if request.use_cache:
        cached_response = cache.get(request.prompt)
        if cached_response:
            return CompletionResponse(
                content=cached_response,
                model_used="cache",
                cached=True,
                latency_ms=(time.time() - start) * 1000,
                tokens={"input": 0, "output": 0},
                cost=0.0
            )
    
    # Route to model
    if request.model:
        model = request.model
        config = router.models.get(TaskComplexity.MODERATE)
    else:
        config = router.route(request.prompt)
        model = config.name
    
    # Call LLM
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": request.prompt}]
    )
    
    content = response.choices[0].message.content
    
    # Cache response
    if request.use_cache:
        cache.set(request.prompt, content)
    
    # Calculate metrics
    input_tokens = response.usage.prompt_tokens
    output_tokens = response.usage.completion_tokens
    cost = (
        (input_tokens / 1000) * config.cost_per_1k_input +
        (output_tokens / 1000) * config.cost_per_1k_output
    )
    
    return CompletionResponse(
        content=content,
        model_used=model,
        cached=False,
        latency_ms=(time.time() - start) * 1000,
        tokens={"input": input_tokens, "output": output_tokens},
        cost=cost
    )

@app.get("/cache/stats")
async def cache_stats():
    """Get cache statistics."""
    return cache.get_stats()

@app.post("/cache/clear")
async def clear_cache():
    """Clear the cache."""
    cache.cache.clear()
    cache.embeddings.clear()
    return {"cleared": True}

References

Conclusion

LLM inference optimization is essential for production applications. Semantic caching eliminates redundant API calls for similar queries—even 30% cache hit rate significantly reduces costs. Request batching improves throughput for high-volume applications. Streaming reduces perceived latency by showing results as they generate. Smart model routing uses cheaper models for simple tasks and reserves expensive models for complex queries. Parallel processing accelerates batch workloads. Combine these techniques based on your specific requirements: latency-sensitive applications benefit most from caching and streaming, while batch processing benefits from batching and parallelization.


Discover more from Code, Cloud & Context

Subscribe to get the latest posts sent to your email.

Leave a Reply

You can use these HTML tags

<a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <s> <strike> <strong>

  

  

  

This site uses Akismet to reduce spam. Learn how your comment data is processed.