Latest Articles

LLM Inference Optimization: From KV Cache to Speculative Decoding

Introduction: LLM inference optimization is the art of making models respond faster while using fewer resources. As LLMs grow larger and usage scales, the difference between naive and optimized inference can mean 10x cost reduction and sub-second latencies instead of multi-second waits. This guide covers the techniques that matter most: KV cache optimization to avoid redundant computation, dynamic batching to maximize GPU utilization, quantization to reduce memory footprint without sacrificing quality, speculative decoding to accelerate autoregressive generation, and continuous batching for high-throughput serving. Whether you’re serving a 7B model on a single GPU or a 70B model across a cluster, these patterns will help you achieve production-grade performance.

LLM Inference Optimization
Inference Optimization: Dynamic Batching, KV Cache, Quantization

KV Cache Optimization

from dataclasses import dataclass, field
from typing import Any, Optional, Tuple
import torch
import numpy as np

@dataclass
class KVCacheConfig:
    """Configuration for KV cache."""
    
    num_layers: int
    num_heads: int
    head_dim: int
    max_seq_len: int
    dtype: torch.dtype = torch.float16

class KVCache:
    """Key-Value cache for transformer inference."""
    
    def __init__(self, config: KVCacheConfig, batch_size: int = 1):
        self.config = config
        self.batch_size = batch_size
        
        # Pre-allocate cache tensors
        cache_shape = (
            config.num_layers,
            2,  # key and value
            batch_size,
            config.num_heads,
            config.max_seq_len,
            config.head_dim
        )
        
        self.cache = torch.zeros(cache_shape, dtype=config.dtype)
        self.seq_len = 0
    
    def update(
        self,
        layer_idx: int,
        key: torch.Tensor,
        value: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Update cache and return full key/value."""
        
        seq_len = key.shape[2]
        
        # Store new key/value
        self.cache[layer_idx, 0, :, :, self.seq_len:self.seq_len + seq_len, :] = key
        self.cache[layer_idx, 1, :, :, self.seq_len:self.seq_len + seq_len, :] = value
        
        # Return full cached key/value
        full_key = self.cache[layer_idx, 0, :, :, :self.seq_len + seq_len, :]
        full_value = self.cache[layer_idx, 1, :, :, :self.seq_len + seq_len, :]
        
        return full_key, full_value
    
    def advance(self, num_tokens: int = 1):
        """Advance sequence position."""
        self.seq_len += num_tokens
    
    def reset(self):
        """Reset cache."""
        self.cache.zero_()
        self.seq_len = 0

class PagedKVCache:
    """Paged attention KV cache for efficient memory."""
    
    def __init__(
        self,
        config: KVCacheConfig,
        page_size: int = 16,
        num_pages: int = 1024
    ):
        self.config = config
        self.page_size = page_size
        self.num_pages = num_pages
        
        # Page pool
        page_shape = (
            num_pages,
            2,  # key and value
            config.num_layers,
            config.num_heads,
            page_size,
            config.head_dim
        )
        
        self.page_pool = torch.zeros(page_shape, dtype=config.dtype)
        self.free_pages = list(range(num_pages))
        
        # Sequence to page mapping
        self.seq_pages: dict[int, list[int]] = {}
    
    def allocate_sequence(self, seq_id: int) -> bool:
        """Allocate first page for sequence."""
        
        if not self.free_pages:
            return False
        
        page_idx = self.free_pages.pop()
        self.seq_pages[seq_id] = [page_idx]
        return True
    
    def extend_sequence(self, seq_id: int) -> bool:
        """Allocate additional page for sequence."""
        
        if not self.free_pages:
            return False
        
        page_idx = self.free_pages.pop()
        self.seq_pages[seq_id].append(page_idx)
        return True
    
    def free_sequence(self, seq_id: int):
        """Free all pages for sequence."""
        
        if seq_id in self.seq_pages:
            self.free_pages.extend(self.seq_pages[seq_id])
            del self.seq_pages[seq_id]
    
    def get_cache(
        self,
        seq_id: int,
        layer_idx: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get cached key/value for sequence."""
        
        pages = self.seq_pages.get(seq_id, [])
        
        if not pages:
            return None, None
        
        # Gather from pages
        keys = []
        values = []
        
        for page_idx in pages:
            keys.append(self.page_pool[page_idx, 0, layer_idx])
            values.append(self.page_pool[page_idx, 1, layer_idx])
        
        return torch.cat(keys, dim=1), torch.cat(values, dim=1)
    
    def update_cache(
        self,
        seq_id: int,
        layer_idx: int,
        position: int,
        key: torch.Tensor,
        value: torch.Tensor
    ):
        """Update cache at position."""
        
        page_idx = position // self.page_size
        offset = position % self.page_size
        
        # Extend if needed
        while len(self.seq_pages[seq_id]) <= page_idx:
            if not self.extend_sequence(seq_id):
                raise RuntimeError("Out of cache pages")
        
        actual_page = self.seq_pages[seq_id][page_idx]
        
        self.page_pool[actual_page, 0, layer_idx, :, offset, :] = key
        self.page_pool[actual_page, 1, layer_idx, :, offset, :] = value

class SlidingWindowCache:
    """Sliding window attention cache."""
    
    def __init__(
        self,
        config: KVCacheConfig,
        window_size: int = 4096
    ):
        self.config = config
        self.window_size = window_size
        
        # Circular buffer
        cache_shape = (
            config.num_layers,
            2,
            config.num_heads,
            window_size,
            config.head_dim
        )
        
        self.cache = torch.zeros(cache_shape, dtype=config.dtype)
        self.position = 0
    
    def update(
        self,
        layer_idx: int,
        key: torch.Tensor,
        value: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Update with circular buffer."""
        
        idx = self.position % self.window_size
        
        self.cache[layer_idx, 0, :, idx, :] = key.squeeze(2)
        self.cache[layer_idx, 1, :, idx, :] = value.squeeze(2)
        
        # Return window
        if self.position < self.window_size:
            return (
                self.cache[layer_idx, 0, :, :self.position + 1, :],
                self.cache[layer_idx, 1, :, :self.position + 1, :]
            )
        
        # Reorder for correct sequence
        start = (self.position + 1) % self.window_size
        indices = [(start + i) % self.window_size for i in range(self.window_size)]
        
        return (
            self.cache[layer_idx, 0, :, indices, :],
            self.cache[layer_idx, 1, :, indices, :]
        )
    
    def advance(self):
        """Advance position."""
        self.position += 1

Dynamic Batching

from dataclasses import dataclass, field
from typing import Any, Optional, Callable
from collections import deque
import asyncio
import time

@dataclass
class InferenceRequest:
    """A single inference request."""
    
    id: str
    prompt: str
    max_tokens: int = 100
    temperature: float = 1.0
    created_at: float = field(default_factory=time.time)
    future: asyncio.Future = None

@dataclass
class BatchConfig:
    """Configuration for dynamic batching."""
    
    max_batch_size: int = 32
    max_wait_time_ms: float = 50.0
    max_tokens_per_batch: int = 4096

class DynamicBatcher:
    """Dynamic batching for inference requests."""
    
    def __init__(
        self,
        config: BatchConfig,
        inference_fn: Callable
    ):
        self.config = config
        self.inference_fn = inference_fn
        self.queue: deque[InferenceRequest] = deque()
        self.running = False
    
    async def submit(self, request: InferenceRequest) -> str:
        """Submit request and wait for result."""
        
        request.future = asyncio.Future()
        self.queue.append(request)
        
        return await request.future
    
    async def start(self):
        """Start batch processing loop."""
        
        self.running = True
        
        while self.running:
            batch = await self._collect_batch()
            
            if batch:
                await self._process_batch(batch)
            else:
                await asyncio.sleep(0.001)
    
    async def _collect_batch(self) -> list[InferenceRequest]:
        """Collect requests into a batch."""
        
        if not self.queue:
            return []
        
        batch = []
        total_tokens = 0
        start_time = time.time()
        
        while self.queue:
            # Check batch limits
            if len(batch) >= self.config.max_batch_size:
                break
            
            request = self.queue[0]
            request_tokens = len(request.prompt.split()) + request.max_tokens
            
            if total_tokens + request_tokens > self.config.max_tokens_per_batch:
                break
            
            # Check wait time
            elapsed_ms = (time.time() - start_time) * 1000
            if elapsed_ms > self.config.max_wait_time_ms and batch:
                break
            
            batch.append(self.queue.popleft())
            total_tokens += request_tokens
        
        return batch
    
    async def _process_batch(self, batch: list[InferenceRequest]):
        """Process a batch of requests."""
        
        try:
            # Prepare batch inputs
            prompts = [r.prompt for r in batch]
            
            # Run inference
            results = await self.inference_fn(prompts)
            
            # Distribute results
            for request, result in zip(batch, results):
                request.future.set_result(result)
        
        except Exception as e:
            for request in batch:
                request.future.set_exception(e)
    
    def stop(self):
        """Stop batch processing."""
        self.running = False

class ContinuousBatcher:
    """Continuous batching for streaming generation."""
    
    def __init__(
        self,
        max_batch_size: int = 32,
        inference_fn: Callable = None
    ):
        self.max_batch_size = max_batch_size
        self.inference_fn = inference_fn
        
        # Active sequences
        self.active: dict[str, dict] = {}
        self.pending: deque[InferenceRequest] = deque()
    
    async def submit(self, request: InferenceRequest):
        """Submit request for continuous batching."""
        
        request.future = asyncio.Future()
        self.pending.append(request)
        return request.future
    
    async def step(self):
        """Execute one generation step."""
        
        # Add pending requests if space available
        while self.pending and len(self.active) < self.max_batch_size:
            request = self.pending.popleft()
            self.active[request.id] = {
                "request": request,
                "tokens": [],
                "position": 0
            }
        
        if not self.active:
            return
        
        # Prepare batch
        batch_ids = list(self.active.keys())
        
        # Run one step of generation
        next_tokens = await self._generate_step(batch_ids)
        
        # Update sequences
        completed = []
        
        for seq_id, token in zip(batch_ids, next_tokens):
            seq = self.active[seq_id]
            seq["tokens"].append(token)
            seq["position"] += 1
            
            # Check completion
            if self._is_complete(seq):
                completed.append(seq_id)
        
        # Complete finished sequences
        for seq_id in completed:
            seq = self.active.pop(seq_id)
            result = self._decode_tokens(seq["tokens"])
            seq["request"].future.set_result(result)
    
    async def _generate_step(self, batch_ids: list[str]) -> list[int]:
        """Generate next token for each sequence."""
        
        # Would call actual model inference
        return [0] * len(batch_ids)  # Placeholder
    
    def _is_complete(self, seq: dict) -> bool:
        """Check if sequence is complete."""
        
        request = seq["request"]
        
        # Check max tokens
        if len(seq["tokens"]) >= request.max_tokens:
            return True
        
        # Check EOS token
        if seq["tokens"] and seq["tokens"][-1] == 2:  # EOS
            return True
        
        return False
    
    def _decode_tokens(self, tokens: list[int]) -> str:
        """Decode tokens to text."""
        return " ".join(str(t) for t in tokens)  # Placeholder

class PriorityBatcher:
    """Priority-based batching."""
    
    def __init__(self, max_batch_size: int = 32):
        self.max_batch_size = max_batch_size
        self.queues: dict[int, deque] = {
            0: deque(),  # High priority
            1: deque(),  # Normal priority
            2: deque()   # Low priority
        }
    
    def submit(self, request: InferenceRequest, priority: int = 1):
        """Submit with priority."""
        
        request.future = asyncio.Future()
        self.queues[priority].append(request)
        return request.future
    
    def collect_batch(self) -> list[InferenceRequest]:
        """Collect batch respecting priorities."""
        
        batch = []
        
        # Collect from high to low priority
        for priority in sorted(self.queues.keys()):
            queue = self.queues[priority]
            
            while queue and len(batch) < self.max_batch_size:
                batch.append(queue.popleft())
        
        return batch

Quantization Techniques

from dataclasses import dataclass
from typing import Any, Optional
import torch
import numpy as np

@dataclass
class QuantizationConfig:
    """Configuration for quantization."""
    
    bits: int = 4
    group_size: int = 128
    symmetric: bool = False
    compute_dtype: torch.dtype = torch.float16

class WeightQuantizer:
    """Quantize model weights."""
    
    def __init__(self, config: QuantizationConfig):
        self.config = config
    
    def quantize(self, weight: torch.Tensor) -> dict:
        """Quantize weight tensor."""
        
        if self.config.symmetric:
            return self._quantize_symmetric(weight)
        else:
            return self._quantize_asymmetric(weight)
    
    def _quantize_symmetric(self, weight: torch.Tensor) -> dict:
        """Symmetric quantization."""
        
        # Reshape for group quantization
        original_shape = weight.shape
        weight = weight.reshape(-1, self.config.group_size)
        
        # Calculate scale
        max_val = weight.abs().max(dim=1, keepdim=True)[0]
        scale = max_val / (2 ** (self.config.bits - 1) - 1)
        
        # Quantize
        quantized = torch.round(weight / scale).to(torch.int8)
        
        return {
            "quantized": quantized,
            "scale": scale,
            "zero_point": None,
            "original_shape": original_shape
        }
    
    def _quantize_asymmetric(self, weight: torch.Tensor) -> dict:
        """Asymmetric quantization."""
        
        original_shape = weight.shape
        weight = weight.reshape(-1, self.config.group_size)
        
        # Calculate scale and zero point
        min_val = weight.min(dim=1, keepdim=True)[0]
        max_val = weight.max(dim=1, keepdim=True)[0]
        
        scale = (max_val - min_val) / (2 ** self.config.bits - 1)
        zero_point = torch.round(-min_val / scale)
        
        # Quantize
        quantized = torch.round(weight / scale + zero_point).to(torch.uint8)
        
        return {
            "quantized": quantized,
            "scale": scale,
            "zero_point": zero_point,
            "original_shape": original_shape
        }
    
    def dequantize(self, quantized_data: dict) -> torch.Tensor:
        """Dequantize to original precision."""
        
        quantized = quantized_data["quantized"].to(self.config.compute_dtype)
        scale = quantized_data["scale"]
        zero_point = quantized_data["zero_point"]
        original_shape = quantized_data["original_shape"]
        
        if zero_point is not None:
            dequantized = (quantized - zero_point) * scale
        else:
            dequantized = quantized * scale
        
        return dequantized.reshape(original_shape)

class AWQQuantizer:
    """Activation-aware Weight Quantization."""
    
    def __init__(self, bits: int = 4, group_size: int = 128):
        self.bits = bits
        self.group_size = group_size
    
    def quantize_layer(
        self,
        weight: torch.Tensor,
        activations: torch.Tensor
    ) -> dict:
        """Quantize with activation awareness."""
        
        # Calculate activation scales
        act_scales = activations.abs().mean(dim=0)
        
        # Find optimal scaling factors
        best_scale = self._search_scale(weight, act_scales)
        
        # Apply scaling and quantize
        scaled_weight = weight * best_scale
        
        # Standard quantization
        quantizer = WeightQuantizer(QuantizationConfig(
            bits=self.bits,
            group_size=self.group_size
        ))
        
        quantized = quantizer.quantize(scaled_weight)
        quantized["act_scale"] = best_scale
        
        return quantized
    
    def _search_scale(
        self,
        weight: torch.Tensor,
        act_scales: torch.Tensor,
        n_grid: int = 20
    ) -> torch.Tensor:
        """Search for optimal scaling."""
        
        best_error = float('inf')
        best_scale = torch.ones_like(act_scales)
        
        for ratio in np.linspace(0, 1, n_grid):
            scale = act_scales.pow(ratio)
            
            # Simulate quantization error
            scaled = weight * scale
            quantized = torch.round(scaled)
            dequantized = quantized / scale
            
            error = (weight - dequantized).pow(2).mean()
            
            if error < best_error:
                best_error = error
                best_scale = scale
        
        return best_scale

class GPTQQuantizer:
    """GPTQ quantization."""
    
    def __init__(
        self,
        bits: int = 4,
        group_size: int = 128,
        damp_percent: float = 0.01
    ):
        self.bits = bits
        self.group_size = group_size
        self.damp_percent = damp_percent
    
    def quantize_layer(
        self,
        weight: torch.Tensor,
        hessian: torch.Tensor
    ) -> dict:
        """Quantize using GPTQ algorithm."""
        
        W = weight.clone()
        H = hessian.clone()
        
        # Add damping
        damp = self.damp_percent * torch.diag(H).mean()
        H += damp * torch.eye(H.shape[0], device=H.device)
        
        # Cholesky decomposition
        H_inv = torch.linalg.cholesky(H)
        H_inv = torch.cholesky_inverse(H_inv)
        
        # Quantize column by column
        Q = torch.zeros_like(W)
        
        for i in range(W.shape[1]):
            w = W[:, i]
            d = H_inv[i, i]
            
            # Quantize
            q = self._quantize_column(w)
            Q[:, i] = q
            
            # Update remaining weights
            err = (w - q) / d
            W[:, i:] -= err.unsqueeze(1) * H_inv[i, i:].unsqueeze(0)
        
        return {
            "quantized": Q,
            "bits": self.bits
        }
    
    def _quantize_column(self, w: torch.Tensor) -> torch.Tensor:
        """Quantize a single column."""
        
        max_val = w.abs().max()
        scale = max_val / (2 ** (self.bits - 1) - 1)
        
        quantized = torch.round(w / scale) * scale
        return quantized

Speculative Decoding

from dataclasses import dataclass
from typing import Any, Optional, Tuple
import torch

@dataclass
class SpeculativeConfig:
    """Configuration for speculative decoding."""
    
    num_speculative_tokens: int = 4
    temperature: float = 1.0

class SpeculativeDecoder:
    """Speculative decoding for faster generation."""
    
    def __init__(
        self,
        target_model: Any,
        draft_model: Any,
        config: SpeculativeConfig
    ):
        self.target = target_model
        self.draft = draft_model
        self.config = config
    
    def generate(
        self,
        input_ids: torch.Tensor,
        max_tokens: int = 100
    ) -> torch.Tensor:
        """Generate with speculative decoding."""
        
        generated = input_ids.clone()
        
        while generated.shape[1] < input_ids.shape[1] + max_tokens:
            # Draft phase: generate speculative tokens
            draft_tokens = self._draft_tokens(generated)
            
            # Verify phase: check with target model
            accepted, next_token = self._verify_tokens(generated, draft_tokens)
            
            # Append accepted tokens
            generated = torch.cat([generated, accepted, next_token], dim=1)
        
        return generated
    
    def _draft_tokens(self, context: torch.Tensor) -> torch.Tensor:
        """Generate draft tokens with small model."""
        
        draft_tokens = []
        current = context
        
        for _ in range(self.config.num_speculative_tokens):
            logits = self.draft(current)
            next_token = self._sample(logits[:, -1, :])
            draft_tokens.append(next_token)
            current = torch.cat([current, next_token.unsqueeze(1)], dim=1)
        
        return torch.stack(draft_tokens, dim=1)
    
    def _verify_tokens(
        self,
        context: torch.Tensor,
        draft_tokens: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Verify draft tokens with target model."""
        
        # Get target model probabilities for all positions
        full_input = torch.cat([context, draft_tokens], dim=1)
        target_logits = self.target(full_input)
        
        # Get draft model probabilities
        draft_logits = self.draft(full_input)
        
        # Verify each token
        accepted = []
        
        for i in range(draft_tokens.shape[1]):
            pos = context.shape[1] + i
            
            # Get probabilities
            target_prob = torch.softmax(target_logits[:, pos - 1, :], dim=-1)
            draft_prob = torch.softmax(draft_logits[:, pos - 1, :], dim=-1)
            
            token = draft_tokens[:, i]
            
            # Acceptance probability
            p_target = target_prob.gather(1, token.unsqueeze(1))
            p_draft = draft_prob.gather(1, token.unsqueeze(1))
            
            accept_prob = torch.min(
                torch.ones_like(p_target),
                p_target / p_draft
            )
            
            # Accept or reject
            if torch.rand(1) < accept_prob:
                accepted.append(token)
            else:
                # Sample from adjusted distribution
                adjusted = torch.clamp(target_prob - draft_prob, min=0)
                adjusted = adjusted / adjusted.sum(dim=-1, keepdim=True)
                next_token = torch.multinomial(adjusted, 1).squeeze(1)
                
                return (
                    torch.stack(accepted, dim=1) if accepted else torch.tensor([]).reshape(1, 0),
                    next_token.unsqueeze(1)
                )
        
        # All accepted, sample next from target
        next_token = self._sample(target_logits[:, -1, :])
        
        return draft_tokens, next_token.unsqueeze(1)
    
    def _sample(self, logits: torch.Tensor) -> torch.Tensor:
        """Sample from logits."""
        
        if self.config.temperature == 0:
            return logits.argmax(dim=-1)
        
        probs = torch.softmax(logits / self.config.temperature, dim=-1)
        return torch.multinomial(probs, 1).squeeze(1)

class MedusaDecoder:
    """Medusa-style parallel decoding."""
    
    def __init__(
        self,
        model: Any,
        num_heads: int = 4,
        num_candidates: int = 5
    ):
        self.model = model
        self.num_heads = num_heads
        self.num_candidates = num_candidates
        
        # Medusa heads predict future tokens
        self.medusa_heads = None  # Would be trained heads
    
    def generate(
        self,
        input_ids: torch.Tensor,
        max_tokens: int = 100
    ) -> torch.Tensor:
        """Generate with Medusa decoding."""
        
        generated = input_ids.clone()
        
        while generated.shape[1] < input_ids.shape[1] + max_tokens:
            # Get base model hidden states
            hidden = self.model.get_hidden_states(generated)
            
            # Get predictions from Medusa heads
            candidates = self._get_candidates(hidden)
            
            # Verify candidates with tree attention
            accepted = self._verify_tree(generated, candidates)
            
            generated = torch.cat([generated, accepted], dim=1)
        
        return generated
    
    def _get_candidates(self, hidden: torch.Tensor) -> list[torch.Tensor]:
        """Get candidate tokens from Medusa heads."""
        
        candidates = []
        
        for head in range(self.num_heads):
            # Each head predicts token at position +head
            logits = self.medusa_heads[head](hidden[:, -1, :])
            top_k = torch.topk(logits, self.num_candidates, dim=-1)
            candidates.append(top_k.indices)
        
        return candidates
    
    def _verify_tree(
        self,
        context: torch.Tensor,
        candidates: list[torch.Tensor]
    ) -> torch.Tensor:
        """Verify candidates using tree attention."""
        
        # Build candidate tree
        # Verify with single forward pass
        # Return longest accepted path
        
        # Simplified: just return first candidate
        return candidates[0][:, 0:1]

Production Inference Service

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, Any
import asyncio
import time

app = FastAPI()

class GenerateRequest(BaseModel):
    prompt: str
    max_tokens: int = 100
    temperature: float = 1.0
    stream: bool = False

class GenerateResponse(BaseModel):
    text: str
    tokens_generated: int
    latency_ms: float
    tokens_per_second: float

class InferenceMetrics:
    """Track inference metrics."""
    
    def __init__(self):
        self.total_requests = 0
        self.total_tokens = 0
        self.total_latency = 0.0
        self.batch_sizes = []
    
    def record(
        self,
        tokens: int,
        latency: float,
        batch_size: int = 1
    ):
        self.total_requests += 1
        self.total_tokens += tokens
        self.total_latency += latency
        self.batch_sizes.append(batch_size)
    
    def summary(self) -> dict:
        return {
            "total_requests": self.total_requests,
            "total_tokens": self.total_tokens,
            "avg_latency_ms": self.total_latency / max(self.total_requests, 1) * 1000,
            "avg_tokens_per_second": self.total_tokens / max(self.total_latency, 0.001),
            "avg_batch_size": sum(self.batch_sizes) / max(len(self.batch_sizes), 1)
        }

metrics = InferenceMetrics()

# Mock model for demo
class MockModel:
    async def generate(self, prompts: list[str], max_tokens: int = 100) -> list[str]:
        await asyncio.sleep(0.1)  # Simulate inference
        return [f"Generated response for: {p[:50]}..." for p in prompts]

model = MockModel()

# Batcher
batcher = None  # Would be DynamicBatcher instance

@app.post("/v1/generate")
async def generate(request: GenerateRequest) -> GenerateResponse:
    """Generate text."""
    
    start_time = time.time()
    
    # Direct inference (would use batcher in production)
    results = await model.generate([request.prompt], request.max_tokens)
    
    latency = time.time() - start_time
    tokens = len(results[0].split())
    
    metrics.record(tokens, latency)
    
    return GenerateResponse(
        text=results[0],
        tokens_generated=tokens,
        latency_ms=latency * 1000,
        tokens_per_second=tokens / latency
    )

@app.get("/v1/metrics")
async def get_metrics() -> dict:
    """Get inference metrics."""
    return metrics.summary()

@app.get("/health")
async def health():
    return {"status": "healthy"}

References

Conclusion

LLM inference optimization is about understanding where time and memory go and eliminating waste. KV cache is the foundation—without it, every token requires recomputing attention over the entire sequence. Paged attention takes this further by eliminating memory fragmentation and enabling efficient memory sharing across requests. Dynamic batching maximizes GPU utilization by grouping requests together, while continuous batching ensures the GPU never waits for slow requests to finish. Quantization reduces memory footprint and can actually speed up inference on modern hardware that has efficient int8/int4 operations. Speculative decoding breaks the sequential bottleneck of autoregressive generation by drafting multiple tokens in parallel and verifying them in a single forward pass. In production, combine these techniques: use a quantized model with paged attention, continuous batching, and speculative decoding for maximum throughput. Monitor your metrics—tokens per second, time to first token, and GPU utilization tell you where bottlenecks remain. The key insight is that inference optimization is not about any single technique but about building a system where every component works together efficiently.


Discover more from Code, Cloud & Context

Subscribe to get the latest posts sent to your email.

About the Author

I am a Cloud Architect and Developer passionate about solving complex problems with modern technology. My blog explores the intersection of Cloud Architecture, Artificial Intelligence, and Software Engineering. I share tutorials, deep dives, and insights into building scalable, intelligent systems.

Areas of Expertise

Cloud Architecture (Azure, AWS)
Artificial Intelligence & LLMs
DevOps & Kubernetes
Backend Dev (C#, .NET, Python, Node.js)
© 2025 Code, Cloud & Context | Built by Nithin Mohan TK | Powered by Passion