Latest Articles

LLM Routing and Load Balancing: Optimizing Cost and Performance Across Model Fleets

Introduction: LLM routing and load balancing are critical for building cost-effective, reliable AI systems at scale. Not every query needs GPT-4—many can be handled by smaller, faster, cheaper models with equivalent quality. Intelligent routing analyzes incoming requests and directs them to the most appropriate model based on complexity, cost constraints, latency requirements, and current system load. This guide covers the techniques that make LLM routing effective: request classification, model selection strategies, load balancing algorithms, fallback handling, and cost optimization. Whether you’re managing a fleet of models or optimizing a single provider’s offerings, these patterns will help you serve better responses at lower cost while maintaining high availability.

LLM Routing and Load Balancing
LLM Router: Request Classification, Model Selection, Load Balancing

Request Classification

from dataclasses import dataclass, field
from typing import Any, Optional, List
from enum import Enum
import re

class RequestComplexity(Enum):
    """Request complexity levels."""
    
    SIMPLE = "simple"
    MODERATE = "moderate"
    COMPLEX = "complex"
    EXPERT = "expert"

class RequestType(Enum):
    """Types of LLM requests."""
    
    CLASSIFICATION = "classification"
    EXTRACTION = "extraction"
    SUMMARIZATION = "summarization"
    GENERATION = "generation"
    REASONING = "reasoning"
    CODE = "code"
    CONVERSATION = "conversation"

@dataclass
class ClassifiedRequest:
    """A classified LLM request."""
    
    prompt: str
    complexity: RequestComplexity
    request_type: RequestType
    estimated_tokens: int
    requires_reasoning: bool
    requires_creativity: bool
    metadata: dict = field(default_factory=dict)

