Categories

Archives

A sample text widget

Etiam pulvinar consectetur dolor sed malesuada. Ut convallis euismod dolor nec pretium. Nunc ut tristique massa.

Nam sodales mi vitae dolor ullamcorper et vulputate enim accumsan. Morbi orci magna, tincidunt vitae molestie nec, molestie at mi. Nulla nulla lorem, suscipit in posuere in, interdum non magna.

Inference Optimization Patterns: Maximizing LLM Throughput and Efficiency

Introduction: LLM inference is expensive—both in compute and latency. Every token generated requires a forward pass through billions of parameters, and users expect responses in seconds, not minutes. Inference optimization techniques reduce costs and improve responsiveness without sacrificing output quality. This guide covers practical optimization strategies: batching requests to maximize GPU utilization, managing KV caches to avoid redundant computation, quantizing models to reduce memory footprint, speculative decoding to accelerate generation, and continuous batching for production throughput. Whether you’re serving a chatbot to millions of users or running inference on edge devices, these techniques determine whether your LLM application is viable at scale.

Inference Optimization
Inference Optimization: Request Batching, KV Cache Management, Model Quantization

Request Batching

from dataclasses import dataclass, field
from typing import Any, Optional, Callable
from datetime import datetime
import asyncio
from collections import deque

@dataclass
class InferenceRequest:
    """A single inference request."""
    
    id: str
    prompt: str
    max_tokens: int = 100
    temperature: float = 0.7
    created_at: datetime = field(default_factory=datetime.utcnow)
    future: asyncio.Future = None

@dataclass
class InferenceResult:
    """Result of inference."""
    
    request_id: str
    text: str
    tokens_generated: int
    latency_ms: float

class DynamicBatcher:
    """Dynamic request batching for inference."""
    
    def __init__(
        self,
        model: Any,
        max_batch_size: int = 32,
        max_wait_ms: float = 50.0,
        max_tokens_per_batch: int = 4096
    ):
        self.model = model
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
        self.max_tokens_per_batch = max_tokens_per_batch
        
        self._queue: deque[InferenceRequest] = deque()
        self._lock = asyncio.Lock()
        self._batch_event = asyncio.Event()
        self._running = False
    
    async def start(self):
        """Start the batcher."""
        self._running = True
        asyncio.create_task(self._batch_loop())
    
    async def stop(self):
        """Stop the batcher."""
        self._running = False
        self._batch_event.set()
    
    async def infer(self, request: InferenceRequest) -> InferenceResult:
        """Submit request and wait for result."""
        
        request.future = asyncio.get_event_loop().create_future()
        
        async with self._lock:
            self._queue.append(request)
            self._batch_event.set()
        
        return await request.future
    
    async def _batch_loop(self):
        """Main batching loop."""
        
        while self._running:
            await self._batch_event.wait()
            self._batch_event.clear()
            
            # Wait for more requests or timeout
            await asyncio.sleep(self.max_wait_ms / 1000)
            
            # Collect batch
            batch = await self._collect_batch()
            
            if batch:
                # Process batch
                results = await self._process_batch(batch)
                
                # Deliver results
                for request, result in zip(batch, results):
                    if not request.future.done():
                        request.future.set_result(result)
    
    async def _collect_batch(self) -> list[InferenceRequest]:
        """Collect requests into a batch."""
        
        batch = []
        total_tokens = 0
        
        async with self._lock:
            while self._queue and len(batch) < self.max_batch_size:
                request = self._queue[0]
                
                # Estimate tokens
                prompt_tokens = len(request.prompt.split()) * 1.3
                request_tokens = prompt_tokens + request.max_tokens
                
                if total_tokens + request_tokens > self.max_tokens_per_batch:
                    break
                
                batch.append(self._queue.popleft())
                total_tokens += request_tokens
        
        return batch
    
    async def _process_batch(
        self,
        batch: list[InferenceRequest]
    ) -> list[InferenceResult]:
        """Process a batch of requests."""
        
        start_time = datetime.utcnow()
        
        # Prepare batch inputs
        prompts = [r.prompt for r in batch]
        max_tokens = max(r.max_tokens for r in batch)
        
        # Run inference
        outputs = await self.model.generate_batch(
            prompts,
            max_tokens=max_tokens
        )
        
        end_time = datetime.utcnow()
        latency = (end_time - start_time).total_seconds() * 1000
        
        # Create results
        results = []
        for request, output in zip(batch, outputs):
            results.append(InferenceResult(
                request_id=request.id,
                text=output,
                tokens_generated=len(output.split()),
                latency_ms=latency
            ))
        
        return results

