LLM Inference Optimization: KV Cache, Quantization, and Speculative Decoding (Part 2 of 2)

Introduction: LLM inference optimization is the art of making models respond faster while using fewer resources. As LLMs grow larger and usage scales, the difference between naive and optimized inference can mean 10x cost reduction and sub-second latencies instead of multi-second waits. This guide covers the techniques that matter most: KV cache optimization to avoid redundant computation, dynamic batching to maximize GPU utilization, quantization to reduce memory footprint without sacrificing quality, speculative decoding to accelerate autoregressive generation, and continuous batching for high-throughput serving. Whether you’re serving a 7B model on a single GPU or a 70B model across a cluster, these patterns will help you achieve production-grade performance.

📚
SERIES: LLM Inference Optimization (Part 2 of 2)

Building on Part 1’s application-level optimizations, this article covers model-level techniques for maximizing inference performance.

  • Part 1: Application-level – caching, batching, routing, streaming
  • Part 2 (this article): Model-level – KV cache, quantization, speculative decoding
LLM Inference Optimization
Inference Optimization: Dynamic Batching, KV Cache, Quantization

Model-Level Optimization Architecture

Model-level optimizations operate at the inference engine layer, improving performance without changing application code. This diagram shows how these techniques integrate into the inference pipeline.

flowchart TB
    subgraph Input["Input Processing"]
        T[Tokenizer]
        E[Embeddings]
    end
    
    subgraph KVCache["KV Cache Layer"]
        KC[KV Cache Manager]
        PM[PagedAttention]
        PR[Prefix Cache]
    end
    
    subgraph Model["Model Layer"]
        Q[Quantized Weights
INT8/INT4] ATT[Attention Blocks] FFN[Feed-Forward] end subgraph Decode["Decoding"] SD[Speculative Decoder] DM[Draft Model] VM[Verify Model] end subgraph Output["Output"] DT[Detokenizer] ST[Streaming] end T --> E E --> KC KC --> PM PM --> PR PR --> ATT Q --> ATT ATT --> FFN FFN --> SD SD --> DM DM --> VM VM --> DT DT --> ST style T fill:#E3F2FD,stroke:#90CAF9,stroke-width:2px,color:#1565C0 style E fill:#E3F2FD,stroke:#90CAF9,stroke-width:2px,color:#1565C0 style KC fill:#F3E5F5,stroke:#CE93D8,stroke-width:2px,color:#6A1B9A style PM fill:#F3E5F5,stroke:#CE93D8,stroke-width:2px,color:#6A1B9A style PR fill:#F3E5F5,stroke:#CE93D8,stroke-width:2px,color:#6A1B9A style Q fill:#E8F5E9,stroke:#A5D6A7,stroke-width:2px,color:#2E7D32 style ATT fill:#E0F2F1,stroke:#80CBC4,stroke-width:2px,color:#00695C style FFN fill:#E0F2F1,stroke:#80CBC4,stroke-width:2px,color:#00695C style SD fill:#FFF3E0,stroke:#FFCC80,stroke-width:2px,color:#E65100 style DM fill:#FFF3E0,stroke:#FFCC80,stroke-width:2px,color:#E65100 style VM fill:#FFF3E0,stroke:#FFCC80,stroke-width:2px,color:#E65100 style DT fill:#ECEFF1,stroke:#90A4AE,stroke-width:2px,color:#455A64 style ST fill:#ECEFF1,stroke:#90A4AE,stroke-width:2px,color:#455A64

Figure 1: Model-Level Optimization Pipeline

KV Cache Optimization

The Key-Value cache stores attention computations from previous tokens, avoiding redundant calculations. For long conversations or documents, KV cache management is critical—improper handling leads to memory bloat and cache misses that destroy performance.

from dataclasses import dataclass, field
from typing import Any, Optional, Tuple
import torch
import numpy as np

@dataclass
class KVCacheConfig:
    """Configuration for KV cache."""
    
    num_layers: int
    num_heads: int
    head_dim: int
    max_seq_len: int
    dtype: torch.dtype = torch.float16