class RequestClassifier:
    """Classify incoming LLM requests."""
    
    def __init__(self):
        # Patterns for request type detection
        self.type_patterns = {
            RequestType.CLASSIFICATION: [
                r"classify", r"categorize", r"is this", r"which type",
                r"sentiment", r"positive or negative"
            ],
            RequestType.EXTRACTION: [
                r"extract", r"find all", r"list the", r"what are the",
                r"identify", r"parse"
            ],
            RequestType.SUMMARIZATION: [
                r"summarize", r"summary", r"tldr", r"brief",
                r"main points", r"key takeaways"
            ],
            RequestType.GENERATION: [
                r"write", r"create", r"generate", r"compose",
                r"draft", r"produce"
            ],
            RequestType.REASONING: [
                r"why", r"explain", r"reason", r"analyze",
                r"compare", r"evaluate", r"think through"
            ],
            RequestType.CODE: [
                r"code", r"function", r"implement", r"debug",
                r"python", r"javascript", r"sql", r"algorithm"
            ],
            RequestType.CONVERSATION: [
                r"chat", r"talk", r"hello", r"hi",
                r"how are you", r"thanks"
            ]
        }
        
        # Complexity indicators
        self.complexity_indicators = {
            "simple": [
                r"yes or no", r"true or false", r"one word",
                r"simple", r"quick", r"brief"
            ],
            "complex": [
                r"step by step", r"detailed", r"comprehensive",
                r"in depth", r"thorough", r"analyze"
            ],
            "expert": [
                r"expert", r"advanced", r"technical",
                r"research", r"academic", r"professional"
            ]
        }
    
    def classify(self, prompt: str) -> ClassifiedRequest:
        """Classify a request."""
        
        prompt_lower = prompt.lower()
        
        # Detect request type
        request_type = self._detect_type(prompt_lower)
        
        # Estimate complexity
        complexity = self._estimate_complexity(prompt_lower, request_type)
        
        # Estimate tokens
        estimated_tokens = self._estimate_tokens(prompt)
        
        # Check for reasoning/creativity requirements
        requires_reasoning = self._requires_reasoning(prompt_lower, request_type)
        requires_creativity = self._requires_creativity(prompt_lower, request_type)
        
        return ClassifiedRequest(
            prompt=prompt,
            complexity=complexity,
            request_type=request_type,
            estimated_tokens=estimated_tokens,
            requires_reasoning=requires_reasoning,
            requires_creativity=requires_creativity
        )
    
    def _detect_type(self, prompt: str) -> RequestType:
        """Detect request type from prompt."""
        
        scores = {}
        
        for req_type, patterns in self.type_patterns.items():
            score = sum(1 for p in patterns if re.search(p, prompt))
            scores[req_type] = score
        
        if max(scores.values()) == 0:
            return RequestType.GENERATION
        
        return max(scores, key=scores.get)
    
    def _estimate_complexity(
        self,
        prompt: str,
        request_type: RequestType
    ) -> RequestComplexity:
        """Estimate request complexity."""
        
        # Check explicit indicators
        for level, patterns in self.complexity_indicators.items():
            if any(re.search(p, prompt) for p in patterns):
                if level == "simple":
                    return RequestComplexity.SIMPLE
                elif level == "complex":
                    return RequestComplexity.COMPLEX
                elif level == "expert":
                    return RequestComplexity.EXPERT
        
        # Heuristics based on prompt length and type
        word_count = len(prompt.split())
        
        if word_count < 20:
            base_complexity = RequestComplexity.SIMPLE
        elif word_count < 100:
            base_complexity = RequestComplexity.MODERATE
        else:
            base_complexity = RequestComplexity.COMPLEX
        
        # Adjust based on request type
        if request_type in [RequestType.REASONING, RequestType.CODE]:
            if base_complexity == RequestComplexity.SIMPLE:
                return RequestComplexity.MODERATE
            elif base_complexity == RequestComplexity.MODERATE:
                return RequestComplexity.COMPLEX
        
        return base_complexity
    
    def _estimate_tokens(self, prompt: str) -> int:
        """Estimate token count."""
        
        # Rough estimate: ~4 characters per token
        return len(prompt) // 4
    
    def _requires_reasoning(self, prompt: str, request_type: RequestType) -> bool:
        """Check if request requires reasoning."""
        
        if request_type == RequestType.REASONING:
            return True
        
        reasoning_keywords = [
            "why", "how", "explain", "reason", "because",
            "analyze", "compare", "evaluate", "think"
        ]
        
        return any(kw in prompt for kw in reasoning_keywords)
    
    def _requires_creativity(self, prompt: str, request_type: RequestType) -> bool:
        """Check if request requires creativity."""
        
        if request_type == RequestType.GENERATION:
            return True
        
        creativity_keywords = [
            "creative", "original", "unique", "innovative",
            "story", "poem", "imagine", "brainstorm"
        ]
        
        return any(kw in prompt for kw in creativity_keywords)

class MLClassifier:
    """ML-based request classifier."""
    
    def __init__(self, model_path: str = None):
        self.model = None
        self.tokenizer = None
        
        if model_path:
            self._load_model(model_path)
    
    def _load_model(self, path: str):
        """Load classification model."""
        
        # Placeholder for actual model loading
        pass
    
    def classify(self, prompt: str) -> ClassifiedRequest:
        """Classify using ML model."""
        
        if self.model is None:
            # Fallback to rule-based
            return RequestClassifier().classify(prompt)
        
        # Get model predictions
        features = self._extract_features(prompt)
        
        complexity_pred = self.model.predict_complexity(features)
        type_pred = self.model.predict_type(features)
        
        return ClassifiedRequest(
            prompt=prompt,
            complexity=RequestComplexity(complexity_pred),
            request_type=RequestType(type_pred),
            estimated_tokens=len(prompt) // 4,
            requires_reasoning=self.model.predict_reasoning(features),
            requires_creativity=self.model.predict_creativity(features)
        )
    
    def _extract_features(self, prompt: str) -> dict:
        """Extract features for classification."""
        
        return {
            "length": len(prompt),
            "word_count": len(prompt.split()),
            "question_marks": prompt.count("?"),
            "has_code": bool(re.search(r"```|def |function |class ", prompt))
        }