class ContinuousBatcher:
    """Continuous batching for streaming inference."""
    
    def __init__(
        self,
        model: Any,
        max_batch_size: int = 64,
        max_sequence_length: int = 2048
    ):
        self.model = model
        self.max_batch_size = max_batch_size
        self.max_sequence_length = max_sequence_length
        
        self._active_requests: dict[str, InferenceRequest] = {}
        self._pending_queue: deque[InferenceRequest] = deque()
        self._lock = asyncio.Lock()
    
    async def add_request(self, request: InferenceRequest):
        """Add request to processing."""
        
        async with self._lock:
            if len(self._active_requests) < self.max_batch_size:
                self._active_requests[request.id] = request
            else:
                self._pending_queue.append(request)
    
    async def step(self) -> dict[str, str]:
        """Generate one token for all active requests."""
        
        if not self._active_requests:
            return {}
        
        # Get current state of all active requests
        requests = list(self._active_requests.values())
        
        # Generate next token for batch
        next_tokens = await self.model.generate_next_token_batch(
            [r.prompt for r in requests]
        )
        
        results = {}
        completed = []
        
        for request, token in zip(requests, next_tokens):
            # Update prompt with new token
            request.prompt += token
            results[request.id] = token
            
            # Check completion
            if self._is_complete(request, token):
                completed.append(request.id)
        
        # Remove completed requests
        async with self._lock:
            for req_id in completed:
                del self._active_requests[req_id]
                
                # Add pending request
                if self._pending_queue:
                    new_request = self._pending_queue.popleft()
                    self._active_requests[new_request.id] = new_request
        
        return results
    
    def _is_complete(self, request: InferenceRequest, token: str) -> bool:
        """Check if request is complete."""
        
        # Check for EOS token
        if token in ["<|endoftext|>", "", "<|end|>"]:
            return True
        
        # Check max tokens
        generated = len(request.prompt.split()) - len(request.prompt.split())
        if generated >= request.max_tokens:
            return True
        
        return False

KV Cache Management

from dataclasses import dataclass
from typing import Any, Optional
import torch

@dataclass
class KVCacheEntry:
    """A KV cache entry."""
    
    key: torch.Tensor
    value: torch.Tensor
    sequence_length: int
    last_accessed: float

class KVCacheManager:
    """Manage KV caches for inference."""
    
    def __init__(
        self,
        num_layers: int,
        num_heads: int,
        head_dim: int,
        max_cache_size_gb: float = 8.0,
        device: str = "cuda"
    ):
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.max_cache_size = int(max_cache_size_gb * 1e9)
        self.device = device
        
        self._caches: dict[str, list[KVCacheEntry]] = {}
        self._current_size = 0
    
    def get_cache(self, request_id: str) -> Optional[list[KVCacheEntry]]:
        """Get KV cache for request."""
        
        cache = self._caches.get(request_id)
        if cache:
            import time
            for entry in cache:
                entry.last_accessed = time.time()
        return cache
    
    def set_cache(
        self,
        request_id: str,
        keys: list[torch.Tensor],
        values: list[torch.Tensor],
        seq_len: int
    ):
        """Set KV cache for request."""
        
        import time
        
        # Calculate size
        cache_size = sum(
            k.numel() * k.element_size() + v.numel() * v.element_size()
            for k, v in zip(keys, values)
        )
        
        # Evict if necessary
        while self._current_size + cache_size > self.max_cache_size:
            if not self._evict_oldest():
                break
        
        # Store cache
        entries = []
        for k, v in zip(keys, values):
            entries.append(KVCacheEntry(
                key=k,
                value=v,
                sequence_length=seq_len,
                last_accessed=time.time()
            ))
        
        self._caches[request_id] = entries
        self._current_size += cache_size
    
    def extend_cache(
        self,
        request_id: str,
        new_keys: list[torch.Tensor],
        new_values: list[torch.Tensor]
    ):
        """Extend existing cache with new tokens."""
        
        cache = self._caches.get(request_id)
        if not cache:
            return
        
        for i, (entry, new_k, new_v) in enumerate(zip(cache, new_keys, new_values)):
            # Concatenate along sequence dimension
            entry.key = torch.cat([entry.key, new_k], dim=-2)
            entry.value = torch.cat([entry.value, new_v], dim=-2)
            entry.sequence_length += new_k.shape[-2]
    
    def delete_cache(self, request_id: str):
        """Delete cache for request."""
        
        if request_id in self._caches:
            cache = self._caches[request_id]
            cache_size = sum(
                e.key.numel() * e.key.element_size() +
                e.value.numel() * e.value.element_size()
                for e in cache
            )
            del self._caches[request_id]
            self._current_size -= cache_size
    
    def _evict_oldest(self) -> bool:
        """Evict oldest cache entry."""
        
        if not self._caches:
            return False
        
        # Find oldest
        oldest_id = min(
            self._caches.keys(),
            key=lambda k: min(e.last_accessed for e in self._caches[k])
        )
        
        self.delete_cache(oldest_id)
        return True

