Introduction: LLM inference optimization is the art of making models respond faster while using fewer resources. As LLMs grow larger and usage scales, the difference between naive and optimized inference can mean 10x cost reduction and sub-second latencies instead of multi-second waits. This guide covers the techniques that matter most: KV cache optimization to avoid redundant computation, dynamic batching to maximize GPU utilization, quantization to reduce memory footprint without sacrificing quality, speculative decoding to accelerate autoregressive generation, and continuous batching for high-throughput serving. Whether you’re serving a 7B model on a single GPU or a 70B model across a cluster, these patterns will help you achieve production-grade performance.

KV Cache Optimization
from dataclasses import dataclass, field
from typing import Any, Optional, Tuple
import torch
import numpy as np
@dataclass
class KVCacheConfig:
"""Configuration for KV cache."""
num_layers: int
num_heads: int
head_dim: int
max_seq_len: int
dtype: torch.dtype = torch.float16
class KVCache:
"""Key-Value cache for transformer inference."""
def __init__(self, config: KVCacheConfig, batch_size: int = 1):
self.config = config
self.batch_size = batch_size
# Pre-allocate cache tensors
cache_shape = (
config.num_layers,
2, # key and value
batch_size,
config.num_heads,
config.max_seq_len,
config.head_dim
)
self.cache = torch.zeros(cache_shape, dtype=config.dtype)
self.seq_len = 0
def update(
self,
layer_idx: int,
key: torch.Tensor,
value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Update cache and return full key/value."""
seq_len = key.shape[2]
# Store new key/value
self.cache[layer_idx, 0, :, :, self.seq_len:self.seq_len + seq_len, :] = key
self.cache[layer_idx, 1, :, :, self.seq_len:self.seq_len + seq_len, :] = value
# Return full cached key/value
full_key = self.cache[layer_idx, 0, :, :, :self.seq_len + seq_len, :]
full_value = self.cache[layer_idx, 1, :, :, :self.seq_len + seq_len, :]
return full_key, full_value
def advance(self, num_tokens: int = 1):
"""Advance sequence position."""
self.seq_len += num_tokens
def reset(self):
"""Reset cache."""
self.cache.zero_()
self.seq_len = 0
class PagedKVCache:
"""Paged attention KV cache for efficient memory."""
def __init__(
self,
config: KVCacheConfig,
page_size: int = 16,
num_pages: int = 1024
):
self.config = config
self.page_size = page_size
self.num_pages = num_pages
# Page pool
page_shape = (
num_pages,
2, # key and value
config.num_layers,
config.num_heads,
page_size,
config.head_dim
)
self.page_pool = torch.zeros(page_shape, dtype=config.dtype)
self.free_pages = list(range(num_pages))
# Sequence to page mapping
self.seq_pages: dict[int, list[int]] = {}
def allocate_sequence(self, seq_id: int) -> bool:
"""Allocate first page for sequence."""
if not self.free_pages:
return False
page_idx = self.free_pages.pop()
self.seq_pages[seq_id] = [page_idx]
return True
def extend_sequence(self, seq_id: int) -> bool:
"""Allocate additional page for sequence."""
if not self.free_pages:
return False
page_idx = self.free_pages.pop()
self.seq_pages[seq_id].append(page_idx)
return True
def free_sequence(self, seq_id: int):
"""Free all pages for sequence."""
if seq_id in self.seq_pages:
self.free_pages.extend(self.seq_pages[seq_id])
del self.seq_pages[seq_id]
def get_cache(
self,
seq_id: int,
layer_idx: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get cached key/value for sequence."""
pages = self.seq_pages.get(seq_id, [])
if not pages:
return None, None
# Gather from pages
keys = []
values = []
for page_idx in pages:
keys.append(self.page_pool[page_idx, 0, layer_idx])
values.append(self.page_pool[page_idx, 1, layer_idx])
return torch.cat(keys, dim=1), torch.cat(values, dim=1)
def update_cache(
self,
seq_id: int,
layer_idx: int,
position: int,
key: torch.Tensor,
value: torch.Tensor
):
"""Update cache at position."""
page_idx = position // self.page_size
offset = position % self.page_size
# Extend if needed
while len(self.seq_pages[seq_id]) <= page_idx:
if not self.extend_sequence(seq_id):
raise RuntimeError("Out of cache pages")
actual_page = self.seq_pages[seq_id][page_idx]
self.page_pool[actual_page, 0, layer_idx, :, offset, :] = key
self.page_pool[actual_page, 1, layer_idx, :, offset, :] = value
class SlidingWindowCache:
"""Sliding window attention cache."""
def __init__(
self,
config: KVCacheConfig,
window_size: int = 4096
):
self.config = config
self.window_size = window_size
# Circular buffer
cache_shape = (
config.num_layers,
2,
config.num_heads,
window_size,
config.head_dim
)
self.cache = torch.zeros(cache_shape, dtype=config.dtype)
self.position = 0
def update(
self,
layer_idx: int,
key: torch.Tensor,
value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Update with circular buffer."""
idx = self.position % self.window_size
self.cache[layer_idx, 0, :, idx, :] = key.squeeze(2)
self.cache[layer_idx, 1, :, idx, :] = value.squeeze(2)
# Return window
if self.position < self.window_size:
return (
self.cache[layer_idx, 0, :, :self.position + 1, :],
self.cache[layer_idx, 1, :, :self.position + 1, :]
)
# Reorder for correct sequence
start = (self.position + 1) % self.window_size
indices = [(start + i) % self.window_size for i in range(self.window_size)]
return (
self.cache[layer_idx, 0, :, indices, :],
self.cache[layer_idx, 1, :, indices, :]
)
def advance(self):
"""Advance position."""
self.position += 1
Dynamic Batching
from dataclasses import dataclass, field
from typing import Any, Optional, Callable
from collections import deque
import asyncio
import time
@dataclass
class InferenceRequest:
"""A single inference request."""
id: str
prompt: str
max_tokens: int = 100
temperature: float = 1.0
created_at: float = field(default_factory=time.time)
future: asyncio.Future = None
@dataclass
class BatchConfig:
"""Configuration for dynamic batching."""
max_batch_size: int = 32
max_wait_time_ms: float = 50.0
max_tokens_per_batch: int = 4096
class DynamicBatcher:
"""Dynamic batching for inference requests."""
def __init__(
self,
config: BatchConfig,
inference_fn: Callable
):
self.config = config
self.inference_fn = inference_fn
self.queue: deque[InferenceRequest] = deque()
self.running = False
async def submit(self, request: InferenceRequest) -> str:
"""Submit request and wait for result."""
request.future = asyncio.Future()
self.queue.append(request)
return await request.future
async def start(self):
"""Start batch processing loop."""
self.running = True
while self.running:
batch = await self._collect_batch()
if batch:
await self._process_batch(batch)
else:
await asyncio.sleep(0.001)
async def _collect_batch(self) -> list[InferenceRequest]:
"""Collect requests into a batch."""
if not self.queue:
return []
batch = []
total_tokens = 0
start_time = time.time()
while self.queue:
# Check batch limits
if len(batch) >= self.config.max_batch_size:
break
request = self.queue[0]
request_tokens = len(request.prompt.split()) + request.max_tokens
if total_tokens + request_tokens > self.config.max_tokens_per_batch:
break
# Check wait time
elapsed_ms = (time.time() - start_time) * 1000
if elapsed_ms > self.config.max_wait_time_ms and batch:
break
batch.append(self.queue.popleft())
total_tokens += request_tokens
return batch
async def _process_batch(self, batch: list[InferenceRequest]):
"""Process a batch of requests."""
try:
# Prepare batch inputs
prompts = [r.prompt for r in batch]
# Run inference
results = await self.inference_fn(prompts)
# Distribute results
for request, result in zip(batch, results):
request.future.set_result(result)
except Exception as e:
for request in batch:
request.future.set_exception(e)
def stop(self):
"""Stop batch processing."""
self.running = False
class ContinuousBatcher:
"""Continuous batching for streaming generation."""
def __init__(
self,
max_batch_size: int = 32,
inference_fn: Callable = None
):
self.max_batch_size = max_batch_size
self.inference_fn = inference_fn
# Active sequences
self.active: dict[str, dict] = {}
self.pending: deque[InferenceRequest] = deque()
async def submit(self, request: InferenceRequest):
"""Submit request for continuous batching."""
request.future = asyncio.Future()
self.pending.append(request)
return request.future
async def step(self):
"""Execute one generation step."""
# Add pending requests if space available
while self.pending and len(self.active) < self.max_batch_size:
request = self.pending.popleft()
self.active[request.id] = {
"request": request,
"tokens": [],
"position": 0
}
if not self.active:
return
# Prepare batch
batch_ids = list(self.active.keys())
# Run one step of generation
next_tokens = await self._generate_step(batch_ids)
# Update sequences
completed = []
for seq_id, token in zip(batch_ids, next_tokens):
seq = self.active[seq_id]
seq["tokens"].append(token)
seq["position"] += 1
# Check completion
if self._is_complete(seq):
completed.append(seq_id)
# Complete finished sequences
for seq_id in completed:
seq = self.active.pop(seq_id)
result = self._decode_tokens(seq["tokens"])
seq["request"].future.set_result(result)
async def _generate_step(self, batch_ids: list[str]) -> list[int]:
"""Generate next token for each sequence."""
# Would call actual model inference
return [0] * len(batch_ids) # Placeholder
def _is_complete(self, seq: dict) -> bool:
"""Check if sequence is complete."""
request = seq["request"]
# Check max tokens
if len(seq["tokens"]) >= request.max_tokens:
return True
# Check EOS token
if seq["tokens"] and seq["tokens"][-1] == 2: # EOS
return True
return False
def _decode_tokens(self, tokens: list[int]) -> str:
"""Decode tokens to text."""
return " ".join(str(t) for t in tokens) # Placeholder
class PriorityBatcher:
"""Priority-based batching."""
def __init__(self, max_batch_size: int = 32):
self.max_batch_size = max_batch_size
self.queues: dict[int, deque] = {
0: deque(), # High priority
1: deque(), # Normal priority
2: deque() # Low priority
}
def submit(self, request: InferenceRequest, priority: int = 1):
"""Submit with priority."""
request.future = asyncio.Future()
self.queues[priority].append(request)
return request.future
def collect_batch(self) -> list[InferenceRequest]:
"""Collect batch respecting priorities."""
batch = []
# Collect from high to low priority
for priority in sorted(self.queues.keys()):
queue = self.queues[priority]
while queue and len(batch) < self.max_batch_size:
batch.append(queue.popleft())
return batch
Quantization Techniques
from dataclasses import dataclass
from typing import Any, Optional
import torch
import numpy as np
@dataclass
class QuantizationConfig:
"""Configuration for quantization."""
bits: int = 4
group_size: int = 128
symmetric: bool = False
compute_dtype: torch.dtype = torch.float16
class WeightQuantizer:
"""Quantize model weights."""
def __init__(self, config: QuantizationConfig):
self.config = config
def quantize(self, weight: torch.Tensor) -> dict:
"""Quantize weight tensor."""
if self.config.symmetric:
return self._quantize_symmetric(weight)
else:
return self._quantize_asymmetric(weight)
def _quantize_symmetric(self, weight: torch.Tensor) -> dict:
"""Symmetric quantization."""
# Reshape for group quantization
original_shape = weight.shape
weight = weight.reshape(-1, self.config.group_size)
# Calculate scale
max_val = weight.abs().max(dim=1, keepdim=True)[0]
scale = max_val / (2 ** (self.config.bits - 1) - 1)
# Quantize
quantized = torch.round(weight / scale).to(torch.int8)
return {
"quantized": quantized,
"scale": scale,
"zero_point": None,
"original_shape": original_shape
}
def _quantize_asymmetric(self, weight: torch.Tensor) -> dict:
"""Asymmetric quantization."""
original_shape = weight.shape
weight = weight.reshape(-1, self.config.group_size)
# Calculate scale and zero point
min_val = weight.min(dim=1, keepdim=True)[0]
max_val = weight.max(dim=1, keepdim=True)[0]
scale = (max_val - min_val) / (2 ** self.config.bits - 1)
zero_point = torch.round(-min_val / scale)
# Quantize
quantized = torch.round(weight / scale + zero_point).to(torch.uint8)
return {
"quantized": quantized,
"scale": scale,
"zero_point": zero_point,
"original_shape": original_shape
}
def dequantize(self, quantized_data: dict) -> torch.Tensor:
"""Dequantize to original precision."""
quantized = quantized_data["quantized"].to(self.config.compute_dtype)
scale = quantized_data["scale"]
zero_point = quantized_data["zero_point"]
original_shape = quantized_data["original_shape"]
if zero_point is not None:
dequantized = (quantized - zero_point) * scale
else:
dequantized = quantized * scale
return dequantized.reshape(original_shape)
class AWQQuantizer:
"""Activation-aware Weight Quantization."""
def __init__(self, bits: int = 4, group_size: int = 128):
self.bits = bits
self.group_size = group_size
def quantize_layer(
self,
weight: torch.Tensor,
activations: torch.Tensor
) -> dict:
"""Quantize with activation awareness."""
# Calculate activation scales
act_scales = activations.abs().mean(dim=0)
# Find optimal scaling factors
best_scale = self._search_scale(weight, act_scales)
# Apply scaling and quantize
scaled_weight = weight * best_scale
# Standard quantization
quantizer = WeightQuantizer(QuantizationConfig(
bits=self.bits,
group_size=self.group_size
))
quantized = quantizer.quantize(scaled_weight)
quantized["act_scale"] = best_scale
return quantized
def _search_scale(
self,
weight: torch.Tensor,
act_scales: torch.Tensor,
n_grid: int = 20
) -> torch.Tensor:
"""Search for optimal scaling."""
best_error = float('inf')
best_scale = torch.ones_like(act_scales)
for ratio in np.linspace(0, 1, n_grid):
scale = act_scales.pow(ratio)
# Simulate quantization error
scaled = weight * scale
quantized = torch.round(scaled)
dequantized = quantized / scale
error = (weight - dequantized).pow(2).mean()
if error < best_error:
best_error = error
best_scale = scale
return best_scale
class GPTQQuantizer:
"""GPTQ quantization."""
def __init__(
self,
bits: int = 4,
group_size: int = 128,
damp_percent: float = 0.01
):
self.bits = bits
self.group_size = group_size
self.damp_percent = damp_percent
def quantize_layer(
self,
weight: torch.Tensor,
hessian: torch.Tensor
) -> dict:
"""Quantize using GPTQ algorithm."""
W = weight.clone()
H = hessian.clone()
# Add damping
damp = self.damp_percent * torch.diag(H).mean()
H += damp * torch.eye(H.shape[0], device=H.device)
# Cholesky decomposition
H_inv = torch.linalg.cholesky(H)
H_inv = torch.cholesky_inverse(H_inv)
# Quantize column by column
Q = torch.zeros_like(W)
for i in range(W.shape[1]):
w = W[:, i]
d = H_inv[i, i]
# Quantize
q = self._quantize_column(w)
Q[:, i] = q
# Update remaining weights
err = (w - q) / d
W[:, i:] -= err.unsqueeze(1) * H_inv[i, i:].unsqueeze(0)
return {
"quantized": Q,
"bits": self.bits
}
def _quantize_column(self, w: torch.Tensor) -> torch.Tensor:
"""Quantize a single column."""
max_val = w.abs().max()
scale = max_val / (2 ** (self.bits - 1) - 1)
quantized = torch.round(w / scale) * scale
return quantized
Speculative Decoding
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import torch
@dataclass
class SpeculativeConfig:
"""Configuration for speculative decoding."""
num_speculative_tokens: int = 4
temperature: float = 1.0
class SpeculativeDecoder:
"""Speculative decoding for faster generation."""
def __init__(
self,
target_model: Any,
draft_model: Any,
config: SpeculativeConfig
):
self.target = target_model
self.draft = draft_model
self.config = config
def generate(
self,
input_ids: torch.Tensor,
max_tokens: int = 100
) -> torch.Tensor:
"""Generate with speculative decoding."""
generated = input_ids.clone()
while generated.shape[1] < input_ids.shape[1] + max_tokens:
# Draft phase: generate speculative tokens
draft_tokens = self._draft_tokens(generated)
# Verify phase: check with target model
accepted, next_token = self._verify_tokens(generated, draft_tokens)
# Append accepted tokens
generated = torch.cat([generated, accepted, next_token], dim=1)
return generated
def _draft_tokens(self, context: torch.Tensor) -> torch.Tensor:
"""Generate draft tokens with small model."""
draft_tokens = []
current = context
for _ in range(self.config.num_speculative_tokens):
logits = self.draft(current)
next_token = self._sample(logits[:, -1, :])
draft_tokens.append(next_token)
current = torch.cat([current, next_token.unsqueeze(1)], dim=1)
return torch.stack(draft_tokens, dim=1)
def _verify_tokens(
self,
context: torch.Tensor,
draft_tokens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Verify draft tokens with target model."""
# Get target model probabilities for all positions
full_input = torch.cat([context, draft_tokens], dim=1)
target_logits = self.target(full_input)
# Get draft model probabilities
draft_logits = self.draft(full_input)
# Verify each token
accepted = []
for i in range(draft_tokens.shape[1]):
pos = context.shape[1] + i
# Get probabilities
target_prob = torch.softmax(target_logits[:, pos - 1, :], dim=-1)
draft_prob = torch.softmax(draft_logits[:, pos - 1, :], dim=-1)
token = draft_tokens[:, i]
# Acceptance probability
p_target = target_prob.gather(1, token.unsqueeze(1))
p_draft = draft_prob.gather(1, token.unsqueeze(1))
accept_prob = torch.min(
torch.ones_like(p_target),
p_target / p_draft
)
# Accept or reject
if torch.rand(1) < accept_prob:
accepted.append(token)
else:
# Sample from adjusted distribution
adjusted = torch.clamp(target_prob - draft_prob, min=0)
adjusted = adjusted / adjusted.sum(dim=-1, keepdim=True)
next_token = torch.multinomial(adjusted, 1).squeeze(1)
return (
torch.stack(accepted, dim=1) if accepted else torch.tensor([]).reshape(1, 0),
next_token.unsqueeze(1)
)
# All accepted, sample next from target
next_token = self._sample(target_logits[:, -1, :])
return draft_tokens, next_token.unsqueeze(1)
def _sample(self, logits: torch.Tensor) -> torch.Tensor:
"""Sample from logits."""
if self.config.temperature == 0:
return logits.argmax(dim=-1)
probs = torch.softmax(logits / self.config.temperature, dim=-1)
return torch.multinomial(probs, 1).squeeze(1)
class MedusaDecoder:
"""Medusa-style parallel decoding."""
def __init__(
self,
model: Any,
num_heads: int = 4,
num_candidates: int = 5
):
self.model = model
self.num_heads = num_heads
self.num_candidates = num_candidates
# Medusa heads predict future tokens
self.medusa_heads = None # Would be trained heads
def generate(
self,
input_ids: torch.Tensor,
max_tokens: int = 100
) -> torch.Tensor:
"""Generate with Medusa decoding."""
generated = input_ids.clone()
while generated.shape[1] < input_ids.shape[1] + max_tokens:
# Get base model hidden states
hidden = self.model.get_hidden_states(generated)
# Get predictions from Medusa heads
candidates = self._get_candidates(hidden)
# Verify candidates with tree attention
accepted = self._verify_tree(generated, candidates)
generated = torch.cat([generated, accepted], dim=1)
return generated
def _get_candidates(self, hidden: torch.Tensor) -> list[torch.Tensor]:
"""Get candidate tokens from Medusa heads."""
candidates = []
for head in range(self.num_heads):
# Each head predicts token at position +head
logits = self.medusa_heads[head](hidden[:, -1, :])
top_k = torch.topk(logits, self.num_candidates, dim=-1)
candidates.append(top_k.indices)
return candidates
def _verify_tree(
self,
context: torch.Tensor,
candidates: list[torch.Tensor]
) -> torch.Tensor:
"""Verify candidates using tree attention."""
# Build candidate tree
# Verify with single forward pass
# Return longest accepted path
# Simplified: just return first candidate
return candidates[0][:, 0:1]
Production Inference Service
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, Any
import asyncio
import time
app = FastAPI()
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 100
temperature: float = 1.0
stream: bool = False
class GenerateResponse(BaseModel):
text: str
tokens_generated: int
latency_ms: float
tokens_per_second: float
class InferenceMetrics:
"""Track inference metrics."""
def __init__(self):
self.total_requests = 0
self.total_tokens = 0
self.total_latency = 0.0
self.batch_sizes = []
def record(
self,
tokens: int,
latency: float,
batch_size: int = 1
):
self.total_requests += 1
self.total_tokens += tokens
self.total_latency += latency
self.batch_sizes.append(batch_size)
def summary(self) -> dict:
return {
"total_requests": self.total_requests,
"total_tokens": self.total_tokens,
"avg_latency_ms": self.total_latency / max(self.total_requests, 1) * 1000,
"avg_tokens_per_second": self.total_tokens / max(self.total_latency, 0.001),
"avg_batch_size": sum(self.batch_sizes) / max(len(self.batch_sizes), 1)
}
metrics = InferenceMetrics()
# Mock model for demo
class MockModel:
async def generate(self, prompts: list[str], max_tokens: int = 100) -> list[str]:
await asyncio.sleep(0.1) # Simulate inference
return [f"Generated response for: {p[:50]}..." for p in prompts]
model = MockModel()
# Batcher
batcher = None # Would be DynamicBatcher instance
@app.post("/v1/generate")
async def generate(request: GenerateRequest) -> GenerateResponse:
"""Generate text."""
start_time = time.time()
# Direct inference (would use batcher in production)
results = await model.generate([request.prompt], request.max_tokens)
latency = time.time() - start_time
tokens = len(results[0].split())
metrics.record(tokens, latency)
return GenerateResponse(
text=results[0],
tokens_generated=tokens,
latency_ms=latency * 1000,
tokens_per_second=tokens / latency
)
@app.get("/v1/metrics")
async def get_metrics() -> dict:
"""Get inference metrics."""
return metrics.summary()
@app.get("/health")
async def health():
return {"status": "healthy"}
References
- vLLM: https://github.com/vllm-project/vllm
- PagedAttention Paper: https://arxiv.org/abs/2309.06180
- Speculative Decoding: https://arxiv.org/abs/2211.17192
- AWQ Paper: https://arxiv.org/abs/2306.00978
Conclusion
LLM inference optimization is about understanding where time and memory go and eliminating waste. KV cache is the foundation—without it, every token requires recomputing attention over the entire sequence. Paged attention takes this further by eliminating memory fragmentation and enabling efficient memory sharing across requests. Dynamic batching maximizes GPU utilization by grouping requests together, while continuous batching ensures the GPU never waits for slow requests to finish. Quantization reduces memory footprint and can actually speed up inference on modern hardware that has efficient int8/int4 operations. Speculative decoding breaks the sequential bottleneck of autoregressive generation by drafting multiple tokens in parallel and verifying them in a single forward pass. In production, combine these techniques: use a quantized model with paged attention, continuous batching, and speculative decoding for maximum throughput. Monitor your metrics—tokens per second, time to first token, and GPU utilization tell you where bottlenecks remain. The key insight is that inference optimization is not about any single technique but about building a system where every component works together efficiently.
Discover more from Code, Cloud & Context
Subscribe to get the latest posts sent to your email.