Model Selection

from dataclasses import dataclass
from typing import Any, Optional, List, Dict
from enum import Enum

@dataclass
class ModelConfig:
    """Configuration for an LLM model."""
    
    name: str
    provider: str
    cost_per_1k_input: float
    cost_per_1k_output: float
    max_tokens: int
    latency_ms: float  # Average latency
    quality_score: float  # 0-1 quality rating
    capabilities: set[str] = None
    
    def __post_init__(self):
        self.capabilities = self.capabilities or set()

class ModelRegistry:
    """Registry of available models."""
    
    def __init__(self):
        self.models: dict[str, ModelConfig] = {}
        self._register_defaults()
    
    def _register_defaults(self):
        """Register default models."""
        
        self.register(ModelConfig(
            name="gpt-4-turbo",
            provider="openai",
            cost_per_1k_input=0.01,
            cost_per_1k_output=0.03,
            max_tokens=128000,
            latency_ms=2000,
            quality_score=0.95,
            capabilities={"reasoning", "code", "creativity", "vision"}
        ))
        
        self.register(ModelConfig(
            name="gpt-3.5-turbo",
            provider="openai",
            cost_per_1k_input=0.0005,
            cost_per_1k_output=0.0015,
            max_tokens=16385,
            latency_ms=500,
            quality_score=0.80,
            capabilities={"general", "code", "conversation"}
        ))
        
        self.register(ModelConfig(
            name="claude-3-opus",
            provider="anthropic",
            cost_per_1k_input=0.015,
            cost_per_1k_output=0.075,
            max_tokens=200000,
            latency_ms=3000,
            quality_score=0.97,
            capabilities={"reasoning", "code", "creativity", "analysis"}
        ))
        
        self.register(ModelConfig(
            name="claude-3-sonnet",
            provider="anthropic",
            cost_per_1k_input=0.003,
            cost_per_1k_output=0.015,
            max_tokens=200000,
            latency_ms=1500,
            quality_score=0.90,
            capabilities={"reasoning", "code", "general"}
        ))
        
        self.register(ModelConfig(
            name="claude-3-haiku",
            provider="anthropic",
            cost_per_1k_input=0.00025,
            cost_per_1k_output=0.00125,
            max_tokens=200000,
            latency_ms=300,
            quality_score=0.75,
            capabilities={"general", "conversation", "extraction"}
        ))
        
        self.register(ModelConfig(
            name="llama-3-70b",
            provider="together",
            cost_per_1k_input=0.0009,
            cost_per_1k_output=0.0009,
            max_tokens=8192,
            latency_ms=800,
            quality_score=0.85,
            capabilities={"general", "code", "reasoning"}
        ))
        
        self.register(ModelConfig(
            name="mixtral-8x7b",
            provider="together",
            cost_per_1k_input=0.0006,
            cost_per_1k_output=0.0006,
            max_tokens=32768,
            latency_ms=600,
            quality_score=0.82,
            capabilities={"general", "code", "multilingual"}
        ))
    
    def register(self, config: ModelConfig):
        """Register a model."""
        
        self.models[config.name] = config
    
    def get(self, name: str) -> Optional[ModelConfig]:
        """Get model by name."""
        
        return self.models.get(name)
    
    def list_by_capability(self, capability: str) -> list[ModelConfig]:
        """List models with a capability."""
        
        return [
            m for m in self.models.values()
            if capability in m.capabilities
        ]