class PagedKVCache:
    """Paged attention KV cache (vLLM style)."""
    
    def __init__(
        self,
        num_layers: int,
        num_heads: int,
        head_dim: int,
        block_size: int = 16,
        num_blocks: int = 1000,
        device: str = "cuda"
    ):
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.block_size = block_size
        self.num_blocks = num_blocks
        self.device = device
        
        # Pre-allocate block pool
        self.key_blocks = torch.zeros(
            num_blocks, num_layers, block_size, num_heads, head_dim,
            device=device, dtype=torch.float16
        )
        self.value_blocks = torch.zeros(
            num_blocks, num_layers, block_size, num_heads, head_dim,
            device=device, dtype=torch.float16
        )
        
        # Block allocation tracking
        self.free_blocks = list(range(num_blocks))
        self.sequence_blocks: dict[str, list[int]] = {}
    
    def allocate_blocks(self, request_id: str, num_tokens: int) -> list[int]:
        """Allocate blocks for a sequence."""
        
        num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
        
        if len(self.free_blocks) < num_blocks_needed:
            raise RuntimeError("Out of KV cache blocks")
        
        allocated = []
        for _ in range(num_blocks_needed):
            block_id = self.free_blocks.pop()
            allocated.append(block_id)
        
        self.sequence_blocks[request_id] = allocated
        return allocated
    
    def free_blocks_for_sequence(self, request_id: str):
        """Free blocks for a sequence."""
        
        if request_id in self.sequence_blocks:
            blocks = self.sequence_blocks.pop(request_id)
            self.free_blocks.extend(blocks)
    
    def write_kv(
        self,
        request_id: str,
        layer_idx: int,
        position: int,
        key: torch.Tensor,
        value: torch.Tensor
    ):
        """Write KV to cache."""
        
        blocks = self.sequence_blocks.get(request_id, [])
        block_idx = position // self.block_size
        offset = position % self.block_size
        
        if block_idx >= len(blocks):
            # Need more blocks
            new_block = self.free_blocks.pop()
            blocks.append(new_block)
            self.sequence_blocks[request_id] = blocks
        
        block_id = blocks[block_idx]
        self.key_blocks[block_id, layer_idx, offset] = key
        self.value_blocks[block_id, layer_idx, offset] = value
    
    def read_kv(
        self,
        request_id: str,
        layer_idx: int,
        positions: list[int]
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Read KV from cache."""
        
        blocks = self.sequence_blocks.get(request_id, [])
        
        keys = []
        values = []
        
        for pos in positions:
            block_idx = pos // self.block_size
            offset = pos % self.block_size
            
            if block_idx < len(blocks):
                block_id = blocks[block_idx]
                keys.append(self.key_blocks[block_id, layer_idx, offset])
                values.append(self.value_blocks[block_id, layer_idx, offset])
        
        return torch.stack(keys), torch.stack(values)

Model Quantization

from dataclasses import dataclass
from typing import Any, Optional
from enum import Enum
import torch
import torch.nn as nn

class QuantizationType(Enum):
    """Quantization types."""
    
    INT8 = "int8"
    INT4 = "int4"
    FP8 = "fp8"
    NF4 = "nf4"  # Normal float 4-bit

@dataclass
class QuantizationConfig:
    """Quantization configuration."""
    
    quant_type: QuantizationType
    group_size: int = 128
    use_double_quant: bool = False
    compute_dtype: torch.dtype = torch.float16

class Int8Quantizer:
    """INT8 quantization."""
    
    def __init__(self, symmetric: bool = True):
        self.symmetric = symmetric
    
    def quantize(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Quantize tensor to INT8."""
        
        if self.symmetric:
            # Symmetric quantization
            scale = tensor.abs().max() / 127
            quantized = (tensor / scale).round().clamp(-128, 127).to(torch.int8)
            return quantized, scale
        else:
            # Asymmetric quantization
            min_val = tensor.min()
            max_val = tensor.max()
            scale = (max_val - min_val) / 255
            zero_point = (-min_val / scale).round()
            quantized = ((tensor / scale) + zero_point).round().clamp(0, 255).to(torch.uint8)
            return quantized, torch.tensor([scale, zero_point])
    
    def dequantize(
        self,
        quantized: torch.Tensor,
        scale: torch.Tensor
    ) -> torch.Tensor:
        """Dequantize INT8 to float."""
        
        if self.symmetric:
            return quantized.float() * scale
        else:
            scale_val, zero_point = scale[0], scale[1]
            return (quantized.float() - zero_point) * scale_val

class Int4Quantizer:
    """INT4 quantization with grouping."""
    
    def __init__(self, group_size: int = 128):
        self.group_size = group_size
    
    def quantize(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Quantize tensor to INT4."""
        
        # Reshape for grouping
        original_shape = tensor.shape
        tensor = tensor.reshape(-1, self.group_size)
        
        # Compute scales per group
        scales = tensor.abs().max(dim=1, keepdim=True).values / 7
        scales = scales.clamp(min=1e-8)
        
        # Quantize
        quantized = (tensor / scales).round().clamp(-8, 7).to(torch.int8)
        
        # Pack two INT4 values into one INT8
        packed = self._pack_int4(quantized)
        
        return packed, scales.reshape(-1)
    
    def _pack_int4(self, tensor: torch.Tensor) -> torch.Tensor:
        """Pack INT4 values into INT8."""
        
        # Ensure even number of elements
        flat = tensor.reshape(-1)
        if flat.shape[0] % 2 != 0:
            flat = torch.cat([flat, torch.zeros(1, dtype=flat.dtype)])
        
        # Pack pairs
        low = flat[0::2] & 0x0F
        high = (flat[1::2] & 0x0F) << 4
        packed = (low | high).to(torch.uint8)
        
        return packed
    
    def dequantize(
        self,
        packed: torch.Tensor,
        scales: torch.Tensor,
        original_shape: tuple
    ) -> torch.Tensor:
        """Dequantize INT4 to float."""
        
        # Unpack
        unpacked = self._unpack_int4(packed)
        
        # Reshape and apply scales
        unpacked = unpacked.reshape(-1, self.group_size).float()
        scales = scales.reshape(-1, 1)
        
        dequantized = unpacked * scales
        
        return dequantized.reshape(original_shape)
    
    def _unpack_int4(self, packed: torch.Tensor) -> torch.Tensor:
        """Unpack INT4 values from INT8."""
        
        low = (packed & 0x0F).to(torch.int8)
        high = ((packed >> 4) & 0x0F).to(torch.int8)
        
        # Sign extend
        low = torch.where(low > 7, low - 16, low)
        high = torch.where(high > 7, high - 16, high)
        
        # Interleave
        unpacked = torch.stack([low, high], dim=-1).reshape(-1)
        
        return unpacked

class QuantizedLinear(nn.Module):
    """Quantized linear layer."""
    
    def __init__(
        self,
        in_features: int,
        out_features: int,
        config: QuantizationConfig
    ):
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.config = config
        
        # Quantized weights (stored as INT8 or packed INT4)
        self.register_buffer("weight_quantized", None)
        self.register_buffer("weight_scales", None)
        self.register_buffer("bias", None)
    
    @classmethod
    def from_linear(
        cls,
        linear: nn.Linear,
        config: QuantizationConfig
    ) -> "QuantizedLinear":
        """Create from existing linear layer."""
        
        layer = cls(linear.in_features, linear.out_features, config)
        
        # Quantize weights
        if config.quant_type == QuantizationType.INT8:
            quantizer = Int8Quantizer()
            quantized, scales = quantizer.quantize(linear.weight.data)
        elif config.quant_type == QuantizationType.INT4:
            quantizer = Int4Quantizer(config.group_size)
            quantized, scales = quantizer.quantize(linear.weight.data)
        else:
            raise ValueError(f"Unsupported quantization: {config.quant_type}")
        
        layer.weight_quantized = quantized
        layer.weight_scales = scales
        
        if linear.bias is not None:
            layer.bias = linear.bias.data.clone()
        
        return layer
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with dequantization."""
        
        # Dequantize weights
        if self.config.quant_type == QuantizationType.INT8:
            quantizer = Int8Quantizer()
            weight = quantizer.dequantize(
                self.weight_quantized,
                self.weight_scales
            )
        elif self.config.quant_type == QuantizationType.INT4:
            quantizer = Int4Quantizer(self.config.group_size)
            weight = quantizer.dequantize(
                self.weight_quantized,
                self.weight_scales,
                (self.out_features, self.in_features)
            )
        else:
            weight = self.weight_quantized.float()
        
        # Compute
        output = torch.nn.functional.linear(x, weight, self.bias)
        
        return output

Speculative Decoding

from dataclasses import dataclass
from typing import Any, Optional
import torch

@dataclass
class SpeculativeConfig:
    """Speculative decoding configuration."""
    
    num_speculative_tokens: int = 4
    temperature: float = 0.0
    top_p: float = 1.0

class SpeculativeDecoder:
    """Speculative decoding for faster inference."""
    
    def __init__(
        self,
        target_model: Any,
        draft_model: Any,
        config: SpeculativeConfig
    ):
        self.target_model = target_model
        self.draft_model = draft_model
        self.config = config
    
    async def generate(
        self,
        prompt: str,
        max_tokens: int
    ) -> str:
        """Generate with speculative decoding."""
        
        generated_tokens = []
        current_prompt = prompt
        
        while len(generated_tokens) < max_tokens:
            # Draft phase: generate speculative tokens
            draft_tokens = await self._draft_phase(current_prompt)
            
            # Verify phase: check with target model
            accepted_tokens = await self._verify_phase(
                current_prompt,
                draft_tokens
            )
            
            # Update state
            generated_tokens.extend(accepted_tokens)
            current_prompt += "".join(accepted_tokens)
            
            # Check for EOS
            if any(t in ["<|endoftext|>", ""] for t in accepted_tokens):
                break
        
        return "".join(generated_tokens)
    
    async def _draft_phase(self, prompt: str) -> list[str]:
        """Generate draft tokens with small model."""
        
        draft_tokens = []
        current = prompt
        
        for _ in range(self.config.num_speculative_tokens):
            token = await self.draft_model.generate_next_token(
                current,
                temperature=self.config.temperature
            )
            draft_tokens.append(token)
            current += token
        
        return draft_tokens
    
    async def _verify_phase(
        self,
        prompt: str,
        draft_tokens: list[str]
    ) -> list[str]:
        """Verify draft tokens with target model."""
        
        # Get target model probabilities for all positions
        full_sequence = prompt + "".join(draft_tokens)
        target_probs = await self.target_model.get_token_probabilities(
            full_sequence
        )
        
        # Get draft model probabilities
        draft_probs = await self.draft_model.get_token_probabilities(
            full_sequence
        )
        
        accepted = []
        
        for i, token in enumerate(draft_tokens):
            # Get probabilities at position
            p_target = target_probs[len(prompt) + i].get(token, 0)
            p_draft = draft_probs[len(prompt) + i].get(token, 0)
            
            # Acceptance criterion
            if p_draft > 0:
                acceptance_prob = min(1, p_target / p_draft)
            else:
                acceptance_prob = 1 if p_target > 0 else 0
            
            # Accept or reject
            import random
            if random.random() < acceptance_prob:
                accepted.append(token)
            else:
                # Sample from adjusted distribution
                adjusted_token = self._sample_adjusted(
                    target_probs[len(prompt) + i],
                    draft_probs[len(prompt) + i]
                )
                accepted.append(adjusted_token)
                break  # Stop at first rejection
        
        return accepted
    
    def _sample_adjusted(
        self,
        target_probs: dict[str, float],
        draft_probs: dict[str, float]
    ) -> str:
        """Sample from adjusted distribution."""
        
        # Compute adjusted probabilities
        adjusted = {}
        for token in target_probs:
            p_t = target_probs.get(token, 0)
            p_d = draft_probs.get(token, 0)
            adjusted[token] = max(0, p_t - p_d)
        
        # Normalize
        total = sum(adjusted.values())
        if total > 0:
            adjusted = {k: v / total for k, v in adjusted.items()}
        else:
            adjusted = target_probs
        
        # Sample
        import random
        r = random.random()
        cumsum = 0
        for token, prob in adjusted.items():
            cumsum += prob
            if r < cumsum:
                return token
        
        return list(adjusted.keys())[0]

class MedusaDecoder:
    """Medusa-style parallel decoding."""
    
    def __init__(
        self,
        model: Any,
        num_heads: int = 4,
        num_candidates: int = 5
    ):
        self.model = model
        self.num_heads = num_heads
        self.num_candidates = num_candidates
    
    async def generate(
        self,
        prompt: str,
        max_tokens: int
    ) -> str:
        """Generate with Medusa decoding."""
        
        generated = []
        current = prompt
        
        while len(generated) < max_tokens:
            # Generate candidates from all heads
            candidates = await self._generate_candidates(current)
            
            # Verify candidates
            accepted = await self._verify_candidates(current, candidates)
            
            generated.extend(accepted)
            current += "".join(accepted)
            
            if not accepted:
                # Fall back to single token
                token = await self.model.generate_next_token(current)
                generated.append(token)
                current += token
        
        return "".join(generated)
    
    async def _generate_candidates(
        self,
        prompt: str
    ) -> list[list[str]]:
        """Generate candidate sequences from Medusa heads."""
        
        # Each head predicts tokens at different positions
        candidates = []
        
        for head_idx in range(self.num_heads):
            head_candidates = await self.model.medusa_head_predict(
                prompt,
                head_idx,
                self.num_candidates
            )
            candidates.append(head_candidates)
        
        return candidates
    
    async def _verify_candidates(
        self,
        prompt: str,
        candidates: list[list[str]]
    ) -> list[str]:
        """Verify candidate sequences."""
        
        # Build candidate trees
        trees = self._build_candidate_trees(candidates)
        
        # Verify with single forward pass
        best_sequence = []
        
        for tree in trees:
            verified = await self.model.verify_sequence(prompt, tree)
            if len(verified) > len(best_sequence):
                best_sequence = verified
        
        return best_sequence
    
    def _build_candidate_trees(
        self,
        candidates: list[list[str]]
    ) -> list[list[str]]:
        """Build candidate trees from head predictions."""
        
        # Simplified: just return top candidates
        trees = []
        for head_candidates in candidates:
            for candidate in head_candidates[:self.num_candidates]:
                trees.append([candidate])
        return trees

Production Inference Service

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
from enum import Enum
import asyncio

app = FastAPI()

# Initialize components
batcher = None  # Would initialize DynamicBatcher
kv_cache = None  # Would initialize KVCacheManager

class OptimizationLevel(str, Enum):
    NONE = "none"
    BASIC = "basic"  # Batching only
    FULL = "full"  # Batching + KV cache + quantization

class GenerateRequest(BaseModel):
    prompt: str
    max_tokens: int = 100
    temperature: float = 0.7
    stream: bool = False
    optimization: OptimizationLevel = OptimizationLevel.FULL

class GenerateResponse(BaseModel):
    text: str
    tokens_generated: int
    latency_ms: float
    optimization_used: str

class BatchGenerateRequest(BaseModel):
    prompts: list[str]
    max_tokens: int = 100
    temperature: float = 0.7

class BatchGenerateResponse(BaseModel):
    results: list[GenerateResponse]
    total_latency_ms: float
    throughput_tokens_per_sec: float

@app.post("/v1/generate")
async def generate(request: GenerateRequest) -> GenerateResponse:
    """Generate text with optimization."""
    
    import time
    import uuid
    
    start = time.time()
    
    # Create inference request
    req = InferenceRequest(
        id=str(uuid.uuid4()),
        prompt=request.prompt,
        max_tokens=request.max_tokens,
        temperature=request.temperature
    )
    
    # Submit to batcher
    if batcher:
        result = await batcher.infer(req)
        text = result.text
        tokens = result.tokens_generated
    else:
        # Direct inference (placeholder)
        text = "Generated text placeholder"
        tokens = 10
    
    latency = (time.time() - start) * 1000
    
    return GenerateResponse(
        text=text,
        tokens_generated=tokens,
        latency_ms=latency,
        optimization_used=request.optimization.value
    )

@app.post("/v1/generate/batch")
async def batch_generate(request: BatchGenerateRequest) -> BatchGenerateResponse:
    """Batch generate text."""
    
    import time
    import uuid
    
    start = time.time()
    
    # Create requests
    requests = [
        InferenceRequest(
            id=str(uuid.uuid4()),
            prompt=prompt,
            max_tokens=request.max_tokens,
            temperature=request.temperature
        )
        for prompt in request.prompts
    ]
    
    # Submit all
    if batcher:
        tasks = [batcher.infer(req) for req in requests]
        results = await asyncio.gather(*tasks)
    else:
        results = [
            InferenceResult(
                request_id=req.id,
                text="Placeholder",
                tokens_generated=10,
                latency_ms=100
            )
            for req in requests
        ]
    
    total_latency = (time.time() - start) * 1000
    total_tokens = sum(r.tokens_generated for r in results)
    throughput = total_tokens / (total_latency / 1000) if total_latency > 0 else 0
    
    return BatchGenerateResponse(
        results=[
            GenerateResponse(
                text=r.text,
                tokens_generated=r.tokens_generated,
                latency_ms=r.latency_ms,
                optimization_used="batch"
            )
            for r in results
        ],
        total_latency_ms=total_latency,
        throughput_tokens_per_sec=throughput
    )

class CacheStatsResponse(BaseModel):
    total_entries: int
    total_size_mb: float
    hit_rate: float

@app.get("/v1/cache/stats")
async def cache_stats() -> CacheStatsResponse:
    """Get KV cache statistics."""
    
    if kv_cache:
        return CacheStatsResponse(
            total_entries=len(kv_cache._caches),
            total_size_mb=kv_cache._current_size / 1e6,
            hit_rate=0.85  # Would track actual hit rate
        )
    
    return CacheStatsResponse(
        total_entries=0,
        total_size_mb=0,
        hit_rate=0
    )

@app.delete("/v1/cache")
async def clear_cache():
    """Clear KV cache."""
    
    if kv_cache:
        kv_cache._caches.clear()
        kv_cache._current_size = 0
    
    return {"status": "cleared"}

class ModelInfoResponse(BaseModel):
    model_name: str
    quantization: Optional[str]
    batch_size: int
    max_sequence_length: int

@app.get("/v1/model/info")
async def model_info() -> ModelInfoResponse:
    """Get model information."""
    
    return ModelInfoResponse(
        model_name="llama-7b",
        quantization="int4",
        batch_size=32,
        max_sequence_length=4096
    )

@app.get("/health")
async def health():
    return {"status": "healthy"}

@app.get("/metrics")
async def metrics():
    """Get inference metrics."""
    
    return {
        "requests_total": 1000,
        "requests_per_second": 50,
        "avg_latency_ms": 150,
        "p99_latency_ms": 500,
        "tokens_per_second": 2000,
        "batch_utilization": 0.75,
        "cache_hit_rate": 0.85
    }

References

Conclusion

Inference optimization determines whether LLM applications are economically viable at scale. Start with dynamic batching—grouping requests together maximizes GPU utilization and amortizes the cost of loading model weights. Implement KV caching to avoid recomputing attention for previously processed tokens; paged attention (vLLM-style) enables efficient memory management for variable-length sequences. Quantize models to INT8 or INT4 to reduce memory footprint and increase throughput, with minimal quality degradation for most tasks. Use speculative decoding with a small draft model to generate multiple tokens per forward pass of the large model. Continuous batching handles streaming scenarios where requests arrive and complete at different times. Monitor key metrics: tokens per second, batch utilization, cache hit rate, and latency percentiles. The key insight is that inference optimization is about maximizing useful computation per GPU cycle—every optimization technique reduces waste, whether it's redundant computation (KV cache), memory bandwidth (quantization), or idle time (batching). Combine these techniques for multiplicative improvements in throughput and cost efficiency.