class KVCache:
    """Key-Value cache for transformer inference."""
    
    def __init__(self, config: KVCacheConfig, batch_size: int = 1):
        self.config = config
        self.batch_size = batch_size
        
        # Pre-allocate cache tensors
        cache_shape = (
            config.num_layers,
            2,  # key and value
            batch_size,
            config.num_heads,
            config.max_seq_len,
            config.head_dim
        )
        
        self.cache = torch.zeros(cache_shape, dtype=config.dtype)
        self.seq_len = 0
    
    def update(
        self,
        layer_idx: int,
        key: torch.Tensor,
        value: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Update cache and return full key/value."""
        
        seq_len = key.shape[2]
        
        # Store new key/value
        self.cache[layer_idx, 0, :, :, self.seq_len:self.seq_len + seq_len, :] = key
        self.cache[layer_idx, 1, :, :, self.seq_len:self.seq_len + seq_len, :] = value
        
        # Return full cached key/value
        full_key = self.cache[layer_idx, 0, :, :, :self.seq_len + seq_len, :]
        full_value = self.cache[layer_idx, 1, :, :, :self.seq_len + seq_len, :]
        
        return full_key, full_value
    
    def advance(self, num_tokens: int = 1):
        """Advance sequence position."""
        self.seq_len += num_tokens
    
    def reset(self):
        """Reset cache."""
        self.cache.zero_()
        self.seq_len = 0

class PagedKVCache:
    """Paged attention KV cache for efficient memory."""
    
    def __init__(
        self,
        config: KVCacheConfig,
        page_size: int = 16,
        num_pages: int = 1024
    ):
        self.config = config
        self.page_size = page_size
        self.num_pages = num_pages
        
        # Page pool
        page_shape = (
            num_pages,
            2,  # key and value
            config.num_layers,
            config.num_heads,
            page_size,
            config.head_dim
        )
        
        self.page_pool = torch.zeros(page_shape, dtype=config.dtype)
        self.free_pages = list(range(num_pages))
        
        # Sequence to page mapping
        self.seq_pages: dict[int, list[int]] = {}
    
    def allocate_sequence(self, seq_id: int) -> bool:
        """Allocate first page for sequence."""
        
        if not self.free_pages:
            return False
        
        page_idx = self.free_pages.pop()
        self.seq_pages[seq_id] = [page_idx]
        return True
    
    def extend_sequence(self, seq_id: int) -> bool:
        """Allocate additional page for sequence."""
        
        if not self.free_pages:
            return False
        
        page_idx = self.free_pages.pop()
        self.seq_pages[seq_id].append(page_idx)
        return True
    
    def free_sequence(self, seq_id: int):
        """Free all pages for sequence."""
        
        if seq_id in self.seq_pages:
            self.free_pages.extend(self.seq_pages[seq_id])
            del self.seq_pages[seq_id]
    
    def get_cache(
        self,
        seq_id: int,
        layer_idx: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get cached key/value for sequence."""
        
        pages = self.seq_pages.get(seq_id, [])
        
        if not pages:
            return None, None
        
        # Gather from pages
        keys = []
        values = []
        
        for page_idx in pages:
            keys.append(self.page_pool[page_idx, 0, layer_idx])
            values.append(self.page_pool[page_idx, 1, layer_idx])
        
        return torch.cat(keys, dim=1), torch.cat(values, dim=1)
    
    def update_cache(
        self,
        seq_id: int,
        layer_idx: int,
        position: int,
        key: torch.Tensor,
        value: torch.Tensor
    ):
        """Update cache at position."""
        
        page_idx = position // self.page_size
        offset = position % self.page_size
        
        # Extend if needed
        while len(self.seq_pages[seq_id]) <= page_idx:
            if not self.extend_sequence(seq_id):
                raise RuntimeError("Out of cache pages")
        
        actual_page = self.seq_pages[seq_id][page_idx]
        
        self.page_pool[actual_page, 0, layer_idx, :, offset, :] = key
        self.page_pool[actual_page, 1, layer_idx, :, offset, :] = value

class SlidingWindowCache:
    """Sliding window attention cache."""
    
    def __init__(
        self,
        config: KVCacheConfig,
        window_size: int = 4096
    ):
        self.config = config
        self.window_size = window_size
        
        # Circular buffer
        cache_shape = (
            config.num_layers,
            2,
            config.num_heads,
            window_size,
            config.head_dim
        )
        
        self.cache = torch.zeros(cache_shape, dtype=config.dtype)
        self.position = 0
    
    def update(
        self,
        layer_idx: int,
        key: torch.Tensor,
        value: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Update with circular buffer."""
        
        idx = self.position % self.window_size
        
        self.cache[layer_idx, 0, :, idx, :] = key.squeeze(2)
        self.cache[layer_idx, 1, :, idx, :] = value.squeeze(2)
        
        # Return window
        if self.position < self.window_size:
            return (
                self.cache[layer_idx, 0, :, :self.position + 1, :],
                self.cache[layer_idx, 1, :, :self.position + 1, :]
            )
        
        # Reorder for correct sequence
        start = (self.position + 1) % self.window_size
        indices = [(start + i) % self.window_size for i in range(self.window_size)]
        
        return (
            self.cache[layer_idx, 0, :, indices, :],
            self.cache[layer_idx, 1, :, indices, :]
        )
    
    def advance(self):
        """Advance position."""
        self.position += 1

Dynamic Batching

Continuous batching dynamically groups requests as they arrive, maximizing GPU utilization without forcing requests to wait for a batch to fill. This is essential for production inference servers handling variable traffic.

from dataclasses import dataclass, field
from typing import Any, Optional, Callable
from collections import deque
import asyncio
import time

@dataclass
class InferenceRequest:
    """A single inference request."""
    
    id: str
    prompt: str
    max_tokens: int = 100
    temperature: float = 1.0
    created_at: float = field(default_factory=time.time)
    future: asyncio.Future = None

@dataclass
class BatchConfig:
    """Configuration for dynamic batching."""
    
    max_batch_size: int = 32
    max_wait_time_ms: float = 50.0
    max_tokens_per_batch: int = 4096

class DynamicBatcher:
    """Dynamic batching for inference requests."""
    
    def __init__(
        self,
        config: BatchConfig,
        inference_fn: Callable
    ):
        self.config = config
        self.inference_fn = inference_fn
        self.queue: deque[InferenceRequest] = deque()
        self.running = False
    
    async def submit(self, request: InferenceRequest) -> str:
        """Submit request and wait for result."""
        
        request.future = asyncio.Future()
        self.queue.append(request)
        
        return await request.future
    
    async def start(self):
        """Start batch processing loop."""
        
        self.running = True
        
        while self.running:
            batch = await self._collect_batch()
            
            if batch:
                await self._process_batch(batch)
            else:
                await asyncio.sleep(0.001)
    
    async def _collect_batch(self) -> list[InferenceRequest]:
        """Collect requests into a batch."""
        
        if not self.queue:
            return []
        
        batch = []
        total_tokens = 0
        start_time = time.time()
        
        while self.queue:
            # Check batch limits
            if len(batch) >= self.config.max_batch_size:
                break
            
            request = self.queue[0]
            request_tokens = len(request.prompt.split()) + request.max_tokens
            
            if total_tokens + request_tokens > self.config.max_tokens_per_batch:
                break
            
            # Check wait time
            elapsed_ms = (time.time() - start_time) * 1000
            if elapsed_ms > self.config.max_wait_time_ms and batch:
                break
            
            batch.append(self.queue.popleft())
            total_tokens += request_tokens
        
        return batch
    
    async def _process_batch(self, batch: list[InferenceRequest]):
        """Process a batch of requests."""
        
        try:
            # Prepare batch inputs
            prompts = [r.prompt for r in batch]
            
            # Run inference
            results = await self.inference_fn(prompts)
            
            # Distribute results
            for request, result in zip(batch, results):
                request.future.set_result(result)
        
        except Exception as e:
            for request in batch:
                request.future.set_exception(e)
    
    def stop(self):
        """Stop batch processing."""
        self.running = False

class ContinuousBatcher:
    """Continuous batching for streaming generation."""
    
    def __init__(
        self,
        max_batch_size: int = 32,
        inference_fn: Callable = None
    ):
        self.max_batch_size = max_batch_size
        self.inference_fn = inference_fn
        
        # Active sequences
        self.active: dict[str, dict] = {}
        self.pending: deque[InferenceRequest] = deque()
    
    async def submit(self, request: InferenceRequest):
        """Submit request for continuous batching."""
        
        request.future = asyncio.Future()
        self.pending.append(request)
        return request.future
    
    async def step(self):
        """Execute one generation step."""
        
        # Add pending requests if space available
        while self.pending and len(self.active) < self.max_batch_size:
            request = self.pending.popleft()
            self.active[request.id] = {
                "request": request,
                "tokens": [],
                "position": 0
            }
        
        if not self.active:
            return
        
        # Prepare batch
        batch_ids = list(self.active.keys())
        
        # Run one step of generation
        next_tokens = await self._generate_step(batch_ids)
        
        # Update sequences
        completed = []
        
        for seq_id, token in zip(batch_ids, next_tokens):
            seq = self.active[seq_id]
            seq["tokens"].append(token)
            seq["position"] += 1
            
            # Check completion
            if self._is_complete(seq):
                completed.append(seq_id)
        
        # Complete finished sequences
        for seq_id in completed:
            seq = self.active.pop(seq_id)
            result = self._decode_tokens(seq["tokens"])
            seq["request"].future.set_result(result)
    
    async def _generate_step(self, batch_ids: list[str]) -> list[int]:
        """Generate next token for each sequence."""
        
        # Would call actual model inference
        return [0] * len(batch_ids)  # Placeholder
    
    def _is_complete(self, seq: dict) -> bool:
        """Check if sequence is complete."""
        
        request = seq["request"]
        
        # Check max tokens
        if len(seq["tokens"]) >= request.max_tokens:
            return True
        
        # Check EOS token
        if seq["tokens"] and seq["tokens"][-1] == 2:  # EOS
            return True
        
        return False
    
    def _decode_tokens(self, tokens: list[int]) -> str:
        """Decode tokens to text."""
        return " ".join(str(t) for t in tokens)  # Placeholder

class PriorityBatcher:
    """Priority-based batching."""
    
    def __init__(self, max_batch_size: int = 32):
        self.max_batch_size = max_batch_size
        self.queues: dict[int, deque] = {
            0: deque(),  # High priority
            1: deque(),  # Normal priority
            2: deque()   # Low priority
        }
    
    def submit(self, request: InferenceRequest, priority: int = 1):
        """Submit with priority."""
        
        request.future = asyncio.Future()
        self.queues[priority].append(request)
        return request.future
    
    def collect_batch(self) -> list[InferenceRequest]:
        """Collect batch respecting priorities."""
        
        batch = []
        
        # Collect from high to low priority
        for priority in sorted(self.queues.keys()):
            queue = self.queues[priority]
            
            while queue and len(batch) < self.max_batch_size:
                batch.append(queue.popleft())
        
        return batch

Quantization Techniques

Quantization reduces model precision from float32 to int8 or int4, dramatically shrinking memory requirements and improving throughput. Modern quantization methods like GPTQ and AWQ maintain quality while cutting inference costs by 50-75%.

Quantization Precision Comparison

Different quantization levels trade off memory savings against quality. This diagram shows the memory reduction and typical quality impact for each precision level.

flowchart LR
    subgraph FP32["FP32 (Baseline)"]
        F1[32 bits per weight]
        F2[100% Memory]
        F3[100% Quality]
    end
    
    subgraph FP16["FP16"]
        H1[16 bits per weight]
        H2[50% Memory]
        H3[~99.9% Quality]
    end
    
    subgraph INT8["INT8"]
        I1[8 bits per weight]
        I2[25% Memory]
        I3[~99% Quality]
    end
    
    subgraph INT4["INT4 (GPTQ/AWQ)"]
        Q1[4 bits per weight]
        Q2[12.5% Memory]
        Q3[~97% Quality]
    end
    
    FP32 --> FP16 --> INT8 --> INT4
    
    style F1 fill:#FCE4EC,stroke:#F48FB1,stroke-width:2px,color:#AD1457
    style F2 fill:#FCE4EC,stroke:#F48FB1,stroke-width:2px,color:#AD1457
    style F3 fill:#FCE4EC,stroke:#F48FB1,stroke-width:2px,color:#AD1457
    style H1 fill:#FFF3E0,stroke:#FFCC80,stroke-width:2px,color:#E65100
    style H2 fill:#FFF3E0,stroke:#FFCC80,stroke-width:2px,color:#E65100
    style H3 fill:#FFF3E0,stroke:#FFCC80,stroke-width:2px,color:#E65100
    style I1 fill:#E3F2FD,stroke:#90CAF9,stroke-width:2px,color:#1565C0
    style I2 fill:#E3F2FD,stroke:#90CAF9,stroke-width:2px,color:#1565C0
    style I3 fill:#E3F2FD,stroke:#90CAF9,stroke-width:2px,color:#1565C0
    style Q1 fill:#E8F5E9,stroke:#A5D6A7,stroke-width:2px,color:#2E7D32
    style Q2 fill:#E8F5E9,stroke:#A5D6A7,stroke-width:2px,color:#2E7D32
    style Q3 fill:#E8F5E9,stroke:#A5D6A7,stroke-width:2px,color:#2E7D32

Figure 3: Quantization Precision Trade-offs

from dataclasses import dataclass
from typing import Any, Optional
import torch
import numpy as np

@dataclass
class QuantizationConfig:
    """Configuration for quantization."""
    
    bits: int = 4
    group_size: int = 128
    symmetric: bool = False
    compute_dtype: torch.dtype = torch.float16

class WeightQuantizer:
    """Quantize model weights."""
    
    def __init__(self, config: QuantizationConfig):
        self.config = config
    
    def quantize(self, weight: torch.Tensor) -> dict:
        """Quantize weight tensor."""
        
        if self.config.symmetric:
            return self._quantize_symmetric(weight)
        else:
            return self._quantize_asymmetric(weight)
    
    def _quantize_symmetric(self, weight: torch.Tensor) -> dict:
        """Symmetric quantization."""
        
        # Reshape for group quantization
        original_shape = weight.shape
        weight = weight.reshape(-1, self.config.group_size)
        
        # Calculate scale
        max_val = weight.abs().max(dim=1, keepdim=True)[0]
        scale = max_val / (2 ** (self.config.bits - 1) - 1)
        
        # Quantize
        quantized = torch.round(weight / scale).to(torch.int8)
        
        return {
            "quantized": quantized,
            "scale": scale,
            "zero_point": None,
            "original_shape": original_shape
        }
    
    def _quantize_asymmetric(self, weight: torch.Tensor) -> dict:
        """Asymmetric quantization."""
        
        original_shape = weight.shape
        weight = weight.reshape(-1, self.config.group_size)
        
        # Calculate scale and zero point
        min_val = weight.min(dim=1, keepdim=True)[0]
        max_val = weight.max(dim=1, keepdim=True)[0]
        
        scale = (max_val - min_val) / (2 ** self.config.bits - 1)
        zero_point = torch.round(-min_val / scale)
        
        # Quantize
        quantized = torch.round(weight / scale + zero_point).to(torch.uint8)
        
        return {
            "quantized": quantized,
            "scale": scale,
            "zero_point": zero_point,
            "original_shape": original_shape
        }
    
    def dequantize(self, quantized_data: dict) -> torch.Tensor:
        """Dequantize to original precision."""
        
        quantized = quantized_data["quantized"].to(self.config.compute_dtype)
        scale = quantized_data["scale"]
        zero_point = quantized_data["zero_point"]
        original_shape = quantized_data["original_shape"]
        
        if zero_point is not None:
            dequantized = (quantized - zero_point) * scale
        else:
            dequantized = quantized * scale
        
        return dequantized.reshape(original_shape)

class AWQQuantizer:
    """Activation-aware Weight Quantization."""
    
    def __init__(self, bits: int = 4, group_size: int = 128):
        self.bits = bits
        self.group_size = group_size
    
    def quantize_layer(
        self,
        weight: torch.Tensor,
        activations: torch.Tensor
    ) -> dict:
        """Quantize with activation awareness."""
        
        # Calculate activation scales
        act_scales = activations.abs().mean(dim=0)
        
        # Find optimal scaling factors
        best_scale = self._search_scale(weight, act_scales)
        
        # Apply scaling and quantize
        scaled_weight = weight * best_scale
        
        # Standard quantization
        quantizer = WeightQuantizer(QuantizationConfig(
            bits=self.bits,
            group_size=self.group_size
        ))
        
        quantized = quantizer.quantize(scaled_weight)
        quantized["act_scale"] = best_scale
        
        return quantized
    
    def _search_scale(
        self,
        weight: torch.Tensor,
        act_scales: torch.Tensor,
        n_grid: int = 20
    ) -> torch.Tensor:
        """Search for optimal scaling."""
        
        best_error = float('inf')
        best_scale = torch.ones_like(act_scales)
        
        for ratio in np.linspace(0, 1, n_grid):
            scale = act_scales.pow(ratio)
            
            # Simulate quantization error
            scaled = weight * scale
            quantized = torch.round(scaled)
            dequantized = quantized / scale
            
            error = (weight - dequantized).pow(2).mean()
            
            if error < best_error:
                best_error = error
                best_scale = scale
        
        return best_scale

class GPTQQuantizer:
    """GPTQ quantization."""
    
    def __init__(
        self,
        bits: int = 4,
        group_size: int = 128,
        damp_percent: float = 0.01
    ):
        self.bits = bits
        self.group_size = group_size
        self.damp_percent = damp_percent
    
    def quantize_layer(
        self,
        weight: torch.Tensor,
        hessian: torch.Tensor
    ) -> dict:
        """Quantize using GPTQ algorithm."""
        
        W = weight.clone()
        H = hessian.clone()
        
        # Add damping
        damp = self.damp_percent * torch.diag(H).mean()
        H += damp * torch.eye(H.shape[0], device=H.device)
        
        # Cholesky decomposition
        H_inv = torch.linalg.cholesky(H)
        H_inv = torch.cholesky_inverse(H_inv)
        
        # Quantize column by column
        Q = torch.zeros_like(W)
        
        for i in range(W.shape[1]):
            w = W[:, i]
            d = H_inv[i, i]
            
            # Quantize
            q = self._quantize_column(w)
            Q[:, i] = q
            
            # Update remaining weights
            err = (w - q) / d
            W[:, i:] -= err.unsqueeze(1) * H_inv[i, i:].unsqueeze(0)
        
        return {
            "quantized": Q,
            "bits": self.bits
        }
    
    def _quantize_column(self, w: torch.Tensor) -> torch.Tensor:
        """Quantize a single column."""
        
        max_val = w.abs().max()
        scale = max_val / (2 ** (self.bits - 1) - 1)
        
        quantized = torch.round(w / scale) * scale
        return quantized

Speculative Decoding

Speculative decoding uses a small, fast "draft" model to predict multiple tokens, then verifies them with the main model in a single forward pass. When the draft model predicts correctly (which happens often for common patterns), you get free speedups of 2-3x.

Speculative Decoding Flow

Speculative decoding uses a small, fast draft model to predict multiple tokens, then verifies them with the main model in a single forward pass. When predictions are correct, we get significant speedups.

sequenceDiagram
    participant Input
    participant Draft as Draft Model
(68M params) participant Main as Main Model
(7B params) participant Output Input->>Draft: Prompt loop Generate K tokens Draft->>Draft: Fast decode end Draft->>Main: [tok1, tok2, tok3, tok4, tok5] Note over Main: Single forward pass
verifies all tokens Main->>Main: Compute probabilities alt All tokens accepted Main-->>Output: [tok1, tok2, tok3, tok4, tok5] ✓ Note over Output: 5x speedup! else Rejection at tok3 Main-->>Output: [tok1, tok2, new_tok3] Note over Output: 3x speedup end

Figure 2: Speculative Decoding Sequence

from dataclasses import dataclass
from typing import Any, Optional, Tuple
import torch

@dataclass
class SpeculativeConfig:
    """Configuration for speculative decoding."""
    
    num_speculative_tokens: int = 4
    temperature: float = 1.0

class SpeculativeDecoder:
    """Speculative decoding for faster generation."""
    
    def __init__(
        self,
        target_model: Any,
        draft_model: Any,
        config: SpeculativeConfig
    ):
        self.target = target_model
        self.draft = draft_model
        self.config = config
    
    def generate(
        self,
        input_ids: torch.Tensor,
        max_tokens: int = 100
    ) -> torch.Tensor:
        """Generate with speculative decoding."""
        
        generated = input_ids.clone()
        
        while generated.shape[1] < input_ids.shape[1] + max_tokens:
            # Draft phase: generate speculative tokens
            draft_tokens = self._draft_tokens(generated)
            
            # Verify phase: check with target model
            accepted, next_token = self._verify_tokens(generated, draft_tokens)
            
            # Append accepted tokens
            generated = torch.cat([generated, accepted, next_token], dim=1)
        
        return generated
    
    def _draft_tokens(self, context: torch.Tensor) -> torch.Tensor:
        """Generate draft tokens with small model."""
        
        draft_tokens = []
        current = context
        
        for _ in range(self.config.num_speculative_tokens):
            logits = self.draft(current)
            next_token = self._sample(logits[:, -1, :])
            draft_tokens.append(next_token)
            current = torch.cat([current, next_token.unsqueeze(1)], dim=1)
        
        return torch.stack(draft_tokens, dim=1)
    
    def _verify_tokens(
        self,
        context: torch.Tensor,
        draft_tokens: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Verify draft tokens with target model."""
        
        # Get target model probabilities for all positions
        full_input = torch.cat([context, draft_tokens], dim=1)
        target_logits = self.target(full_input)
        
        # Get draft model probabilities
        draft_logits = self.draft(full_input)
        
        # Verify each token
        accepted = []
        
        for i in range(draft_tokens.shape[1]):
            pos = context.shape[1] + i
            
            # Get probabilities
            target_prob = torch.softmax(target_logits[:, pos - 1, :], dim=-1)
            draft_prob = torch.softmax(draft_logits[:, pos - 1, :], dim=-1)
            
            token = draft_tokens[:, i]
            
            # Acceptance probability
            p_target = target_prob.gather(1, token.unsqueeze(1))
            p_draft = draft_prob.gather(1, token.unsqueeze(1))
            
            accept_prob = torch.min(
                torch.ones_like(p_target),
                p_target / p_draft
            )
            
            # Accept or reject
            if torch.rand(1) < accept_prob:
                accepted.append(token)
            else:
                # Sample from adjusted distribution
                adjusted = torch.clamp(target_prob - draft_prob, min=0)
                adjusted = adjusted / adjusted.sum(dim=-1, keepdim=True)
                next_token = torch.multinomial(adjusted, 1).squeeze(1)
                
                return (
                    torch.stack(accepted, dim=1) if accepted else torch.tensor([]).reshape(1, 0),
                    next_token.unsqueeze(1)
                )
        
        # All accepted, sample next from target
        next_token = self._sample(target_logits[:, -1, :])
        
        return draft_tokens, next_token.unsqueeze(1)
    
    def _sample(self, logits: torch.Tensor) -> torch.Tensor:
        """Sample from logits."""
        
        if self.config.temperature == 0:
            return logits.argmax(dim=-1)
        
        probs = torch.softmax(logits / self.config.temperature, dim=-1)
        return torch.multinomial(probs, 1).squeeze(1)

class MedusaDecoder:
    """Medusa-style parallel decoding."""
    
    def __init__(
        self,
        model: Any,
        num_heads: int = 4,
        num_candidates: int = 5
    ):
        self.model = model
        self.num_heads = num_heads
        self.num_candidates = num_candidates
        
        # Medusa heads predict future tokens
        self.medusa_heads = None  # Would be trained heads
    
    def generate(
        self,
        input_ids: torch.Tensor,
        max_tokens: int = 100
    ) -> torch.Tensor:
        """Generate with Medusa decoding."""
        
        generated = input_ids.clone()
        
        while generated.shape[1] < input_ids.shape[1] + max_tokens:
            # Get base model hidden states
            hidden = self.model.get_hidden_states(generated)
            
            # Get predictions from Medusa heads
            candidates = self._get_candidates(hidden)
            
            # Verify candidates with tree attention
            accepted = self._verify_tree(generated, candidates)
            
            generated = torch.cat([generated, accepted], dim=1)
        
        return generated
    
    def _get_candidates(self, hidden: torch.Tensor) -> list[torch.Tensor]:
        """Get candidate tokens from Medusa heads."""
        
        candidates = []
        
        for head in range(self.num_heads):
            # Each head predicts token at position +head
            logits = self.medusa_heads[head](hidden[:, -1, :])
            top_k = torch.topk(logits, self.num_candidates, dim=-1)
            candidates.append(top_k.indices)
        
        return candidates
    
    def _verify_tree(
        self,
        context: torch.Tensor,
        candidates: list[torch.Tensor]
    ) -> torch.Tensor:
        """Verify candidates using tree attention."""
        
        # Build candidate tree
        # Verify with single forward pass
        # Return longest accepted path
        
        # Simplified: just return first candidate
        return candidates[0][:, 0:1]

Production Inference Service

The following implementation demonstrates a production-ready approach to production inference service. This code includes proper error handling, logging, and configuration management.

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, Any
import asyncio
import time

app = FastAPI()

class GenerateRequest(BaseModel):
    prompt: str
    max_tokens: int = 100
    temperature: float = 1.0
    stream: bool = False

class GenerateResponse(BaseModel):
    text: str
    tokens_generated: int
    latency_ms: float
    tokens_per_second: float

class InferenceMetrics:
    """Track inference metrics."""
    
    def __init__(self):
        self.total_requests = 0
        self.total_tokens = 0
        self.total_latency = 0.0
        self.batch_sizes = []
    
    def record(
        self,
        tokens: int,
        latency: float,
        batch_size: int = 1
    ):
        self.total_requests += 1
        self.total_tokens += tokens
        self.total_latency += latency
        self.batch_sizes.append(batch_size)
    
    def summary(self) -> dict:
        return {
            "total_requests": self.total_requests,
            "total_tokens": self.total_tokens,
            "avg_latency_ms": self.total_latency / max(self.total_requests, 1) * 1000,
            "avg_tokens_per_second": self.total_tokens / max(self.total_latency, 0.001),
            "avg_batch_size": sum(self.batch_sizes) / max(len(self.batch_sizes), 1)
        }

metrics = InferenceMetrics()

# Mock model for demo
class MockModel:
    async def generate(self, prompts: list[str], max_tokens: int = 100) -> list[str]:
        await asyncio.sleep(0.1)  # Simulate inference
        return [f"Generated response for: {p[:50]}..." for p in prompts]

model = MockModel()

# Batcher
batcher = None  # Would be DynamicBatcher instance

@app.post("/v1/generate")
async def generate(request: GenerateRequest) -> GenerateResponse:
    """Generate text."""
    
    start_time = time.time()
    
    # Direct inference (would use batcher in production)
    results = await model.generate([request.prompt], request.max_tokens)
    
    latency = time.time() - start_time
    tokens = len(results[0].split())
    
    metrics.record(tokens, latency)
    
    return GenerateResponse(
        text=results[0],
        tokens_generated=tokens,
        latency_ms=latency * 1000,
        tokens_per_second=tokens / latency
    )

@app.get("/v1/metrics")
async def get_metrics() -> dict:
    """Get inference metrics."""
    return metrics.summary()

@app.get("/health")
async def health():
    return {"status": "healthy"}

References

Conclusion

LLM inference optimization is about understanding where time and memory go and eliminating waste. KV cache is the foundation—without it, every token requires recomputing attention over the entire sequence. Paged attention takes this further by eliminating memory fragmentation and enabling efficient memory sharing across requests. Dynamic batching maximizes GPU utilization by grouping requests together, while continuous batching ensures the GPU never waits for slow requests to finish. Quantization reduces memory footprint and can actually speed up inference on modern hardware that has efficient int8/int4 operations. Speculative decoding breaks the sequential bottleneck of autoregressive generation by drafting multiple tokens in parallel and verifying them in a single forward pass. In production, combine these techniques: use a quantized model with paged attention, continuous batching, and speculative decoding for maximum throughput. Monitor your metrics—tokens per second, time to first token, and GPU utilization tell you where bottlenecks remain. The key insight is that inference optimization is not about any single technique but about building a system where every component works together efficiently.

Key Takeaways

  • Cache aggressively - Semantic caching provides 50%+ cost savings for many workloads
  • Stream responses - Improves perceived latency from seconds to milliseconds
  • Route intelligently - Match query complexity to model capability and cost
  • Batch when possible - Non-real-time workloads benefit greatly from batching
  • Measure everything - Optimization without metrics is just guessing

Conclusion

LLM inference optimization spans application-level techniques (caching, routing, batching) and model-level approaches (KV cache, quantization, speculative decoding). The best production systems combine multiple techniques, measuring impact at each step to ensure optimizations deliver real value.

References


Discover more from C4: Container, Code, Cloud & Context

Subscribe to get the latest posts sent to your email.

Leave a comment

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.