class ModelSelector:
    """Select optimal model for a request."""
    
    def __init__(self, registry: ModelRegistry = None):
        self.registry = registry or ModelRegistry()
    
    def select(
        self,
        request: ClassifiedRequest,
        constraints: dict = None
    ) -> ModelConfig:
        """Select best model for request."""
        
        constraints = constraints or {}
        
        # Get candidate models
        candidates = self._get_candidates(request, constraints)
        
        if not candidates:
            # Fallback to default
            return self.registry.get("gpt-3.5-turbo")
        
        # Score candidates
        scored = []
        for model in candidates:
            score = self._score_model(model, request, constraints)
            scored.append((score, model))
        
        # Return best
        scored.sort(reverse=True)
        return scored[0][1]
    
    def _get_candidates(
        self,
        request: ClassifiedRequest,
        constraints: dict
    ) -> list[ModelConfig]:
        """Get candidate models."""
        
        candidates = list(self.registry.models.values())
        
        # Filter by max cost
        if "max_cost" in constraints:
            max_cost = constraints["max_cost"]
            candidates = [
                m for m in candidates
                if m.cost_per_1k_input <= max_cost
            ]
        
        # Filter by max latency
        if "max_latency_ms" in constraints:
            max_latency = constraints["max_latency_ms"]
            candidates = [
                m for m in candidates
                if m.latency_ms <= max_latency
            ]
        
        # Filter by required capabilities
        required_caps = set()
        
        if request.requires_reasoning:
            required_caps.add("reasoning")
        if request.requires_creativity:
            required_caps.add("creativity")
        if request.request_type == RequestType.CODE:
            required_caps.add("code")
        
        if required_caps:
            candidates = [
                m for m in candidates
                if required_caps.issubset(m.capabilities)
            ]
        
        # Filter by token limit
        if request.estimated_tokens > 0:
            candidates = [
                m for m in candidates
                if m.max_tokens >= request.estimated_tokens * 2
            ]
        
        return candidates
    
    def _score_model(
        self,
        model: ModelConfig,
        request: ClassifiedRequest,
        constraints: dict
    ) -> float:
        """Score a model for the request."""
        
        # Base score from quality
        score = model.quality_score * 100
        
        # Adjust for complexity match
        complexity_scores = {
            RequestComplexity.SIMPLE: 0.7,
            RequestComplexity.MODERATE: 0.85,
            RequestComplexity.COMPLEX: 0.95,
            RequestComplexity.EXPERT: 1.0
        }
        
        required_quality = complexity_scores[request.complexity]
        
        if model.quality_score >= required_quality:
            # Bonus for meeting quality threshold
            score += 10
        else:
            # Penalty for insufficient quality
            score -= 20
        
        # Cost efficiency (lower is better)
        cost_factor = 1 / (1 + model.cost_per_1k_input * 100)
        score += cost_factor * 20
        
        # Latency factor (lower is better)
        latency_factor = 1 / (1 + model.latency_ms / 1000)
        score += latency_factor * 15
        
        # Capability match bonus
        if request.requires_reasoning and "reasoning" in model.capabilities:
            score += 5
        if request.requires_creativity and "creativity" in model.capabilities:
            score += 5
        
        return score

class CostOptimizedSelector(ModelSelector):
    """Selector optimized for cost."""
    
    def _score_model(
        self,
        model: ModelConfig,
        request: ClassifiedRequest,
        constraints: dict
    ) -> float:
        """Score with heavy cost weighting."""
        
        base_score = super()._score_model(model, request, constraints)
        
        # Additional cost penalty
        cost_penalty = model.cost_per_1k_input * 1000
        
        return base_score - cost_penalty

class QualityOptimizedSelector(ModelSelector):
    """Selector optimized for quality."""
    
    def _score_model(
        self,
        model: ModelConfig,
        request: ClassifiedRequest,
        constraints: dict
    ) -> float:
        """Score with heavy quality weighting."""
        
        base_score = super()._score_model(model, request, constraints)
        
        # Additional quality bonus
        quality_bonus = model.quality_score * 50
        
        return base_score + quality_bonus

Load Balancing

from dataclasses import dataclass, field
from typing import Any, Optional, List
from datetime import datetime, timedelta
import random
import asyncio

@dataclass
class EndpointHealth:
    """Health status of an endpoint."""
    
    endpoint: str
    is_healthy: bool = True
    last_check: datetime = field(default_factory=datetime.now)
    consecutive_failures: int = 0
    latency_ms: float = 0
    requests_per_minute: int = 0
    error_rate: float = 0

