Introduction: LLM inference is expensive—both in compute and latency. Every token generated requires a forward pass through billions of parameters, and users expect responses in seconds, not minutes. Inference optimization techniques reduce costs and improve responsiveness without sacrificing output quality. This guide covers practical optimization strategies: batching requests to maximize GPU utilization, managing KV caches to avoid redundant computation, quantizing models to reduce memory footprint, speculative decoding to accelerate generation, and continuous batching for production throughput. Whether you’re serving a chatbot to millions of users or running inference on edge devices, these techniques determine whether your LLM application is viable at scale.

Request Batching
from dataclasses import dataclass, field
from typing import Any, Optional, Callable
from datetime import datetime
import asyncio
from collections import deque
@dataclass
class InferenceRequest:
"""A single inference request."""
id: str
prompt: str
max_tokens: int = 100
temperature: float = 0.7
created_at: datetime = field(default_factory=datetime.utcnow)
future: asyncio.Future = None
@dataclass
class InferenceResult:
"""Result of inference."""
request_id: str
text: str
tokens_generated: int
latency_ms: float
class DynamicBatcher:
"""Dynamic request batching for inference."""
def __init__(
self,
model: Any,
max_batch_size: int = 32,
max_wait_ms: float = 50.0,
max_tokens_per_batch: int = 4096
):
self.model = model
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.max_tokens_per_batch = max_tokens_per_batch
self._queue: deque[InferenceRequest] = deque()
self._lock = asyncio.Lock()
self._batch_event = asyncio.Event()
self._running = False
async def start(self):
"""Start the batcher."""
self._running = True
asyncio.create_task(self._batch_loop())
async def stop(self):
"""Stop the batcher."""
self._running = False
self._batch_event.set()
async def infer(self, request: InferenceRequest) -> InferenceResult:
"""Submit request and wait for result."""
request.future = asyncio.get_event_loop().create_future()
async with self._lock:
self._queue.append(request)
self._batch_event.set()
return await request.future
async def _batch_loop(self):
"""Main batching loop."""
while self._running:
await self._batch_event.wait()
self._batch_event.clear()
# Wait for more requests or timeout
await asyncio.sleep(self.max_wait_ms / 1000)
# Collect batch
batch = await self._collect_batch()
if batch:
# Process batch
results = await self._process_batch(batch)
# Deliver results
for request, result in zip(batch, results):
if not request.future.done():
request.future.set_result(result)
async def _collect_batch(self) -> list[InferenceRequest]:
"""Collect requests into a batch."""
batch = []
total_tokens = 0
async with self._lock:
while self._queue and len(batch) < self.max_batch_size:
request = self._queue[0]
# Estimate tokens
prompt_tokens = len(request.prompt.split()) * 1.3
request_tokens = prompt_tokens + request.max_tokens
if total_tokens + request_tokens > self.max_tokens_per_batch:
break
batch.append(self._queue.popleft())
total_tokens += request_tokens
return batch
async def _process_batch(
self,
batch: list[InferenceRequest]
) -> list[InferenceResult]:
"""Process a batch of requests."""
start_time = datetime.utcnow()
# Prepare batch inputs
prompts = [r.prompt for r in batch]
max_tokens = max(r.max_tokens for r in batch)
# Run inference
outputs = await self.model.generate_batch(
prompts,
max_tokens=max_tokens
)
end_time = datetime.utcnow()
latency = (end_time - start_time).total_seconds() * 1000
# Create results
results = []
for request, output in zip(batch, outputs):
results.append(InferenceResult(
request_id=request.id,
text=output,
tokens_generated=len(output.split()),
latency_ms=latency
))
return results
class ContinuousBatcher:
"""Continuous batching for streaming inference."""
def __init__(
self,
model: Any,
max_batch_size: int = 64,
max_sequence_length: int = 2048
):
self.model = model
self.max_batch_size = max_batch_size
self.max_sequence_length = max_sequence_length
self._active_requests: dict[str, InferenceRequest] = {}
self._pending_queue: deque[InferenceRequest] = deque()
self._lock = asyncio.Lock()
async def add_request(self, request: InferenceRequest):
"""Add request to processing."""
async with self._lock:
if len(self._active_requests) < self.max_batch_size:
self._active_requests[request.id] = request
else:
self._pending_queue.append(request)
async def step(self) -> dict[str, str]:
"""Generate one token for all active requests."""
if not self._active_requests:
return {}
# Get current state of all active requests
requests = list(self._active_requests.values())
# Generate next token for batch
next_tokens = await self.model.generate_next_token_batch(
[r.prompt for r in requests]
)
results = {}
completed = []
for request, token in zip(requests, next_tokens):
# Update prompt with new token
request.prompt += token
results[request.id] = token
# Check completion
if self._is_complete(request, token):
completed.append(request.id)
# Remove completed requests
async with self._lock:
for req_id in completed:
del self._active_requests[req_id]
# Add pending request
if self._pending_queue:
new_request = self._pending_queue.popleft()
self._active_requests[new_request.id] = new_request
return results
def _is_complete(self, request: InferenceRequest, token: str) -> bool:
"""Check if request is complete."""
# Check for EOS token
if token in ["<|endoftext|>", "", "<|end|>"]:
return True
# Check max tokens
generated = len(request.prompt.split()) - len(request.prompt.split())
if generated >= request.max_tokens:
return True
return False
KV Cache Management
from dataclasses import dataclass
from typing import Any, Optional
import torch
@dataclass
class KVCacheEntry:
"""A KV cache entry."""
key: torch.Tensor
value: torch.Tensor
sequence_length: int
last_accessed: float
class KVCacheManager:
"""Manage KV caches for inference."""
def __init__(
self,
num_layers: int,
num_heads: int,
head_dim: int,
max_cache_size_gb: float = 8.0,
device: str = "cuda"
):
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
self.max_cache_size = int(max_cache_size_gb * 1e9)
self.device = device
self._caches: dict[str, list[KVCacheEntry]] = {}
self._current_size = 0
def get_cache(self, request_id: str) -> Optional[list[KVCacheEntry]]:
"""Get KV cache for request."""
cache = self._caches.get(request_id)
if cache:
import time
for entry in cache:
entry.last_accessed = time.time()
return cache
def set_cache(
self,
request_id: str,
keys: list[torch.Tensor],
values: list[torch.Tensor],
seq_len: int
):
"""Set KV cache for request."""
import time
# Calculate size
cache_size = sum(
k.numel() * k.element_size() + v.numel() * v.element_size()
for k, v in zip(keys, values)
)
# Evict if necessary
while self._current_size + cache_size > self.max_cache_size:
if not self._evict_oldest():
break
# Store cache
entries = []
for k, v in zip(keys, values):
entries.append(KVCacheEntry(
key=k,
value=v,
sequence_length=seq_len,
last_accessed=time.time()
))
self._caches[request_id] = entries
self._current_size += cache_size
def extend_cache(
self,
request_id: str,
new_keys: list[torch.Tensor],
new_values: list[torch.Tensor]
):
"""Extend existing cache with new tokens."""
cache = self._caches.get(request_id)
if not cache:
return
for i, (entry, new_k, new_v) in enumerate(zip(cache, new_keys, new_values)):
# Concatenate along sequence dimension
entry.key = torch.cat([entry.key, new_k], dim=-2)
entry.value = torch.cat([entry.value, new_v], dim=-2)
entry.sequence_length += new_k.shape[-2]
def delete_cache(self, request_id: str):
"""Delete cache for request."""
if request_id in self._caches:
cache = self._caches[request_id]
cache_size = sum(
e.key.numel() * e.key.element_size() +
e.value.numel() * e.value.element_size()
for e in cache
)
del self._caches[request_id]
self._current_size -= cache_size
def _evict_oldest(self) -> bool:
"""Evict oldest cache entry."""
if not self._caches:
return False
# Find oldest
oldest_id = min(
self._caches.keys(),
key=lambda k: min(e.last_accessed for e in self._caches[k])
)
self.delete_cache(oldest_id)
return True
class PagedKVCache:
"""Paged attention KV cache (vLLM style)."""
def __init__(
self,
num_layers: int,
num_heads: int,
head_dim: int,
block_size: int = 16,
num_blocks: int = 1000,
device: str = "cuda"
):
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
self.block_size = block_size
self.num_blocks = num_blocks
self.device = device
# Pre-allocate block pool
self.key_blocks = torch.zeros(
num_blocks, num_layers, block_size, num_heads, head_dim,
device=device, dtype=torch.float16
)
self.value_blocks = torch.zeros(
num_blocks, num_layers, block_size, num_heads, head_dim,
device=device, dtype=torch.float16
)
# Block allocation tracking
self.free_blocks = list(range(num_blocks))
self.sequence_blocks: dict[str, list[int]] = {}
def allocate_blocks(self, request_id: str, num_tokens: int) -> list[int]:
"""Allocate blocks for a sequence."""
num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
if len(self.free_blocks) < num_blocks_needed:
raise RuntimeError("Out of KV cache blocks")
allocated = []
for _ in range(num_blocks_needed):
block_id = self.free_blocks.pop()
allocated.append(block_id)
self.sequence_blocks[request_id] = allocated
return allocated
def free_blocks_for_sequence(self, request_id: str):
"""Free blocks for a sequence."""
if request_id in self.sequence_blocks:
blocks = self.sequence_blocks.pop(request_id)
self.free_blocks.extend(blocks)
def write_kv(
self,
request_id: str,
layer_idx: int,
position: int,
key: torch.Tensor,
value: torch.Tensor
):
"""Write KV to cache."""
blocks = self.sequence_blocks.get(request_id, [])
block_idx = position // self.block_size
offset = position % self.block_size
if block_idx >= len(blocks):
# Need more blocks
new_block = self.free_blocks.pop()
blocks.append(new_block)
self.sequence_blocks[request_id] = blocks
block_id = blocks[block_idx]
self.key_blocks[block_id, layer_idx, offset] = key
self.value_blocks[block_id, layer_idx, offset] = value
def read_kv(
self,
request_id: str,
layer_idx: int,
positions: list[int]
) -> tuple[torch.Tensor, torch.Tensor]:
"""Read KV from cache."""
blocks = self.sequence_blocks.get(request_id, [])
keys = []
values = []
for pos in positions:
block_idx = pos // self.block_size
offset = pos % self.block_size
if block_idx < len(blocks):
block_id = blocks[block_idx]
keys.append(self.key_blocks[block_id, layer_idx, offset])
values.append(self.value_blocks[block_id, layer_idx, offset])
return torch.stack(keys), torch.stack(values)
Model Quantization
from dataclasses import dataclass
from typing import Any, Optional
from enum import Enum
import torch
import torch.nn as nn
class QuantizationType(Enum):
"""Quantization types."""
INT8 = "int8"
INT4 = "int4"
FP8 = "fp8"
NF4 = "nf4" # Normal float 4-bit
@dataclass
class QuantizationConfig:
"""Quantization configuration."""
quant_type: QuantizationType
group_size: int = 128
use_double_quant: bool = False
compute_dtype: torch.dtype = torch.float16
class Int8Quantizer:
"""INT8 quantization."""
def __init__(self, symmetric: bool = True):
self.symmetric = symmetric
def quantize(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize tensor to INT8."""
if self.symmetric:
# Symmetric quantization
scale = tensor.abs().max() / 127
quantized = (tensor / scale).round().clamp(-128, 127).to(torch.int8)
return quantized, scale
else:
# Asymmetric quantization
min_val = tensor.min()
max_val = tensor.max()
scale = (max_val - min_val) / 255
zero_point = (-min_val / scale).round()
quantized = ((tensor / scale) + zero_point).round().clamp(0, 255).to(torch.uint8)
return quantized, torch.tensor([scale, zero_point])
def dequantize(
self,
quantized: torch.Tensor,
scale: torch.Tensor
) -> torch.Tensor:
"""Dequantize INT8 to float."""
if self.symmetric:
return quantized.float() * scale
else:
scale_val, zero_point = scale[0], scale[1]
return (quantized.float() - zero_point) * scale_val
class Int4Quantizer:
"""INT4 quantization with grouping."""
def __init__(self, group_size: int = 128):
self.group_size = group_size
def quantize(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize tensor to INT4."""
# Reshape for grouping
original_shape = tensor.shape
tensor = tensor.reshape(-1, self.group_size)
# Compute scales per group
scales = tensor.abs().max(dim=1, keepdim=True).values / 7
scales = scales.clamp(min=1e-8)
# Quantize
quantized = (tensor / scales).round().clamp(-8, 7).to(torch.int8)
# Pack two INT4 values into one INT8
packed = self._pack_int4(quantized)
return packed, scales.reshape(-1)
def _pack_int4(self, tensor: torch.Tensor) -> torch.Tensor:
"""Pack INT4 values into INT8."""
# Ensure even number of elements
flat = tensor.reshape(-1)
if flat.shape[0] % 2 != 0:
flat = torch.cat([flat, torch.zeros(1, dtype=flat.dtype)])
# Pack pairs
low = flat[0::2] & 0x0F
high = (flat[1::2] & 0x0F) << 4
packed = (low | high).to(torch.uint8)
return packed
def dequantize(
self,
packed: torch.Tensor,
scales: torch.Tensor,
original_shape: tuple
) -> torch.Tensor:
"""Dequantize INT4 to float."""
# Unpack
unpacked = self._unpack_int4(packed)
# Reshape and apply scales
unpacked = unpacked.reshape(-1, self.group_size).float()
scales = scales.reshape(-1, 1)
dequantized = unpacked * scales
return dequantized.reshape(original_shape)
def _unpack_int4(self, packed: torch.Tensor) -> torch.Tensor:
"""Unpack INT4 values from INT8."""
low = (packed & 0x0F).to(torch.int8)
high = ((packed >> 4) & 0x0F).to(torch.int8)
# Sign extend
low = torch.where(low > 7, low - 16, low)
high = torch.where(high > 7, high - 16, high)
# Interleave
unpacked = torch.stack([low, high], dim=-1).reshape(-1)
return unpacked
class QuantizedLinear(nn.Module):
"""Quantized linear layer."""
def __init__(
self,
in_features: int,
out_features: int,
config: QuantizationConfig
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.config = config
# Quantized weights (stored as INT8 or packed INT4)
self.register_buffer("weight_quantized", None)
self.register_buffer("weight_scales", None)
self.register_buffer("bias", None)
@classmethod
def from_linear(
cls,
linear: nn.Linear,
config: QuantizationConfig
) -> "QuantizedLinear":
"""Create from existing linear layer."""
layer = cls(linear.in_features, linear.out_features, config)
# Quantize weights
if config.quant_type == QuantizationType.INT8:
quantizer = Int8Quantizer()
quantized, scales = quantizer.quantize(linear.weight.data)
elif config.quant_type == QuantizationType.INT4:
quantizer = Int4Quantizer(config.group_size)
quantized, scales = quantizer.quantize(linear.weight.data)
else:
raise ValueError(f"Unsupported quantization: {config.quant_type}")
layer.weight_quantized = quantized
layer.weight_scales = scales
if linear.bias is not None:
layer.bias = linear.bias.data.clone()
return layer
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with dequantization."""
# Dequantize weights
if self.config.quant_type == QuantizationType.INT8:
quantizer = Int8Quantizer()
weight = quantizer.dequantize(
self.weight_quantized,
self.weight_scales
)
elif self.config.quant_type == QuantizationType.INT4:
quantizer = Int4Quantizer(self.config.group_size)
weight = quantizer.dequantize(
self.weight_quantized,
self.weight_scales,
(self.out_features, self.in_features)
)
else:
weight = self.weight_quantized.float()
# Compute
output = torch.nn.functional.linear(x, weight, self.bias)
return output
Speculative Decoding
from dataclasses import dataclass
from typing import Any, Optional
import torch
@dataclass
class SpeculativeConfig:
"""Speculative decoding configuration."""
num_speculative_tokens: int = 4
temperature: float = 0.0
top_p: float = 1.0
class SpeculativeDecoder:
"""Speculative decoding for faster inference."""
def __init__(
self,
target_model: Any,
draft_model: Any,
config: SpeculativeConfig
):
self.target_model = target_model
self.draft_model = draft_model
self.config = config
async def generate(
self,
prompt: str,
max_tokens: int
) -> str:
"""Generate with speculative decoding."""
generated_tokens = []
current_prompt = prompt
while len(generated_tokens) < max_tokens:
# Draft phase: generate speculative tokens
draft_tokens = await self._draft_phase(current_prompt)
# Verify phase: check with target model
accepted_tokens = await self._verify_phase(
current_prompt,
draft_tokens
)
# Update state
generated_tokens.extend(accepted_tokens)
current_prompt += "".join(accepted_tokens)
# Check for EOS
if any(t in ["<|endoftext|>", ""] for t in accepted_tokens):
break
return "".join(generated_tokens)
async def _draft_phase(self, prompt: str) -> list[str]:
"""Generate draft tokens with small model."""
draft_tokens = []
current = prompt
for _ in range(self.config.num_speculative_tokens):
token = await self.draft_model.generate_next_token(
current,
temperature=self.config.temperature
)
draft_tokens.append(token)
current += token
return draft_tokens
async def _verify_phase(
self,
prompt: str,
draft_tokens: list[str]
) -> list[str]:
"""Verify draft tokens with target model."""
# Get target model probabilities for all positions
full_sequence = prompt + "".join(draft_tokens)
target_probs = await self.target_model.get_token_probabilities(
full_sequence
)
# Get draft model probabilities
draft_probs = await self.draft_model.get_token_probabilities(
full_sequence
)
accepted = []
for i, token in enumerate(draft_tokens):
# Get probabilities at position
p_target = target_probs[len(prompt) + i].get(token, 0)
p_draft = draft_probs[len(prompt) + i].get(token, 0)
# Acceptance criterion
if p_draft > 0:
acceptance_prob = min(1, p_target / p_draft)
else:
acceptance_prob = 1 if p_target > 0 else 0
# Accept or reject
import random
if random.random() < acceptance_prob:
accepted.append(token)
else:
# Sample from adjusted distribution
adjusted_token = self._sample_adjusted(
target_probs[len(prompt) + i],
draft_probs[len(prompt) + i]
)
accepted.append(adjusted_token)
break # Stop at first rejection
return accepted
def _sample_adjusted(
self,
target_probs: dict[str, float],
draft_probs: dict[str, float]
) -> str:
"""Sample from adjusted distribution."""
# Compute adjusted probabilities
adjusted = {}
for token in target_probs:
p_t = target_probs.get(token, 0)
p_d = draft_probs.get(token, 0)
adjusted[token] = max(0, p_t - p_d)
# Normalize
total = sum(adjusted.values())
if total > 0:
adjusted = {k: v / total for k, v in adjusted.items()}
else:
adjusted = target_probs
# Sample
import random
r = random.random()
cumsum = 0
for token, prob in adjusted.items():
cumsum += prob
if r < cumsum:
return token
return list(adjusted.keys())[0]
class MedusaDecoder:
"""Medusa-style parallel decoding."""
def __init__(
self,
model: Any,
num_heads: int = 4,
num_candidates: int = 5
):
self.model = model
self.num_heads = num_heads
self.num_candidates = num_candidates
async def generate(
self,
prompt: str,
max_tokens: int
) -> str:
"""Generate with Medusa decoding."""
generated = []
current = prompt
while len(generated) < max_tokens:
# Generate candidates from all heads
candidates = await self._generate_candidates(current)
# Verify candidates
accepted = await self._verify_candidates(current, candidates)
generated.extend(accepted)
current += "".join(accepted)
if not accepted:
# Fall back to single token
token = await self.model.generate_next_token(current)
generated.append(token)
current += token
return "".join(generated)
async def _generate_candidates(
self,
prompt: str
) -> list[list[str]]:
"""Generate candidate sequences from Medusa heads."""
# Each head predicts tokens at different positions
candidates = []
for head_idx in range(self.num_heads):
head_candidates = await self.model.medusa_head_predict(
prompt,
head_idx,
self.num_candidates
)
candidates.append(head_candidates)
return candidates
async def _verify_candidates(
self,
prompt: str,
candidates: list[list[str]]
) -> list[str]:
"""Verify candidate sequences."""
# Build candidate trees
trees = self._build_candidate_trees(candidates)
# Verify with single forward pass
best_sequence = []
for tree in trees:
verified = await self.model.verify_sequence(prompt, tree)
if len(verified) > len(best_sequence):
best_sequence = verified
return best_sequence
def _build_candidate_trees(
self,
candidates: list[list[str]]
) -> list[list[str]]:
"""Build candidate trees from head predictions."""
# Simplified: just return top candidates
trees = []
for head_candidates in candidates:
for candidate in head_candidates[:self.num_candidates]:
trees.append([candidate])
return trees
Production Inference Service
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
from enum import Enum
import asyncio
app = FastAPI()
# Initialize components
batcher = None # Would initialize DynamicBatcher
kv_cache = None # Would initialize KVCacheManager
class OptimizationLevel(str, Enum):
NONE = "none"
BASIC = "basic" # Batching only
FULL = "full" # Batching + KV cache + quantization
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 100
temperature: float = 0.7
stream: bool = False
optimization: OptimizationLevel = OptimizationLevel.FULL
class GenerateResponse(BaseModel):
text: str
tokens_generated: int
latency_ms: float
optimization_used: str
class BatchGenerateRequest(BaseModel):
prompts: list[str]
max_tokens: int = 100
temperature: float = 0.7
class BatchGenerateResponse(BaseModel):
results: list[GenerateResponse]
total_latency_ms: float
throughput_tokens_per_sec: float
@app.post("/v1/generate")
async def generate(request: GenerateRequest) -> GenerateResponse:
"""Generate text with optimization."""
import time
import uuid
start = time.time()
# Create inference request
req = InferenceRequest(
id=str(uuid.uuid4()),
prompt=request.prompt,
max_tokens=request.max_tokens,
temperature=request.temperature
)
# Submit to batcher
if batcher:
result = await batcher.infer(req)
text = result.text
tokens = result.tokens_generated
else:
# Direct inference (placeholder)
text = "Generated text placeholder"
tokens = 10
latency = (time.time() - start) * 1000
return GenerateResponse(
text=text,
tokens_generated=tokens,
latency_ms=latency,
optimization_used=request.optimization.value
)
@app.post("/v1/generate/batch")
async def batch_generate(request: BatchGenerateRequest) -> BatchGenerateResponse:
"""Batch generate text."""
import time
import uuid
start = time.time()
# Create requests
requests = [
InferenceRequest(
id=str(uuid.uuid4()),
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature
)
for prompt in request.prompts
]
# Submit all
if batcher:
tasks = [batcher.infer(req) for req in requests]
results = await asyncio.gather(*tasks)
else:
results = [
InferenceResult(
request_id=req.id,
text="Placeholder",
tokens_generated=10,
latency_ms=100
)
for req in requests
]
total_latency = (time.time() - start) * 1000
total_tokens = sum(r.tokens_generated for r in results)
throughput = total_tokens / (total_latency / 1000) if total_latency > 0 else 0
return BatchGenerateResponse(
results=[
GenerateResponse(
text=r.text,
tokens_generated=r.tokens_generated,
latency_ms=r.latency_ms,
optimization_used="batch"
)
for r in results
],
total_latency_ms=total_latency,
throughput_tokens_per_sec=throughput
)
class CacheStatsResponse(BaseModel):
total_entries: int
total_size_mb: float
hit_rate: float
@app.get("/v1/cache/stats")
async def cache_stats() -> CacheStatsResponse:
"""Get KV cache statistics."""
if kv_cache:
return CacheStatsResponse(
total_entries=len(kv_cache._caches),
total_size_mb=kv_cache._current_size / 1e6,
hit_rate=0.85 # Would track actual hit rate
)
return CacheStatsResponse(
total_entries=0,
total_size_mb=0,
hit_rate=0
)
@app.delete("/v1/cache")
async def clear_cache():
"""Clear KV cache."""
if kv_cache:
kv_cache._caches.clear()
kv_cache._current_size = 0
return {"status": "cleared"}
class ModelInfoResponse(BaseModel):
model_name: str
quantization: Optional[str]
batch_size: int
max_sequence_length: int
@app.get("/v1/model/info")
async def model_info() -> ModelInfoResponse:
"""Get model information."""
return ModelInfoResponse(
model_name="llama-7b",
quantization="int4",
batch_size=32,
max_sequence_length=4096
)
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.get("/metrics")
async def metrics():
"""Get inference metrics."""
return {
"requests_total": 1000,
"requests_per_second": 50,
"avg_latency_ms": 150,
"p99_latency_ms": 500,
"tokens_per_second": 2000,
"batch_utilization": 0.75,
"cache_hit_rate": 0.85
}
References
- vLLM: https://github.com/vllm-project/vllm
- PagedAttention: https://arxiv.org/abs/2309.06180
- Speculative Decoding: https://arxiv.org/abs/2211.17192
- GPTQ Quantization: https://arxiv.org/abs/2210.17323
- Medusa: https://arxiv.org/abs/2401.10774
Conclusion
Inference optimization determines whether LLM applications are economically viable at scale. Start with dynamic batching—grouping requests together maximizes GPU utilization and amortizes the cost of loading model weights. Implement KV caching to avoid recomputing attention for previously processed tokens; paged attention (vLLM-style) enables efficient memory management for variable-length sequences. Quantize models to INT8 or INT4 to reduce memory footprint and increase throughput, with minimal quality degradation for most tasks. Use speculative decoding with a small draft model to generate multiple tokens per forward pass of the large model. Continuous batching handles streaming scenarios where requests arrive and complete at different times. Monitor key metrics: tokens per second, batch utilization, cache hit rate, and latency percentiles. The key insight is that inference optimization is about maximizing useful computation per GPU cycle—every optimization technique reduces waste, whether it's redundant computation (KV cache), memory bandwidth (quantization), or idle time (batching). Combine these techniques for multiplicative improvements in throughput and cost efficiency.
Discover more from Code, Cloud & Context
Subscribe to get the latest posts sent to your email.