class LoadBalancer:
    """Base load balancer."""
    
    def __init__(self, endpoints: list[str]):
        self.endpoints = endpoints
        self.health: dict[str, EndpointHealth] = {
            ep: EndpointHealth(endpoint=ep) for ep in endpoints
        }
    
    def select(self) -> str:
        """Select an endpoint."""
        
        healthy = [ep for ep in self.endpoints if self.health[ep].is_healthy]
        
        if not healthy:
            # All unhealthy, try any
            return random.choice(self.endpoints)
        
        return self._select_from_healthy(healthy)
    
    def _select_from_healthy(self, healthy: list[str]) -> str:
        """Select from healthy endpoints."""
        
        raise NotImplementedError
    
    def report_success(self, endpoint: str, latency_ms: float):
        """Report successful request."""
        
        health = self.health[endpoint]
        health.is_healthy = True
        health.consecutive_failures = 0
        health.latency_ms = (health.latency_ms + latency_ms) / 2
        health.last_check = datetime.now()
    
    def report_failure(self, endpoint: str):
        """Report failed request."""
        
        health = self.health[endpoint]
        health.consecutive_failures += 1
        health.last_check = datetime.now()
        
        if health.consecutive_failures >= 3:
            health.is_healthy = False

class RoundRobinBalancer(LoadBalancer):
    """Round-robin load balancing."""
    
    def __init__(self, endpoints: list[str]):
        super().__init__(endpoints)
        self.current_index = 0
    
    def _select_from_healthy(self, healthy: list[str]) -> str:
        """Select next endpoint in rotation."""
        
        endpoint = healthy[self.current_index % len(healthy)]
        self.current_index += 1
        return endpoint

class WeightedBalancer(LoadBalancer):
    """Weighted load balancing."""
    
    def __init__(self, endpoints: list[str], weights: dict[str, float] = None):
        super().__init__(endpoints)
        self.weights = weights or {ep: 1.0 for ep in endpoints}
    
    def _select_from_healthy(self, healthy: list[str]) -> str:
        """Select based on weights."""
        
        total_weight = sum(self.weights[ep] for ep in healthy)
        
        r = random.uniform(0, total_weight)
        cumulative = 0
        
        for ep in healthy:
            cumulative += self.weights[ep]
            if r <= cumulative:
                return ep
        
        return healthy[-1]

class LeastConnectionsBalancer(LoadBalancer):
    """Least connections load balancing."""
    
    def __init__(self, endpoints: list[str]):
        super().__init__(endpoints)
        self.active_connections: dict[str, int] = {ep: 0 for ep in endpoints}
    
    def _select_from_healthy(self, healthy: list[str]) -> str:
        """Select endpoint with fewest connections."""
        
        return min(healthy, key=lambda ep: self.active_connections[ep])
    
    def acquire(self, endpoint: str):
        """Mark connection as active."""
        
        self.active_connections[endpoint] += 1
    
    def release(self, endpoint: str):
        """Mark connection as released."""
        
        self.active_connections[endpoint] = max(
            0, self.active_connections[endpoint] - 1
        )

class LatencyBasedBalancer(LoadBalancer):
    """Latency-based load balancing."""
    
    def __init__(self, endpoints: list[str]):
        super().__init__(endpoints)
        self.latencies: dict[str, list[float]] = {ep: [] for ep in endpoints}
    
    def _select_from_healthy(self, healthy: list[str]) -> str:
        """Select endpoint with lowest latency."""
        
        avg_latencies = {}
        
        for ep in healthy:
            if self.latencies[ep]:
                # Use recent latencies
                recent = self.latencies[ep][-10:]
                avg_latencies[ep] = sum(recent) / len(recent)
            else:
                avg_latencies[ep] = float('inf')
        
        return min(healthy, key=lambda ep: avg_latencies[ep])
    
    def report_success(self, endpoint: str, latency_ms: float):
        """Record latency."""
        
        super().report_success(endpoint, latency_ms)
        
        self.latencies[endpoint].append(latency_ms)
        
        # Keep only recent
        if len(self.latencies[endpoint]) > 100:
            self.latencies[endpoint] = self.latencies[endpoint][-100:]

class AdaptiveBalancer(LoadBalancer):
    """Adaptive load balancing based on multiple factors."""
    
    def __init__(self, endpoints: list[str]):
        super().__init__(endpoints)
        self.metrics: dict[str, dict] = {
            ep: {
                "latencies": [],
                "errors": 0,
                "successes": 0,
                "last_error": None
            }
            for ep in endpoints
        }
    
    def _select_from_healthy(self, healthy: list[str]) -> str:
        """Select based on adaptive scoring."""
        
        scores = {}
        
        for ep in healthy:
            scores[ep] = self._compute_score(ep)
        
        # Weighted random selection based on scores
        total = sum(scores.values())
        
        if total == 0:
            return random.choice(healthy)
        
        r = random.uniform(0, total)
        cumulative = 0
        
        for ep, score in scores.items():
            cumulative += score
            if r <= cumulative:
                return ep
        
        return healthy[-1]
    
    def _compute_score(self, endpoint: str) -> float:
        """Compute endpoint score."""
        
        metrics = self.metrics[endpoint]
        
        # Base score
        score = 100
        
        # Latency factor
        if metrics["latencies"]:
            avg_latency = sum(metrics["latencies"][-10:]) / len(metrics["latencies"][-10:])
            score -= avg_latency / 100  # Penalty for high latency
        
        # Error rate factor
        total = metrics["errors"] + metrics["successes"]
        if total > 0:
            error_rate = metrics["errors"] / total
            score -= error_rate * 50  # Heavy penalty for errors
        
        # Recent error penalty
        if metrics["last_error"]:
            time_since_error = (datetime.now() - metrics["last_error"]).total_seconds()
            if time_since_error < 60:
                score -= 20  # Recent error penalty
        
        return max(score, 1)  # Minimum score of 1
    
    def report_success(self, endpoint: str, latency_ms: float):
        """Record success."""
        
        super().report_success(endpoint, latency_ms)
        
        self.metrics[endpoint]["latencies"].append(latency_ms)
        self.metrics[endpoint]["successes"] += 1
        
        # Trim latencies
        if len(self.metrics[endpoint]["latencies"]) > 100:
            self.metrics[endpoint]["latencies"] = self.metrics[endpoint]["latencies"][-100:]
    
    def report_failure(self, endpoint: str):
        """Record failure."""
        
        super().report_failure(endpoint)
        
        self.metrics[endpoint]["errors"] += 1
        self.metrics[endpoint]["last_error"] = datetime.now()

Fallback and Retry

from dataclasses import dataclass
from typing import Any, Optional, Callable
import asyncio
import time

@dataclass
class RetryConfig:
    """Retry configuration."""
    
    max_retries: int = 3
    base_delay_ms: int = 100
    max_delay_ms: int = 5000
    exponential_base: float = 2.0
    jitter: bool = True

class RetryHandler:
    """Handle retries with exponential backoff."""
    
    def __init__(self, config: RetryConfig = None):
        self.config = config or RetryConfig()
    
    async def execute(
        self,
        func: Callable,
        *args,
        **kwargs
    ) -> Any:
        """Execute with retries."""
        
        last_error = None
        
        for attempt in range(self.config.max_retries + 1):
            try:
                return await func(*args, **kwargs)
            except Exception as e:
                last_error = e
                
                if attempt < self.config.max_retries:
                    delay = self._calculate_delay(attempt)
                    await asyncio.sleep(delay / 1000)
        
        raise last_error
    
    def _calculate_delay(self, attempt: int) -> float:
        """Calculate delay for attempt."""
        
        delay = self.config.base_delay_ms * (
            self.config.exponential_base ** attempt
        )
        
        delay = min(delay, self.config.max_delay_ms)
        
        if self.config.jitter:
            import random
            delay *= random.uniform(0.5, 1.5)
        
        return delay

class FallbackChain:
    """Chain of fallback models."""
    
    def __init__(self, models: list[str]):
        self.models = models
        self.retry_handler = RetryHandler()
    
    async def execute(
        self,
        request: dict,
        call_model: Callable
    ) -> Any:
        """Execute with fallbacks."""
        
        errors = []
        
        for model in self.models:
            try:
                return await self.retry_handler.execute(
                    call_model, model, request
                )
            except Exception as e:
                errors.append((model, e))
                continue
        
        # All models failed
        raise Exception(f"All models failed: {errors}")

class CircuitBreaker:
    """Circuit breaker for model calls."""
    
    def __init__(
        self,
        failure_threshold: int = 5,
        recovery_timeout: int = 30
    ):
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        
        self.failures: dict[str, int] = {}
        self.last_failure: dict[str, float] = {}
        self.state: dict[str, str] = {}  # closed, open, half-open
    
    def can_execute(self, model: str) -> bool:
        """Check if model can be called."""
        
        state = self.state.get(model, "closed")
        
        if state == "closed":
            return True
        
        if state == "open":
            # Check if recovery timeout passed
            last = self.last_failure.get(model, 0)
            if time.time() - last > self.recovery_timeout:
                self.state[model] = "half-open"
                return True
            return False
        
        if state == "half-open":
            return True
        
        return True
    
    def record_success(self, model: str):
        """Record successful call."""
        
        self.failures[model] = 0
        self.state[model] = "closed"
    
    def record_failure(self, model: str):
        """Record failed call."""
        
        self.failures[model] = self.failures.get(model, 0) + 1
        self.last_failure[model] = time.time()
        
        if self.failures[model] >= self.failure_threshold:
            self.state[model] = "open"

class RateLimiter:
    """Rate limiter for model calls."""
    
    def __init__(self, requests_per_minute: int = 60):
        self.rpm = requests_per_minute
        self.requests: dict[str, list[float]] = {}
    
    def can_execute(self, model: str) -> bool:
        """Check if within rate limit."""
        
        now = time.time()
        
        if model not in self.requests:
            self.requests[model] = []
        
        # Remove old requests
        self.requests[model] = [
            t for t in self.requests[model]
            if now - t < 60
        ]
        
        return len(self.requests[model]) < self.rpm
    
    def record_request(self, model: str):
        """Record a request."""
        
        if model not in self.requests:
            self.requests[model] = []
        
        self.requests[model].append(time.time())
    
    async def wait_if_needed(self, model: str):
        """Wait if rate limited."""
        
        while not self.can_execute(model):
            await asyncio.sleep(0.1)

Production Router Service

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, List
import time
import asyncio

app = FastAPI()

class RoutingRequest(BaseModel):
    prompt: str
    max_cost: Optional[float] = None
    max_latency_ms: Optional[int] = None
    preferred_models: Optional[list[str]] = None

class RoutingResponse(BaseModel):
    selected_model: str
    response: str
    latency_ms: float
    cost: float
    fallbacks_used: int

class RouterMetrics:
    """Track routing metrics."""
    
    def __init__(self):
        self.requests = 0
        self.model_usage: dict[str, int] = {}
        self.total_cost = 0.0
        self.total_latency = 0.0
        self.fallbacks = 0
    
    def record(
        self,
        model: str,
        cost: float,
        latency: float,
        fallbacks: int
    ):
        self.requests += 1
        self.model_usage[model] = self.model_usage.get(model, 0) + 1
        self.total_cost += cost
        self.total_latency += latency
        self.fallbacks += fallbacks
    
    def stats(self) -> dict:
        return {
            "total_requests": self.requests,
            "model_distribution": self.model_usage,
            "total_cost": self.total_cost,
            "avg_latency_ms": self.total_latency / self.requests if self.requests > 0 else 0,
            "fallback_rate": self.fallbacks / self.requests if self.requests > 0 else 0
        }

metrics = RouterMetrics()

# Initialize components
classifier = RequestClassifier()
registry = ModelRegistry()
selector = ModelSelector(registry)
balancer = AdaptiveBalancer(["endpoint1", "endpoint2", "endpoint3"])
circuit_breaker = CircuitBreaker()
rate_limiter = RateLimiter()

async def call_model(model: str, prompt: str) -> tuple[str, float]:
    """Call LLM model (mock implementation)."""
    
    await asyncio.sleep(0.1)  # Simulate latency
    return f"Response from {model}", 0.001

@app.post("/v1/route")
async def route_request(request: RoutingRequest) -> RoutingResponse:
    """Route request to optimal model."""
    
    start = time.time()
    
    # Classify request
    classified = classifier.classify(request.prompt)
    
    # Build constraints
    constraints = {}
    if request.max_cost:
        constraints["max_cost"] = request.max_cost
    if request.max_latency_ms:
        constraints["max_latency_ms"] = request.max_latency_ms
    
    # Select model
    model_config = selector.select(classified, constraints)
    
    # Check circuit breaker
    if not circuit_breaker.can_execute(model_config.name):
        raise HTTPException(status_code=503, detail="Model circuit open")
    
    # Check rate limit
    await rate_limiter.wait_if_needed(model_config.name)
    
    # Call model
    fallbacks_used = 0
    
    try:
        response, cost = await call_model(model_config.name, request.prompt)
        circuit_breaker.record_success(model_config.name)
        rate_limiter.record_request(model_config.name)
    except Exception as e:
        circuit_breaker.record_failure(model_config.name)
        raise HTTPException(status_code=500, detail=str(e))
    
    latency = (time.time() - start) * 1000
    
    # Record metrics
    metrics.record(model_config.name, cost, latency, fallbacks_used)
    
    return RoutingResponse(
        selected_model=model_config.name,
        response=response,
        latency_ms=latency,
        cost=cost,
        fallbacks_used=fallbacks_used
    )

@app.get("/v1/router/stats")
async def get_stats() -> dict:
    """Get router statistics."""
    return metrics.stats()

@app.get("/v1/models")
async def list_models() -> list[dict]:
    """List available models."""
    return [
        {
            "name": m.name,
            "provider": m.provider,
            "cost_per_1k_input": m.cost_per_1k_input,
            "quality_score": m.quality_score
        }
        for m in registry.models.values()
    ]

@app.get("/health")
async def health():
    return {"status": "healthy"}

References

Conclusion

LLM routing transforms how you manage model costs and performance at scale. The key insight is that request complexity varies dramatically—simple classification tasks don't need GPT-4, while complex reasoning tasks suffer with smaller models. Start with rule-based classification using prompt patterns and length heuristics; this handles 80% of cases well. For the remaining 20%, consider training a small classifier on your actual traffic patterns. Model selection should balance quality, cost, and latency based on your specific constraints—there's no universal "best" model. Load balancing across multiple endpoints and providers improves reliability and can reduce costs through spot pricing arbitrage. Circuit breakers and rate limiters are essential for production systems—they prevent cascade failures when a provider has issues. The fallback chain ensures requests eventually succeed even when primary models fail. Monitor your routing decisions closely: track which models handle which request types, measure quality degradation when routing to cheaper models, and continuously tune your classification thresholds. A well-tuned router can reduce LLM costs by 50-70% while maintaining equivalent output quality for most requests.


Discover more from Code, Cloud & Context

Subscribe to get the latest posts sent to your email.

About the Author

I am a Cloud Architect and Developer passionate about solving complex problems with modern technology. My blog explores the intersection of Cloud Architecture, Artificial Intelligence, and Software Engineering. I share tutorials, deep dives, and insights into building scalable, intelligent systems.

Areas of Expertise

Cloud Architecture (Azure, AWS)
Artificial Intelligence & LLMs
DevOps & Kubernetes
Backend Dev (C#, .NET, Python, Node.js)
© 2025 Code, Cloud & Context | Built by Nithin Mohan TK | Powered by Passion