Categories

Archives

A sample text widget

Etiam pulvinar consectetur dolor sed malesuada. Ut convallis euismod dolor nec pretium. Nunc ut tristique massa.

Nam sodales mi vitae dolor ullamcorper et vulputate enim accumsan. Morbi orci magna, tincidunt vitae molestie nec, molestie at mi. Nulla nulla lorem, suscipit in posuere in, interdum non magna.

Structured Output from LLMs: JSON Mode, Function Calling, and Pydantic Patterns

Introduction: Getting reliable, structured data from LLMs is one of the most practical challenges in building AI applications. Whether you’re extracting entities from text, generating API parameters, or building data pipelines, you need JSON that actually parses and validates against your schema. This guide covers the evolution of structured output techniques—from prompt engineering hacks to native JSON mode and the Instructor library—with production patterns for handling the inevitable edge cases where models don’t follow instructions perfectly.

Structured Output from LLMs
Structured Output: From Prompts to Validated Objects

OpenAI JSON Mode

from openai import OpenAI
import json

client = OpenAI()

# Basic JSON mode
response = client.chat.completions.create(
    model="gpt-4-turbo-preview",
    messages=[
        {
            "role": "system",
            "content": "Extract information and return valid JSON."
        },
        {
            "role": "user",
            "content": """Extract the following from this text:
- Person's name
- Company
- Role
- Contact email

Text: "Hi, I'm Sarah Chen, Senior Engineer at TechCorp. Reach me at sarah@techcorp.com"

Return JSON with keys: name, company, role, email"""
        }
    ],
    response_format={"type": "json_object"}
)

data = json.loads(response.choices[0].message.content)
print(data)
# {"name": "Sarah Chen", "company": "TechCorp", "role": "Senior Engineer", "email": "sarah@techcorp.com"}

# Structured Outputs with JSON Schema (GPT-4o and later)
response = client.chat.completions.create(
    model="gpt-4o-2024-08-06",
    messages=[
        {"role": "system", "content": "Extract meeting details."},
        {"role": "user", "content": "Schedule a meeting with John tomorrow at 3pm to discuss Q4 planning"}
    ],
    response_format={
        "type": "json_schema",
        "json_schema": {
            "name": "meeting",
            "schema": {
                "type": "object",
                "properties": {
                    "attendees": {"type": "array", "items": {"type": "string"}},
                    "date": {"type": "string", "description": "ISO date format"},
                    "time": {"type": "string", "description": "24-hour format"},
                    "topic": {"type": "string"},
                    "duration_minutes": {"type": "integer"}
                },
                "required": ["attendees", "date", "time", "topic"],
                "additionalProperties": False
            },
            "strict": True
        }
    }
)

meeting = json.loads(response.choices[0].message.content)
print(meeting)

Function Calling for Structured Output

from openai import OpenAI
import json

client = OpenAI()

# Define the schema as a function
tools = [
    {
        "type": "function",
        "function": {
            "name": "extract_product_info",
            "description": "Extract product information from text",
            "parameters": {
                "type": "object",
                "properties": {
                    "name": {
                        "type": "string",
                        "description": "Product name"
                    },
                    "price": {
                        "type": "number",
                        "description": "Price in USD"
                    },
                    "currency": {
                        "type": "string",
                        "enum": ["USD", "EUR", "GBP"]
                    },
                    "features": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "List of product features"
                    },
                    "in_stock": {
                        "type": "boolean"
                    },
                    "category": {
                        "type": "string",
                        "enum": ["electronics", "clothing", "home", "other"]
                    }
                },
                "required": ["name", "price", "category"]
            }
        }
    }
]

response = client.chat.completions.create(
    model="gpt-4-turbo-preview",
    messages=[
        {
            "role": "user",
            "content": """Extract product info from:
"The new iPhone 15 Pro is available for $999. Features include A17 chip, 
titanium design, and USB-C. Currently in stock."""
        }
    ],
    tools=tools,
    tool_choice={"type": "function", "function": {"name": "extract_product_info"}}
)

# Parse the function call arguments
tool_call = response.choices[0].message.tool_calls[0]
product = json.loads(tool_call.function.arguments)
print(product)
# {"name": "iPhone 15 Pro", "price": 999, "currency": "USD", 
#  "features": ["A17 chip", "titanium design", "USB-C"], 
#  "in_stock": true, "category": "electronics"}

Pydantic with Instructor

# pip install instructor

import instructor
from openai import OpenAI
from pydantic import BaseModel, Field, field_validator
from typing import Optional
from enum import Enum

# Patch OpenAI client with Instructor
client = instructor.from_openai(OpenAI())

class Priority(str, Enum):
    LOW = "low"
    MEDIUM = "medium"
    HIGH = "high"
    CRITICAL = "critical"

class Task(BaseModel):
    """A task extracted from natural language."""
    title: str = Field(description="Brief task title")
    description: Optional[str] = Field(default=None, description="Detailed description")
    assignee: Optional[str] = Field(default=None, description="Person assigned")
    priority: Priority = Field(default=Priority.MEDIUM)
    due_date: Optional[str] = Field(default=None, description="Due date in YYYY-MM-DD format")
    tags: list[str] = Field(default_factory=list)
    
    @field_validator("due_date")
    @classmethod
    def validate_date(cls, v):
        if v is None:
            return v
        from datetime import datetime
        try:
            datetime.strptime(v, "%Y-%m-%d")
            return v
        except ValueError:
            raise ValueError("Date must be in YYYY-MM-DD format")

# Extract with automatic validation and retry
task = client.chat.completions.create(
    model="gpt-4-turbo-preview",
    messages=[
        {
            "role": "user",
            "content": "Create a high priority task for John to review the Q4 budget by next Friday"
        }
    ],
    response_model=Task,
    max_retries=3  # Automatically retry on validation failure
)

print(task.model_dump_json(indent=2))

# Complex nested structures
class Address(BaseModel):
    street: str
    city: str
    state: str
    zip_code: str
    country: str = "USA"

class Person(BaseModel):
    name: str
    email: str
    phone: Optional[str] = None
    address: Optional[Address] = None
    tags: list[str] = Field(default_factory=list)

class Company(BaseModel):
    name: str
    industry: str
    employees: list[Person]
    headquarters: Address
    founded_year: Optional[int] = None

# Extract complex nested data
company = client.chat.completions.create(
    model="gpt-4-turbo-preview",
    messages=[
        {
            "role": "user",
            "content": """Extract company info:
TechStartup Inc, a fintech company founded in 2020, is headquartered at 
123 Innovation Way, San Francisco, CA 94105. Key employees include 
CEO Jane Smith (jane@techstartup.com) and CTO Bob Johnson (bob@techstartup.com)."""
        }
    ],
    response_model=Company
)

print(company.model_dump_json(indent=2))

Handling Extraction Failures

import instructor
from openai import OpenAI
from pydantic import BaseModel, Field, ValidationError
from typing import Optional, Union
import json

client = instructor.from_openai(OpenAI())

class ExtractionResult(BaseModel):
    """Wrapper for extraction with confidence."""
    data: dict
    confidence: float = Field(ge=0, le=1, description="Confidence score 0-1")
    missing_fields: list[str] = Field(default_factory=list)
    warnings: list[str] = Field(default_factory=list)

def robust_extract(
    text: str,
    schema: type[BaseModel],
    max_retries: int = 3
) -> Union[BaseModel, dict]:
    """Extract with fallback handling."""
    
    # First attempt: strict extraction
    try:
        result = client.chat.completions.create(
            model="gpt-4-turbo-preview",
            messages=[
                {
                    "role": "system",
                    "content": f"Extract information matching this schema. If information is missing, use null."
                },
                {"role": "user", "content": text}
            ],
            response_model=schema,
            max_retries=max_retries
        )
        return result
    except ValidationError as e:
        print(f"Validation failed: {e}")
    
    # Fallback: lenient JSON extraction
    try:
        raw_client = OpenAI()
        response = raw_client.chat.completions.create(
            model="gpt-4-turbo-preview",
            messages=[
                {
                    "role": "system",
                    "content": "Extract as much information as possible. Return valid JSON."
                },
                {"role": "user", "content": text}
            ],
            response_format={"type": "json_object"}
        )
        
        data = json.loads(response.choices[0].message.content)
        
        # Try to coerce into schema
        try:
            return schema.model_validate(data)
        except ValidationError:
            # Return raw data with warning
            return {
                "raw_data": data,
                "schema_validation_failed": True,
                "original_text": text[:200]
            }
    except Exception as e:
        return {"error": str(e), "original_text": text[:200]}

# Batch extraction with progress
from tqdm import tqdm

def batch_extract(
    texts: list[str],
    schema: type[BaseModel],
    show_progress: bool = True
) -> list[dict]:
    """Extract from multiple texts with error handling."""
    
    results = []
    iterator = tqdm(texts) if show_progress else texts
    
    for text in iterator:
        try:
            result = robust_extract(text, schema)
            if isinstance(result, BaseModel):
                results.append({"success": True, "data": result.model_dump()})
            else:
                results.append({"success": False, "data": result})
        except Exception as e:
            results.append({"success": False, "error": str(e)})
    
    success_rate = sum(1 for r in results if r["success"]) / len(results)
    print(f"Success rate: {success_rate:.1%}")
    
    return results

Streaming Structured Output

import instructor
from openai import OpenAI
from pydantic import BaseModel
from typing import Iterable

client = instructor.from_openai(OpenAI())

class SearchResult(BaseModel):
    title: str
    url: str
    snippet: str
    relevance_score: float

# Stream partial objects as they're generated
def stream_extraction(query: str) -> Iterable[SearchResult]:
    """Stream extracted results one at a time."""
    
    return client.chat.completions.create(
        model="gpt-4-turbo-preview",
        messages=[
            {
                "role": "user",
                "content": f"Generate 5 search results for: {query}"
            }
        ],
        response_model=Iterable[SearchResult],
        stream=True
    )

# Process results as they arrive
for result in stream_extraction("python async programming"):
    print(f"Found: {result.title} (score: {result.relevance_score})")

# Partial streaming for long extractions
from instructor import Partial

class Article(BaseModel):
    title: str
    author: str
    summary: str
    key_points: list[str]
    conclusion: str

# Get partial results during generation
for partial_article in client.chat.completions.create(
    model="gpt-4-turbo-preview",
    messages=[{"role": "user", "content": "Summarize the benefits of microservices"}],
    response_model=Partial[Article],
    stream=True
):
    # Access fields as they become available
    if partial_article.title:
        print(f"Title: {partial_article.title}")
    if partial_article.key_points:
        print(f"Points so far: {len(partial_article.key_points)}")

Claude and Other Models

import instructor
from anthropic import Anthropic
from pydantic import BaseModel

# Instructor works with Claude too
client = instructor.from_anthropic(Anthropic())

class Sentiment(BaseModel):
    text: str
    sentiment: str  # positive, negative, neutral
    confidence: float
    key_phrases: list[str]

result = client.messages.create(
    model="claude-3-5-sonnet-20241022",
    max_tokens=1024,
    messages=[
        {
            "role": "user",
            "content": "Analyze sentiment: 'This product exceeded my expectations!'"
        }
    ],
    response_model=Sentiment
)

print(result)

# For models without native JSON mode, use prompt engineering
def extract_with_prompt(text: str, schema: dict) -> dict:
    """Extract using careful prompting for any model."""
    
    schema_str = json.dumps(schema, indent=2)
    
    prompt = f"""Extract information from the text below and return ONLY valid JSON matching this schema:

Schema:
{schema_str}

Text:
{text}

Important:
- Return ONLY the JSON object, no other text
- Use null for missing values
- Ensure all required fields are present

JSON:"""

    # Works with any model
    response = some_llm_client.generate(prompt)
    
    # Clean and parse
    json_str = response.strip()
    if json_str.startswith("```"):
        json_str = json_str.split("```")[1]
        if json_str.startswith("json"):
            json_str = json_str[4:]
    
    return json.loads(json_str)

Production Patterns

from pydantic import BaseModel, Field
from typing import TypeVar, Generic
from datetime import datetime
import hashlib

T = TypeVar("T", bound=BaseModel)

class ExtractionMetadata(BaseModel):
    """Metadata for tracking extractions."""
    extraction_id: str
    model: str
    timestamp: datetime
    input_hash: str
    tokens_used: int
    latency_ms: float

class ExtractionResponse(BaseModel, Generic[T]):
    """Wrapper with metadata for production use."""
    data: T
    metadata: ExtractionMetadata
    
class ProductionExtractor:
    """Production-ready structured extraction."""
    
    def __init__(self, model: str = "gpt-4-turbo-preview"):
        self.client = instructor.from_openai(OpenAI())
        self.model = model
        self.cache = {}  # Use Redis in production
    
    def extract(
        self,
        text: str,
        schema: type[T],
        use_cache: bool = True
    ) -> ExtractionResponse[T]:
        """Extract with caching and metadata."""
        
        import time
        import uuid
        
        # Check cache
        input_hash = hashlib.md5(text.encode()).hexdigest()
        cache_key = f"{schema.__name__}:{input_hash}"
        
        if use_cache and cache_key in self.cache:
            return self.cache[cache_key]
        
        # Extract
        start = time.time()
        
        result = self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": text}],
            response_model=schema,
            max_retries=3
        )
        
        latency = (time.time() - start) * 1000
        
        # Build response
        response = ExtractionResponse(
            data=result,
            metadata=ExtractionMetadata(
                extraction_id=str(uuid.uuid4()),
                model=self.model,
                timestamp=datetime.now(),
                input_hash=input_hash,
                tokens_used=0,  # Get from response in production
                latency_ms=latency
            )
        )
        
        # Cache
        if use_cache:
            self.cache[cache_key] = response
        
        return response

# Usage
extractor = ProductionExtractor()
response = extractor.extract(
    "John Smith, CEO of Acme Corp, john@acme.com",
    Person
)
print(f"Extracted in {response.metadata.latency_ms:.0f}ms")
print(response.data)

References

Conclusion

Structured output transforms LLMs from text generators into reliable data extraction engines. The combination of JSON mode for guaranteed valid JSON, function calling for schema enforcement, and libraries like Instructor for Pydantic integration gives you multiple tools for different situations. Start with OpenAI’s native structured outputs for the most reliable results, fall back to function calling when you need enum constraints, and use Instructor when you want the full power of Pydantic validation. Always implement retry logic—even the best models occasionally produce invalid output. For production systems, add caching, monitoring, and graceful degradation to handle the edge cases that will inevitably occur at scale.

LLM Inference Optimization: Caching, Batching, and Smart Routing

Introduction: LLM inference can be slow and expensive, especially at scale. Optimizing inference is crucial for production applications where latency and cost directly impact user experience and business viability. This guide covers practical optimization techniques: semantic caching to avoid redundant API calls, request batching for throughput, streaming for perceived latency, model quantization for self-hosted models, and architectural patterns that balance quality with speed. These techniques can reduce costs by 50-90% and cut latency significantly without sacrificing output quality.

Inference Optimization
Inference Optimization: Caching, Batching, and Quantization

Semantic Caching

from openai import OpenAI
import hashlib
import json
import numpy as np
from typing import Optional
from datetime import datetime, timedelta

client = OpenAI()

class SemanticCache:
    """Cache LLM responses based on semantic similarity."""
    
    def __init__(
        self,
        similarity_threshold: float = 0.95,
        ttl_hours: int = 24
    ):
        self.cache: dict[str, dict] = {}
        self.embeddings: dict[str, list[float]] = {}
        self.similarity_threshold = similarity_threshold
        self.ttl = timedelta(hours=ttl_hours)
    
    def _get_embedding(self, text: str) -> list[float]:
        """Get embedding for text."""
        response = client.embeddings.create(
            model="text-embedding-3-small",
            input=text
        )
        return response.data[0].embedding
    
    def _compute_similarity(self, emb1: list[float], emb2: list[float]) -> float:
        """Compute cosine similarity."""
        return np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
    
    def _is_expired(self, entry: dict) -> bool:
        """Check if cache entry is expired."""
        created = datetime.fromisoformat(entry["created_at"])
        return datetime.now() - created > self.ttl
    
    def get(self, prompt: str) -> Optional[str]:
        """Get cached response if similar prompt exists."""
        
        prompt_embedding = self._get_embedding(prompt)
        
        best_match = None
        best_similarity = 0
        
        for key, embedding in self.embeddings.items():
            similarity = self._compute_similarity(prompt_embedding, embedding)
            
            if similarity > best_similarity and similarity >= self.similarity_threshold:
                entry = self.cache.get(key)
                if entry and not self._is_expired(entry):
                    best_match = entry
                    best_similarity = similarity
        
        if best_match:
            return best_match["response"]
        
        return None
    
    def set(self, prompt: str, response: str):
        """Cache a response."""
        
        key = hashlib.md5(prompt.encode()).hexdigest()
        embedding = self._get_embedding(prompt)
        
        self.cache[key] = {
            "prompt": prompt,
            "response": response,
            "created_at": datetime.now().isoformat()
        }
        self.embeddings[key] = embedding
    
    def get_stats(self) -> dict:
        """Get cache statistics."""
        
        valid_entries = sum(
            1 for entry in self.cache.values()
            if not self._is_expired(entry)
        )
        
        return {
            "total_entries": len(self.cache),
            "valid_entries": valid_entries,
            "expired_entries": len(self.cache) - valid_entries
        }

class CachedLLM:
    """LLM client with semantic caching."""
    
    def __init__(self, cache: SemanticCache = None):
        self.cache = cache or SemanticCache()
        self.stats = {"hits": 0, "misses": 0}
    
    def complete(
        self,
        prompt: str,
        model: str = "gpt-4o-mini",
        use_cache: bool = True
    ) -> str:
        """Get completion with caching."""
        
        if use_cache:
            cached = self.cache.get(prompt)
            if cached:
                self.stats["hits"] += 1
                return cached
        
        self.stats["misses"] += 1
        
        response = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": prompt}]
        )
        
        result = response.choices[0].message.content
        
        if use_cache:
            self.cache.set(prompt, result)
        
        return result

# Usage
llm = CachedLLM()

# First call - cache miss
response1 = llm.complete("What is machine learning?")

# Similar query - cache hit
response2 = llm.complete("Can you explain machine learning?")

print(f"Cache stats: {llm.stats}")

Request Batching

import asyncio
from dataclasses import dataclass
from typing import Callable
import time

@dataclass
class BatchRequest:
    prompt: str
    future: asyncio.Future
    created_at: float

class BatchProcessor:
    """Batch multiple requests for efficient processing."""
    
    def __init__(
        self,
        max_batch_size: int = 10,
        max_wait_ms: int = 100
    ):
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
        self.queue: list[BatchRequest] = []
        self.lock = asyncio.Lock()
        self.processing = False
    
    async def add_request(self, prompt: str) -> str:
        """Add a request to the batch queue."""
        
        future = asyncio.Future()
        request = BatchRequest(
            prompt=prompt,
            future=future,
            created_at=time.time()
        )
        
        async with self.lock:
            self.queue.append(request)
            
            # Start processing if batch is full
            if len(self.queue) >= self.max_batch_size:
                asyncio.create_task(self._process_batch())
            elif not self.processing:
                # Schedule processing after max_wait
                asyncio.create_task(self._delayed_process())
        
        return await future
    
    async def _delayed_process(self):
        """Process batch after delay."""
        
        await asyncio.sleep(self.max_wait_ms / 1000)
        await self._process_batch()
    
    async def _process_batch(self):
        """Process all queued requests."""
        
        async with self.lock:
            if not self.queue or self.processing:
                return
            
            self.processing = True
            batch = self.queue[:self.max_batch_size]
            self.queue = self.queue[self.max_batch_size:]
        
        try:
            # Process batch (in real implementation, use batch API)
            results = await self._call_llm_batch([r.prompt for r in batch])
            
            for request, result in zip(batch, results):
                request.future.set_result(result)
                
        except Exception as e:
            for request in batch:
                request.future.set_exception(e)
        
        finally:
            self.processing = False
            
            # Process remaining if any
            if self.queue:
                asyncio.create_task(self._process_batch())
    
    async def _call_llm_batch(self, prompts: list[str]) -> list[str]:
        """Call LLM for batch of prompts."""
        
        # Use asyncio.gather for parallel processing
        tasks = [
            self._call_single(prompt)
            for prompt in prompts
        ]
        
        return await asyncio.gather(*tasks)
    
    async def _call_single(self, prompt: str) -> str:
        """Call LLM for single prompt."""
        
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}]
        )
        
        return response.choices[0].message.content

# Usage
async def main():
    batcher = BatchProcessor(max_batch_size=5, max_wait_ms=50)
    
    # Submit multiple requests
    tasks = [
        batcher.add_request(f"What is {topic}?")
        for topic in ["Python", "JavaScript", "Rust", "Go", "Java"]
    ]
    
    results = await asyncio.gather(*tasks)
    
    for topic, result in zip(["Python", "JavaScript", "Rust", "Go", "Java"], results):
        print(f"{topic}: {result[:50]}...")

# asyncio.run(main())

Streaming for Perceived Latency

from typing import Generator, AsyncGenerator

def stream_completion(
    prompt: str,
    model: str = "gpt-4o-mini"
) -> Generator[str, None, None]:
    """Stream completion tokens."""
    
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        stream=True
    )
    
    for chunk in response:
        if chunk.choices[0].delta.content:
            yield chunk.choices[0].delta.content

async def async_stream_completion(
    prompt: str,
    model: str = "gpt-4o-mini"
) -> AsyncGenerator[str, None]:
    """Async stream completion tokens."""
    
    response = await client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        stream=True
    )
    
    async for chunk in response:
        if chunk.choices[0].delta.content:
            yield chunk.choices[0].delta.content

class StreamingBuffer:
    """Buffer streaming output for processing."""
    
    def __init__(self):
        self.buffer = ""
        self.complete_sentences: list[str] = []
    
    def add_chunk(self, chunk: str) -> list[str]:
        """Add chunk and return complete sentences."""
        
        self.buffer += chunk
        
        # Check for sentence boundaries
        new_sentences = []
        
        while True:
            # Find sentence end
            for end_char in [". ", "! ", "? ", ".\n", "!\n", "?\n"]:
                idx = self.buffer.find(end_char)
                if idx != -1:
                    sentence = self.buffer[:idx + 1].strip()
                    self.buffer = self.buffer[idx + len(end_char):]
                    new_sentences.append(sentence)
                    break
            else:
                break
        
        self.complete_sentences.extend(new_sentences)
        return new_sentences
    
    def flush(self) -> str:
        """Get remaining buffer content."""
        remaining = self.buffer.strip()
        self.buffer = ""
        return remaining

# Usage with FastAPI
from fastapi import FastAPI
from fastapi.responses import StreamingResponse

app = FastAPI()

@app.get("/stream")
async def stream_response(prompt: str):
    """Stream LLM response."""
    
    async def generate():
        async for chunk in async_stream_completion(prompt):
            yield f"data: {chunk}\n\n"
        yield "data: [DONE]\n\n"
    
    return StreamingResponse(
        generate(),
        media_type="text/event-stream"
    )

Model Selection and Routing

from enum import Enum
from dataclasses import dataclass

class TaskComplexity(str, Enum):
    SIMPLE = "simple"
    MODERATE = "moderate"
    COMPLEX = "complex"

@dataclass
class ModelConfig:
    name: str
    cost_per_1k_input: float
    cost_per_1k_output: float
    avg_latency_ms: float
    max_tokens: int

class ModelRouter:
    """Route requests to appropriate models based on complexity."""
    
    def __init__(self):
        self.models = {
            TaskComplexity.SIMPLE: ModelConfig(
                name="gpt-4o-mini",
                cost_per_1k_input=0.00015,
                cost_per_1k_output=0.0006,
                avg_latency_ms=500,
                max_tokens=16384
            ),
            TaskComplexity.MODERATE: ModelConfig(
                name="gpt-4o",
                cost_per_1k_input=0.0025,
                cost_per_1k_output=0.01,
                avg_latency_ms=1000,
                max_tokens=128000
            ),
            TaskComplexity.COMPLEX: ModelConfig(
                name="gpt-4o",
                cost_per_1k_input=0.0025,
                cost_per_1k_output=0.01,
                avg_latency_ms=1500,
                max_tokens=128000
            )
        }
    
    def classify_complexity(self, prompt: str) -> TaskComplexity:
        """Classify task complexity."""
        
        # Simple heuristics
        word_count = len(prompt.split())
        
        complex_indicators = [
            "analyze", "compare", "evaluate", "synthesize",
            "explain in detail", "step by step", "comprehensive"
        ]
        
        simple_indicators = [
            "what is", "define", "list", "name",
            "yes or no", "true or false"
        ]
        
        prompt_lower = prompt.lower()
        
        if any(ind in prompt_lower for ind in simple_indicators) and word_count < 50:
            return TaskComplexity.SIMPLE
        
        if any(ind in prompt_lower for ind in complex_indicators) or word_count > 200:
            return TaskComplexity.COMPLEX
        
        return TaskComplexity.MODERATE
    
    def route(self, prompt: str) -> ModelConfig:
        """Route to appropriate model."""
        
        complexity = self.classify_complexity(prompt)
        return self.models[complexity]
    
    def complete(self, prompt: str) -> tuple[str, dict]:
        """Complete with automatic routing."""
        
        config = self.route(prompt)
        
        response = client.chat.completions.create(
            model=config.name,
            messages=[{"role": "user", "content": prompt}]
        )
        
        result = response.choices[0].message.content
        
        # Calculate cost
        input_tokens = response.usage.prompt_tokens
        output_tokens = response.usage.completion_tokens
        
        cost = (
            (input_tokens / 1000) * config.cost_per_1k_input +
            (output_tokens / 1000) * config.cost_per_1k_output
        )
        
        metadata = {
            "model": config.name,
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
            "cost": cost
        }
        
        return result, metadata

# Usage
router = ModelRouter()

# Simple query -> gpt-4o-mini
result, meta = router.complete("What is Python?")
print(f"Model: {meta['model']}, Cost: ${meta['cost']:.6f}")

# Complex query -> gpt-4o
result, meta = router.complete(
    "Analyze the trade-offs between microservices and monolithic architectures, "
    "considering scalability, maintainability, and operational complexity."
)
print(f"Model: {meta['model']}, Cost: ${meta['cost']:.6f}")

Parallel Processing

import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Callable

class ParallelLLM:
    """Process multiple LLM calls in parallel."""
    
    def __init__(self, max_concurrent: int = 10):
        self.semaphore = asyncio.Semaphore(max_concurrent)
        self.executor = ThreadPoolExecutor(max_workers=max_concurrent)
    
    async def _call_with_semaphore(
        self,
        prompt: str,
        model: str = "gpt-4o-mini"
    ) -> str:
        """Call LLM with concurrency limit."""
        
        async with self.semaphore:
            response = client.chat.completions.create(
                model=model,
                messages=[{"role": "user", "content": prompt}]
            )
            return response.choices[0].message.content
    
    async def batch_complete(
        self,
        prompts: list[str],
        model: str = "gpt-4o-mini"
    ) -> list[str]:
        """Complete multiple prompts in parallel."""
        
        tasks = [
            self._call_with_semaphore(prompt, model)
            for prompt in prompts
        ]
        
        return await asyncio.gather(*tasks)
    
    async def map_reduce(
        self,
        items: list[str],
        map_prompt_fn: Callable[[str], str],
        reduce_prompt: str,
        model: str = "gpt-4o-mini"
    ) -> str:
        """Map-reduce pattern for processing large datasets."""
        
        # Map phase: process items in parallel
        map_prompts = [map_prompt_fn(item) for item in items]
        map_results = await self.batch_complete(map_prompts, model)
        
        # Reduce phase: combine results
        combined = "\n\n".join([
            f"Item {i+1} result:\n{result}"
            for i, result in enumerate(map_results)
        ])
        
        final_prompt = f"{reduce_prompt}\n\nResults to combine:\n{combined}"
        
        response = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": final_prompt}]
        )
        
        return response.choices[0].message.content

# Usage
async def main():
    parallel = ParallelLLM(max_concurrent=5)
    
    # Batch processing
    prompts = [f"Summarize the concept of {topic}" for topic in [
        "machine learning", "deep learning", "neural networks",
        "natural language processing", "computer vision"
    ]]
    
    results = await parallel.batch_complete(prompts)
    
    # Map-reduce for document analysis
    documents = ["Doc 1 content...", "Doc 2 content...", "Doc 3 content..."]
    
    summary = await parallel.map_reduce(
        items=documents,
        map_prompt_fn=lambda doc: f"Extract key points from: {doc}",
        reduce_prompt="Combine these key points into a unified summary:"
    )
    
    print(summary)

# asyncio.run(main())

Production Optimization Service

from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
from typing import Optional
import time

app = FastAPI()

# Initialize components
cache = SemanticCache(similarity_threshold=0.92)
router = ModelRouter()

class CompletionRequest(BaseModel):
    prompt: str
    model: Optional[str] = None  # Auto-route if not specified
    use_cache: bool = True
    stream: bool = False

class CompletionResponse(BaseModel):
    content: str
    model_used: str
    cached: bool
    latency_ms: float
    tokens: dict
    cost: float

@app.post("/complete", response_model=CompletionResponse)
async def complete(request: CompletionRequest):
    """Optimized completion endpoint."""
    
    start = time.time()
    cached = False
    
    # Check cache first
    if request.use_cache:
        cached_response = cache.get(request.prompt)
        if cached_response:
            return CompletionResponse(
                content=cached_response,
                model_used="cache",
                cached=True,
                latency_ms=(time.time() - start) * 1000,
                tokens={"input": 0, "output": 0},
                cost=0.0
            )
    
    # Route to model
    if request.model:
        model = request.model
        config = router.models.get(TaskComplexity.MODERATE)
    else:
        config = router.route(request.prompt)
        model = config.name
    
    # Call LLM
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": request.prompt}]
    )
    
    content = response.choices[0].message.content
    
    # Cache response
    if request.use_cache:
        cache.set(request.prompt, content)
    
    # Calculate metrics
    input_tokens = response.usage.prompt_tokens
    output_tokens = response.usage.completion_tokens
    cost = (
        (input_tokens / 1000) * config.cost_per_1k_input +
        (output_tokens / 1000) * config.cost_per_1k_output
    )
    
    return CompletionResponse(
        content=content,
        model_used=model,
        cached=False,
        latency_ms=(time.time() - start) * 1000,
        tokens={"input": input_tokens, "output": output_tokens},
        cost=cost
    )

@app.get("/cache/stats")
async def cache_stats():
    """Get cache statistics."""
    return cache.get_stats()

@app.post("/cache/clear")
async def clear_cache():
    """Clear the cache."""
    cache.cache.clear()
    cache.embeddings.clear()
    return {"cleared": True}

References

Conclusion

LLM inference optimization is essential for production applications. Semantic caching eliminates redundant API calls for similar queries—even 30% cache hit rate significantly reduces costs. Request batching improves throughput for high-volume applications. Streaming reduces perceived latency by showing results as they generate. Smart model routing uses cheaper models for simple tasks and reserves expensive models for complex queries. Parallel processing accelerates batch workloads. Combine these techniques based on your specific requirements: latency-sensitive applications benefit most from caching and streaming, while batch processing benefits from batching and parallelization.

Embedding Models Compared: OpenAI vs Cohere vs Voyage vs Open Source

Introduction: Embedding models convert text into dense vectors that capture semantic meaning. Choosing the right embedding model significantly impacts search quality, retrieval accuracy, and application performance. This guide compares leading embedding models—OpenAI’s text-embedding-3, Cohere’s embed-v3, Voyage AI, and open-source alternatives like BGE and E5. We cover benchmarks, pricing, dimension trade-offs, and practical guidance on selecting the right model for your use case. Whether you’re building semantic search, RAG systems, or recommendation engines, understanding embedding model characteristics is essential.

Embedding Models Comparison
Embedding Models: From Text to Dense Vectors

OpenAI Embeddings

from openai import OpenAI
import numpy as np
from typing import Union

client = OpenAI()

def get_openai_embedding(
    text: Union[str, list[str]],
    model: str = "text-embedding-3-small",
    dimensions: int = None
) -> Union[list[float], list[list[float]]]:
    """Get embeddings from OpenAI."""
    
    # Handle single string or list
    input_text = [text] if isinstance(text, str) else text
    
    kwargs = {"model": model, "input": input_text}
    if dimensions:
        kwargs["dimensions"] = dimensions
    
    response = client.embeddings.create(**kwargs)
    
    embeddings = [item.embedding for item in response.data]
    
    return embeddings[0] if isinstance(text, str) else embeddings

# OpenAI model comparison
models = {
    "text-embedding-3-small": {
        "dimensions": 1536,
        "max_tokens": 8191,
        "price_per_1m": 0.02
    },
    "text-embedding-3-large": {
        "dimensions": 3072,
        "max_tokens": 8191,
        "price_per_1m": 0.13
    },
    "text-embedding-ada-002": {
        "dimensions": 1536,
        "max_tokens": 8191,
        "price_per_1m": 0.10
    }
}

# Dimension reduction with text-embedding-3
# Smaller dimensions = faster search, less storage
small_embed = get_openai_embedding(
    "What is machine learning?",
    model="text-embedding-3-small",
    dimensions=512  # Reduced from 1536
)

print(f"Reduced dimensions: {len(small_embed)}")

# Batch embedding for efficiency
texts = [
    "Machine learning is a subset of AI",
    "Deep learning uses neural networks",
    "Natural language processing handles text"
]

batch_embeddings = get_openai_embedding(texts, model="text-embedding-3-small")
print(f"Batch size: {len(batch_embeddings)}")

Cohere Embeddings

# pip install cohere

import cohere

co = cohere.Client("your-api-key")

def get_cohere_embedding(
    texts: list[str],
    model: str = "embed-english-v3.0",
    input_type: str = "search_document"
) -> list[list[float]]:
    """Get embeddings from Cohere.
    
    input_type options:
    - search_document: For documents to be searched
    - search_query: For search queries
    - classification: For classification tasks
    - clustering: For clustering tasks
    """
    
    response = co.embed(
        texts=texts,
        model=model,
        input_type=input_type
    )
    
    return response.embeddings

# Cohere model comparison
cohere_models = {
    "embed-english-v3.0": {
        "dimensions": 1024,
        "max_tokens": 512,
        "languages": "English",
        "price_per_1m": 0.10
    },
    "embed-multilingual-v3.0": {
        "dimensions": 1024,
        "max_tokens": 512,
        "languages": "100+",
        "price_per_1m": 0.10
    },
    "embed-english-light-v3.0": {
        "dimensions": 384,
        "max_tokens": 512,
        "languages": "English",
        "price_per_1m": 0.10
    }
}

# Document vs Query embeddings (asymmetric search)
documents = [
    "Python is a programming language",
    "JavaScript runs in browsers"
]

query = "What language is used for web development?"

# Embed documents with document type
doc_embeddings = get_cohere_embedding(
    documents,
    input_type="search_document"
)

# Embed query with query type
query_embedding = get_cohere_embedding(
    [query],
    input_type="search_query"
)[0]

# Calculate similarities
def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

for i, doc in enumerate(documents):
    sim = cosine_similarity(query_embedding, doc_embeddings[i])
    print(f"{doc[:40]}: {sim:.3f}")

Voyage AI Embeddings

# pip install voyageai

import voyageai

vo = voyageai.Client(api_key="your-api-key")

def get_voyage_embedding(
    texts: list[str],
    model: str = "voyage-2",
    input_type: str = None
) -> list[list[float]]:
    """Get embeddings from Voyage AI."""
    
    result = vo.embed(
        texts,
        model=model,
        input_type=input_type  # "query" or "document"
    )
    
    return result.embeddings

# Voyage model comparison
voyage_models = {
    "voyage-2": {
        "dimensions": 1024,
        "max_tokens": 4000,
        "specialty": "General purpose",
        "price_per_1m": 0.10
    },
    "voyage-large-2": {
        "dimensions": 1536,
        "max_tokens": 16000,
        "specialty": "Higher quality",
        "price_per_1m": 0.12
    },
    "voyage-code-2": {
        "dimensions": 1536,
        "max_tokens": 16000,
        "specialty": "Code understanding",
        "price_per_1m": 0.12
    },
    "voyage-law-2": {
        "dimensions": 1024,
        "max_tokens": 4000,
        "specialty": "Legal documents",
        "price_per_1m": 0.12
    }
}

# Code embedding example
code_snippets = [
    "def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)",
    "function factorial(n) { return n <= 1 ? 1 : n * factorial(n-1); }",
    "SELECT * FROM users WHERE age > 18"
]

code_embeddings = get_voyage_embedding(
    code_snippets,
    model="voyage-code-2",
    input_type="document"
)

# Search for similar code
query = "recursive function to calculate fibonacci numbers"
query_embed = get_voyage_embedding([query], model="voyage-code-2", input_type="query")[0]

for i, code in enumerate(code_snippets):
    sim = cosine_similarity(query_embed, code_embeddings[i])
    print(f"{code[:50]}: {sim:.3f}")

Open Source: BGE and E5

# pip install sentence-transformers

from sentence_transformers import SentenceTransformer

# BGE (BAAI General Embedding)
bge_model = SentenceTransformer('BAAI/bge-large-en-v1.5')

def get_bge_embedding(texts: list[str], is_query: bool = False) -> list[list[float]]:
    """Get BGE embeddings (free, local)."""
    
    # BGE recommends adding instruction prefix for queries
    if is_query:
        texts = [f"Represent this sentence for searching relevant passages: {t}" for t in texts]
    
    embeddings = bge_model.encode(texts, normalize_embeddings=True)
    return embeddings.tolist()

# E5 (Embeddings from bidirectional Encoder representations)
e5_model = SentenceTransformer('intfloat/e5-large-v2')

def get_e5_embedding(texts: list[str], is_query: bool = False) -> list[list[float]]:
    """Get E5 embeddings (free, local)."""
    
    # E5 requires specific prefixes
    prefix = "query: " if is_query else "passage: "
    texts = [prefix + t for t in texts]
    
    embeddings = e5_model.encode(texts, normalize_embeddings=True)
    return embeddings.tolist()

# Open source model comparison
open_source_models = {
    "bge-large-en-v1.5": {
        "dimensions": 1024,
        "max_tokens": 512,
        "size_mb": 1340,
        "mteb_score": 64.23
    },
    "bge-base-en-v1.5": {
        "dimensions": 768,
        "max_tokens": 512,
        "size_mb": 438,
        "mteb_score": 63.55
    },
    "e5-large-v2": {
        "dimensions": 1024,
        "max_tokens": 512,
        "size_mb": 1340,
        "mteb_score": 62.25
    },
    "all-MiniLM-L6-v2": {
        "dimensions": 384,
        "max_tokens": 256,
        "size_mb": 91,
        "mteb_score": 56.26
    }
}

# Local embedding - no API costs
documents = ["Machine learning basics", "Deep learning fundamentals"]
doc_embeds = get_bge_embedding(documents, is_query=False)

query = "What is ML?"
query_embed = get_bge_embedding([query], is_query=True)[0]

print(f"BGE embedding dimension: {len(doc_embeds[0])}")

Benchmark Comparison

import time
from dataclasses import dataclass

@dataclass
class BenchmarkResult:
    model: str
    avg_latency_ms: float
    throughput_docs_per_sec: float
    dimension: int
    mteb_retrieval_score: float

def benchmark_embedding_model(
    embed_func,
    texts: list[str],
    num_runs: int = 5
) -> dict:
    """Benchmark an embedding model."""
    
    latencies = []
    
    for _ in range(num_runs):
        start = time.time()
        embeddings = embed_func(texts)
        latency = (time.time() - start) * 1000
        latencies.append(latency)
    
    avg_latency = sum(latencies) / len(latencies)
    throughput = len(texts) / (avg_latency / 1000)
    
    return {
        "avg_latency_ms": avg_latency,
        "throughput_docs_per_sec": throughput,
        "dimension": len(embeddings[0])
    }

# MTEB Retrieval Benchmark Scores (as of 2024)
mteb_scores = {
    "text-embedding-3-large": 64.59,
    "text-embedding-3-small": 62.26,
    "voyage-2": 64.83,
    "voyage-large-2": 65.89,
    "embed-english-v3.0": 64.47,
    "bge-large-en-v1.5": 64.23,
    "e5-large-v2": 62.25,
    "all-MiniLM-L6-v2": 56.26
}

# Cost comparison for 1M embeddings
cost_per_million = {
    "text-embedding-3-small": 0.02,
    "text-embedding-3-large": 0.13,
    "text-embedding-ada-002": 0.10,
    "voyage-2": 0.10,
    "voyage-large-2": 0.12,
    "embed-english-v3.0": 0.10,
    "bge-large-en-v1.5": 0.00,  # Free (local)
    "e5-large-v2": 0.00,  # Free (local)
}

# Print comparison table
print("Model Comparison:")
print("-" * 70)
print(f"{'Model':<25} {'MTEB Score':<12} {'Cost/1M':<10} {'Dims':<8}")
print("-" * 70)

for model, score in sorted(mteb_scores.items(), key=lambda x: -x[1]):
    cost = cost_per_million.get(model, "N/A")
    dims = {
        "text-embedding-3-large": 3072,
        "text-embedding-3-small": 1536,
        "voyage-2": 1024,
        "voyage-large-2": 1536,
        "embed-english-v3.0": 1024,
        "bge-large-en-v1.5": 1024,
        "e5-large-v2": 1024,
        "all-MiniLM-L6-v2": 384
    }.get(model, "?")
    
    print(f"{model:<25} {score:<12.2f} ${cost:<9} {dims:<8}")

Unified Embedding Interface

from abc import ABC, abstractmethod
from enum import Enum

class EmbeddingProvider(str, Enum):
    OPENAI = "openai"
    COHERE = "cohere"
    VOYAGE = "voyage"
    BGE = "bge"
    E5 = "e5"

class EmbeddingModel(ABC):
    """Abstract base for embedding models."""
    
    @abstractmethod
    def embed(self, texts: list[str], is_query: bool = False) -> list[list[float]]:
        pass
    
    @property
    @abstractmethod
    def dimension(self) -> int:
        pass

class OpenAIEmbedding(EmbeddingModel):
    def __init__(self, model: str = "text-embedding-3-small", dimensions: int = None):
        self.model = model
        self._dimensions = dimensions or {"text-embedding-3-small": 1536, "text-embedding-3-large": 3072}.get(model, 1536)
    
    def embed(self, texts: list[str], is_query: bool = False) -> list[list[float]]:
        return get_openai_embedding(texts, self.model, self._dimensions)
    
    @property
    def dimension(self) -> int:
        return self._dimensions

class BGEEmbedding(EmbeddingModel):
    def __init__(self, model_name: str = "BAAI/bge-large-en-v1.5"):
        self.model = SentenceTransformer(model_name)
        self._dimension = self.model.get_sentence_embedding_dimension()
    
    def embed(self, texts: list[str], is_query: bool = False) -> list[list[float]]:
        if is_query:
            texts = [f"Represent this sentence for searching relevant passages: {t}" for t in texts]
        return self.model.encode(texts, normalize_embeddings=True).tolist()
    
    @property
    def dimension(self) -> int:
        return self._dimension

class EmbeddingFactory:
    """Factory for creating embedding models."""
    
    @staticmethod
    def create(provider: EmbeddingProvider, **kwargs) -> EmbeddingModel:
        if provider == EmbeddingProvider.OPENAI:
            return OpenAIEmbedding(**kwargs)
        elif provider == EmbeddingProvider.BGE:
            return BGEEmbedding(**kwargs)
        # Add other providers...
        else:
            raise ValueError(f"Unknown provider: {provider}")

# Usage
embedder = EmbeddingFactory.create(EmbeddingProvider.OPENAI, model="text-embedding-3-small")

docs = ["Document 1", "Document 2"]
doc_embeds = embedder.embed(docs, is_query=False)

query_embed = embedder.embed(["Search query"], is_query=True)[0]

print(f"Using {embedder.dimension}-dimensional embeddings")

Choosing the Right Model

The best embedding model depends on your specific requirements. For general-purpose semantic search with good quality and low cost, OpenAI’s text-embedding-3-small offers excellent value at $0.02 per million tokens. If you need the highest retrieval quality and can afford higher costs, Voyage’s voyage-large-2 leads benchmarks. For multilingual applications, Cohere’s embed-multilingual-v3.0 supports 100+ languages. When running locally without API costs is essential, BGE-large-en-v1.5 provides near-commercial quality. For code search specifically, Voyage’s voyage-code-2 is purpose-built for programming languages.

References

Conclusion

Embedding model selection significantly impacts your application’s search quality, latency, and costs. Start with OpenAI’s text-embedding-3-small for most use cases—it’s affordable, high-quality, and easy to integrate. Consider Voyage for specialized domains like code or legal documents. Use open-source models like BGE when you need local inference or want to eliminate API costs. Always benchmark on your specific data, as performance varies by domain. The embedding landscape evolves rapidly, so revisit your choice periodically as new models emerge. A well-chosen embedding model is the foundation of effective semantic search and RAG systems.

RAG Optimization: Query Rewriting, Hybrid Search, and Re-ranking

Introduction: Retrieval-Augmented Generation (RAG) grounds LLM responses in factual data, but naive implementations often retrieve irrelevant content or miss important information. Optimizing RAG requires attention to every stage: query understanding, retrieval strategies, re-ranking, and context integration. This guide covers practical optimization techniques: query rewriting and expansion, hybrid search combining dense and sparse retrieval, re-ranking with cross-encoders, chunk optimization, and evaluation frameworks that help you measure and improve retrieval quality systematically.

RAG Optimization
RAG Pipeline: Query Rewriting, Hybrid Retrieval, Re-ranking

Query Rewriting and Expansion

from dataclasses import dataclass
from typing import Optional

@dataclass
class RewrittenQuery:
    """Result of query rewriting."""
    
    original: str
    rewritten: str
    expansions: list[str]
    hypothetical_answer: Optional[str] = None

class QueryRewriter:
    """Rewrite queries for better retrieval."""
    
    def __init__(self, client):
        self.client = client
    
    def rewrite(self, query: str) -> RewrittenQuery:
        """Rewrite query for better retrieval."""
        
        prompt = f"""Rewrite this search query to be more specific and effective for retrieval.
Also generate 2-3 alternative phrasings.

Original query: {query}

Respond in JSON format:
{{"rewritten": "improved query", "alternatives": ["alt1", "alt2"]}}"""
        
        response = self.client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}],
            response_format={"type": "json_object"}
        )
        
        import json
        result = json.loads(response.choices[0].message.content)
        
        return RewrittenQuery(
            original=query,
            rewritten=result.get("rewritten", query),
            expansions=result.get("alternatives", [])
        )
    
    def generate_hypothetical_answer(self, query: str) -> str:
        """Generate hypothetical answer for HyDE retrieval."""
        
        prompt = f"""Generate a hypothetical answer to this question.
The answer should be detailed and factual-sounding, even if you're not certain.
This will be used for semantic search.

Question: {query}

Hypothetical answer:"""
        
        response = self.client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}]
        )
        
        return response.choices[0].message.content

class MultiQueryRetriever:
    """Generate multiple queries for comprehensive retrieval."""
    
    def __init__(self, client, retriever):
        self.client = client
        self.retriever = retriever
    
    def generate_queries(self, query: str, num_queries: int = 3) -> list[str]:
        """Generate multiple search queries."""
        
        prompt = f"""Generate {num_queries} different search queries that would help answer this question.
Each query should approach the question from a different angle.

Question: {query}

Return as JSON: {{"queries": ["query1", "query2", ...]}}"""
        
        response = self.client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}],
            response_format={"type": "json_object"}
        )
        
        import json
        result = json.loads(response.choices[0].message.content)
        return result.get("queries", [query])
    
    def retrieve(self, query: str, top_k: int = 5) -> list[dict]:
        """Retrieve using multiple queries."""
        
        queries = self.generate_queries(query)
        queries.append(query)  # Include original
        
        all_results = []
        seen_ids = set()
        
        for q in queries:
            results = self.retriever.search(q, top_k=top_k)
            
            for result in results:
                if result["id"] not in seen_ids:
                    all_results.append(result)
                    seen_ids.add(result["id"])
        
        return all_results[:top_k * 2]

# Step-back prompting for complex queries
class StepBackRetriever:
    """Use step-back prompting for complex queries."""
    
    def __init__(self, client, retriever):
        self.client = client
        self.retriever = retriever
    
    def get_step_back_query(self, query: str) -> str:
        """Generate broader step-back query."""
        
        prompt = f"""Given this specific question, generate a more general "step-back" question
that would help provide background context.

Specific question: {query}

Step-back question (more general):"""
        
        response = self.client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}]
        )
        
        return response.choices[0].message.content.strip()
    
    def retrieve(self, query: str, top_k: int = 5) -> dict:
        """Retrieve with step-back context."""
        
        step_back = self.get_step_back_query(query)
        
        # Get background context
        background = self.retriever.search(step_back, top_k=top_k // 2)
        
        # Get specific results
        specific = self.retriever.search(query, top_k=top_k)
        
        return {
            "background": background,
            "specific": specific
        }

Hybrid Search

from dataclasses import dataclass
from typing import Callable
import numpy as np

@dataclass
class SearchResult:
    """A search result with scores."""
    
    id: str
    content: str
    dense_score: float = 0.0
    sparse_score: float = 0.0
    combined_score: float = 0.0
    metadata: dict = None

class HybridRetriever:
    """Combine dense and sparse retrieval."""
    
    def __init__(
        self,
        embedding_client,
        embedding_model: str = "text-embedding-3-small",
        alpha: float = 0.5  # Weight for dense vs sparse
    ):
        self.embedding_client = embedding_client
        self.embedding_model = embedding_model
        self.alpha = alpha
        
        # Document storage
        self.documents: list[dict] = []
        self.embeddings: list[list[float]] = []
        self.bm25_index = None
    
    def _embed(self, texts: list[str]) -> list[list[float]]:
        """Get embeddings for texts."""
        
        response = self.embedding_client.embeddings.create(
            model=self.embedding_model,
            input=texts
        )
        
        return [d.embedding for d in response.data]
    
    def add_documents(self, documents: list[dict]):
        """Add documents to index."""
        
        # Store documents
        self.documents.extend(documents)
        
        # Create embeddings
        texts = [d["content"] for d in documents]
        embeddings = self._embed(texts)
        self.embeddings.extend(embeddings)
        
        # Build BM25 index
        self._build_bm25_index()
    
    def _build_bm25_index(self):
        """Build BM25 index for sparse retrieval."""
        
        from rank_bm25 import BM25Okapi
        
        # Tokenize documents
        tokenized = [
            doc["content"].lower().split()
            for doc in self.documents
        ]
        
        self.bm25_index = BM25Okapi(tokenized)
    
    def _dense_search(self, query: str, top_k: int) -> list[tuple[int, float]]:
        """Dense vector search."""
        
        query_embedding = self._embed([query])[0]
        
        # Calculate similarities
        similarities = []
        for i, doc_embedding in enumerate(self.embeddings):
            sim = np.dot(query_embedding, doc_embedding) / (
                np.linalg.norm(query_embedding) * np.linalg.norm(doc_embedding)
            )
            similarities.append((i, sim))
        
        # Sort by similarity
        similarities.sort(key=lambda x: x[1], reverse=True)
        return similarities[:top_k]
    
    def _sparse_search(self, query: str, top_k: int) -> list[tuple[int, float]]:
        """Sparse BM25 search."""
        
        if not self.bm25_index:
            return []
        
        tokenized_query = query.lower().split()
        scores = self.bm25_index.get_scores(tokenized_query)
        
        # Get top results
        indexed_scores = [(i, s) for i, s in enumerate(scores)]
        indexed_scores.sort(key=lambda x: x[1], reverse=True)
        
        return indexed_scores[:top_k]
    
    def search(self, query: str, top_k: int = 10) -> list[SearchResult]:
        """Hybrid search combining dense and sparse."""
        
        # Get results from both methods
        dense_results = self._dense_search(query, top_k * 2)
        sparse_results = self._sparse_search(query, top_k * 2)
        
        # Normalize scores
        dense_max = max(s for _, s in dense_results) if dense_results else 1
        sparse_max = max(s for _, s in sparse_results) if sparse_results else 1
        
        # Combine scores
        score_map = {}
        
        for idx, score in dense_results:
            normalized = score / dense_max if dense_max > 0 else 0
            score_map[idx] = {"dense": normalized, "sparse": 0}
        
        for idx, score in sparse_results:
            normalized = score / sparse_max if sparse_max > 0 else 0
            if idx in score_map:
                score_map[idx]["sparse"] = normalized
            else:
                score_map[idx] = {"dense": 0, "sparse": normalized}
        
        # Calculate combined scores
        results = []
        for idx, scores in score_map.items():
            combined = (
                self.alpha * scores["dense"] +
                (1 - self.alpha) * scores["sparse"]
            )
            
            results.append(SearchResult(
                id=str(idx),
                content=self.documents[idx]["content"],
                dense_score=scores["dense"],
                sparse_score=scores["sparse"],
                combined_score=combined,
                metadata=self.documents[idx].get("metadata")
            ))
        
        # Sort by combined score
        results.sort(key=lambda x: x.combined_score, reverse=True)
        return results[:top_k]

# Reciprocal Rank Fusion
class RRFRetriever:
    """Combine multiple retrievers using RRF."""
    
    def __init__(self, retrievers: list, k: int = 60):
        self.retrievers = retrievers
        self.k = k  # RRF constant
    
    def search(self, query: str, top_k: int = 10) -> list[dict]:
        """Search using RRF fusion."""
        
        # Get results from all retrievers
        all_rankings = []
        
        for retriever in self.retrievers:
            results = retriever.search(query, top_k=top_k * 2)
            all_rankings.append(results)
        
        # Calculate RRF scores
        rrf_scores = {}
        
        for ranking in all_rankings:
            for rank, result in enumerate(ranking):
                doc_id = result.get("id") or result.get("content")[:50]
                
                if doc_id not in rrf_scores:
                    rrf_scores[doc_id] = {
                        "score": 0,
                        "content": result.get("content"),
                        "metadata": result.get("metadata")
                    }
                
                # RRF formula: 1 / (k + rank)
                rrf_scores[doc_id]["score"] += 1 / (self.k + rank + 1)
        
        # Sort by RRF score
        sorted_results = sorted(
            rrf_scores.items(),
            key=lambda x: x[1]["score"],
            reverse=True
        )
        
        return [
            {
                "id": doc_id,
                "content": data["content"],
                "score": data["score"],
                "metadata": data["metadata"]
            }
            for doc_id, data in sorted_results[:top_k]
        ]

Re-ranking

from dataclasses import dataclass

@dataclass
class RankedResult:
    """A re-ranked result."""
    
    content: str
    original_rank: int
    new_rank: int
    relevance_score: float
    metadata: dict = None

class CrossEncoderReranker:
    """Re-rank using cross-encoder model."""
    
    def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
        from sentence_transformers import CrossEncoder
        self.model = CrossEncoder(model_name)
    
    def rerank(
        self,
        query: str,
        documents: list[dict],
        top_k: int = 5
    ) -> list[RankedResult]:
        """Re-rank documents using cross-encoder."""
        
        # Create query-document pairs
        pairs = [
            [query, doc["content"]]
            for doc in documents
        ]
        
        # Get relevance scores
        scores = self.model.predict(pairs)
        
        # Create ranked results
        results = []
        for i, (doc, score) in enumerate(zip(documents, scores)):
            results.append(RankedResult(
                content=doc["content"],
                original_rank=i,
                new_rank=0,  # Will be set after sorting
                relevance_score=float(score),
                metadata=doc.get("metadata")
            ))
        
        # Sort by relevance score
        results.sort(key=lambda x: x.relevance_score, reverse=True)
        
        # Update ranks
        for i, result in enumerate(results):
            result.new_rank = i
        
        return results[:top_k]

class LLMReranker:
    """Re-rank using LLM."""
    
    def __init__(self, client):
        self.client = client
    
    def rerank(
        self,
        query: str,
        documents: list[dict],
        top_k: int = 5
    ) -> list[RankedResult]:
        """Re-rank using LLM scoring."""
        
        results = []
        
        for i, doc in enumerate(documents):
            score = self._score_relevance(query, doc["content"])
            
            results.append(RankedResult(
                content=doc["content"],
                original_rank=i,
                new_rank=0,
                relevance_score=score,
                metadata=doc.get("metadata")
            ))
        
        # Sort and update ranks
        results.sort(key=lambda x: x.relevance_score, reverse=True)
        
        for i, result in enumerate(results):
            result.new_rank = i
        
        return results[:top_k]
    
    def _score_relevance(self, query: str, document: str) -> float:
        """Score document relevance to query."""
        
        prompt = f"""Rate how relevant this document is to the query on a scale of 0-10.

Query: {query}

Document: {document[:500]}

Respond with just a number (0-10):"""
        
        response = self.client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}],
            max_tokens=5
        )
        
        try:
            return float(response.choices[0].message.content.strip())
        except:
            return 5.0

# Cohere reranker
class CohereReranker:
    """Re-rank using Cohere rerank API."""
    
    def __init__(self, api_key: str):
        import cohere
        self.client = cohere.Client(api_key)
    
    def rerank(
        self,
        query: str,
        documents: list[dict],
        top_k: int = 5
    ) -> list[RankedResult]:
        """Re-rank using Cohere."""
        
        docs = [d["content"] for d in documents]
        
        response = self.client.rerank(
            query=query,
            documents=docs,
            top_n=top_k,
            model="rerank-english-v3.0"
        )
        
        results = []
        for i, result in enumerate(response.results):
            results.append(RankedResult(
                content=documents[result.index]["content"],
                original_rank=result.index,
                new_rank=i,
                relevance_score=result.relevance_score,
                metadata=documents[result.index].get("metadata")
            ))
        
        return results

Chunk Optimization

from dataclasses import dataclass
from typing import Optional

@dataclass
class Chunk:
    """A document chunk."""
    
    content: str
    metadata: dict
    parent_id: Optional[str] = None
    chunk_index: int = 0

class SemanticChunker:
    """Chunk documents based on semantic boundaries."""
    
    def __init__(
        self,
        embedding_client,
        similarity_threshold: float = 0.8
    ):
        self.embedding_client = embedding_client
        self.threshold = similarity_threshold
    
    def chunk(self, text: str, metadata: dict = None) -> list[Chunk]:
        """Chunk text at semantic boundaries."""
        
        # Split into sentences
        import re
        sentences = re.split(r'(?<=[.!?])\s+', text)
        
        if len(sentences) <= 1:
            return [Chunk(content=text, metadata=metadata or {})]
        
        # Get embeddings for sentences
        embeddings = self._embed(sentences)
        
        # Find semantic boundaries
        chunks = []
        current_chunk = [sentences[0]]
        
        for i in range(1, len(sentences)):
            similarity = self._cosine_similarity(
                embeddings[i-1],
                embeddings[i]
            )
            
            if similarity < self.threshold:
                # Semantic boundary - start new chunk
                chunks.append(Chunk(
                    content=" ".join(current_chunk),
                    metadata=metadata or {},
                    chunk_index=len(chunks)
                ))
                current_chunk = [sentences[i]]
            else:
                current_chunk.append(sentences[i])
        
        # Add final chunk
        if current_chunk:
            chunks.append(Chunk(
                content=" ".join(current_chunk),
                metadata=metadata or {},
                chunk_index=len(chunks)
            ))
        
        return chunks
    
    def _embed(self, texts: list[str]) -> list[list[float]]:
        response = self.embedding_client.embeddings.create(
            model="text-embedding-3-small",
            input=texts
        )
        return [d.embedding for d in response.data]
    
    def _cosine_similarity(self, a, b) -> float:
        import numpy as np
        a, b = np.array(a), np.array(b)
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

class HierarchicalChunker:
    """Create hierarchical chunks with parent-child relationships."""
    
    def __init__(
        self,
        large_chunk_size: int = 2000,
        small_chunk_size: int = 400,
        overlap: int = 50
    ):
        self.large_size = large_chunk_size
        self.small_size = small_chunk_size
        self.overlap = overlap
    
    def chunk(self, text: str, doc_id: str) -> dict:
        """Create hierarchical chunks."""
        
        # Create large chunks (parents)
        large_chunks = self._split_text(text, self.large_size, self.overlap)
        
        result = {
            "parents": [],
            "children": []
        }
        
        for i, large_chunk in enumerate(large_chunks):
            parent_id = f"{doc_id}_parent_{i}"
            
            result["parents"].append(Chunk(
                content=large_chunk,
                metadata={"doc_id": doc_id, "type": "parent"},
                chunk_index=i
            ))
            
            # Create small chunks (children)
            small_chunks = self._split_text(large_chunk, self.small_size, self.overlap)
            
            for j, small_chunk in enumerate(small_chunks):
                result["children"].append(Chunk(
                    content=small_chunk,
                    metadata={"doc_id": doc_id, "type": "child"},
                    parent_id=parent_id,
                    chunk_index=j
                ))
        
        return result
    
    def _split_text(self, text: str, chunk_size: int, overlap: int) -> list[str]:
        """Split text into overlapping chunks."""
        
        chunks = []
        start = 0
        
        while start < len(text):
            end = start + chunk_size
            chunk = text[start:end]
            
            # Try to break at sentence boundary
            if end < len(text):
                last_period = chunk.rfind('.')
                if last_period > chunk_size // 2:
                    chunk = chunk[:last_period + 1]
                    end = start + last_period + 1
            
            chunks.append(chunk.strip())
            start = end - overlap
        
        return chunks

# Parent document retriever
class ParentDocumentRetriever:
    """Retrieve child chunks but return parent context."""
    
    def __init__(self, child_retriever, parent_store: dict):
        self.child_retriever = child_retriever
        self.parent_store = parent_store
    
    def search(self, query: str, top_k: int = 5) -> list[dict]:
        """Search children, return parents."""
        
        # Search child chunks
        child_results = self.child_retriever.search(query, top_k=top_k * 2)
        
        # Get unique parents
        seen_parents = set()
        results = []
        
        for child in child_results:
            parent_id = child.get("parent_id")
            
            if parent_id and parent_id not in seen_parents:
                seen_parents.add(parent_id)
                
                parent = self.parent_store.get(parent_id)
                if parent:
                    results.append({
                        "content": parent["content"],
                        "metadata": parent.get("metadata"),
                        "matched_child": child["content"]
                    })
        
        return results[:top_k]

Production RAG Service

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional

app = FastAPI()

# Initialize components
from openai import OpenAI
client = OpenAI()

hybrid_retriever = HybridRetriever(client)
query_rewriter = QueryRewriter(client)
reranker = LLMReranker(client)

class RAGRequest(BaseModel):
    query: str
    top_k: int = 5
    use_rewriting: bool = True
    use_reranking: bool = True

class IndexRequest(BaseModel):
    documents: list[dict]

@app.post("/v1/index")
async def index_documents(request: IndexRequest):
    """Index documents for retrieval."""
    
    hybrid_retriever.add_documents(request.documents)
    
    return {
        "indexed": len(request.documents),
        "total": len(hybrid_retriever.documents)
    }

@app.post("/v1/retrieve")
async def retrieve(request: RAGRequest):
    """Retrieve relevant documents."""
    
    query = request.query
    
    # Query rewriting
    if request.use_rewriting:
        rewritten = query_rewriter.rewrite(query)
        query = rewritten.rewritten
    
    # Hybrid search
    results = hybrid_retriever.search(query, top_k=request.top_k * 2)
    
    # Re-ranking
    if request.use_reranking and results:
        documents = [{"content": r.content, "metadata": r.metadata} for r in results]
        ranked = reranker.rerank(request.query, documents, top_k=request.top_k)
        
        return {
            "results": [
                {
                    "content": r.content,
                    "score": r.relevance_score,
                    "original_rank": r.original_rank,
                    "new_rank": r.new_rank
                }
                for r in ranked
            ]
        }
    
    return {
        "results": [
            {
                "content": r.content,
                "score": r.combined_score,
                "dense_score": r.dense_score,
                "sparse_score": r.sparse_score
            }
            for r in results[:request.top_k]
        ]
    }

@app.post("/v1/rag")
async def rag_query(request: RAGRequest):
    """Full RAG pipeline with generation."""
    
    # Retrieve
    retrieval_response = await retrieve(request)
    results = retrieval_response["results"]
    
    if not results:
        return {"answer": "No relevant information found.", "sources": []}
    
    # Build context
    context = "\n\n".join([
        f"Source {i+1}: {r['content']}"
        for i, r in enumerate(results)
    ])
    
    # Generate answer
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "system",
                "content": "Answer based on the provided context. Cite sources."
            },
            {
                "role": "user",
                "content": f"Context:\n{context}\n\nQuestion: {request.query}"
            }
        ]
    )
    
    return {
        "answer": response.choices[0].message.content,
        "sources": results
    }

@app.get("/health")
async def health():
    return {
        "status": "healthy",
        "documents_indexed": len(hybrid_retriever.documents)
    }

References

Conclusion

RAG optimization is iterative—measure retrieval quality, identify failure modes, and apply targeted improvements. Start with query rewriting to handle ambiguous or poorly-formed queries. Implement hybrid search combining dense embeddings with sparse BM25 to capture both semantic similarity and keyword matches. Add re-ranking with cross-encoders or LLMs to improve precision on the final results. Optimize chunking strategy based on your content—semantic chunking for varied documents, hierarchical chunking for long documents where context matters. Use evaluation frameworks to measure recall, precision, and answer quality systematically. The goal is retrieving the most relevant context while staying within token limits—every improvement in retrieval quality directly improves generation quality.

LLM Routing and Model Selection: Optimizing Cost and Quality in Production

Introduction: Not every query needs GPT-4. Routing simple questions to cheaper, faster models while reserving expensive models for complex tasks can cut costs by 70% or more without sacrificing quality. Smart LLM routing is the difference between a $10,000/month AI bill and a $3,000 one. This guide covers implementing intelligent model selection: classifying query complexity, building routing logic, handling fallbacks, and optimizing for cost-quality tradeoffs. Whether you’re using a single provider with multiple tiers or orchestrating across OpenAI, Anthropic, and local models, these patterns will help you build efficient, cost-effective LLM applications.

LLM Routing
Intelligent LLM Routing: Task Classification to Model Selection

Basic Model Router

from dataclasses import dataclass
from enum import Enum
from typing import Optional, Callable
from openai import OpenAI

client = OpenAI()

class TaskComplexity(Enum):
    SIMPLE = "simple"      # Factual Q&A, simple formatting
    MEDIUM = "medium"      # Summarization, basic analysis
    COMPLEX = "complex"    # Reasoning, coding, creative writing
    EXPERT = "expert"      # Multi-step reasoning, specialized knowledge

@dataclass
class ModelConfig:
    name: str
    cost_per_1k_input: float
    cost_per_1k_output: float
    max_tokens: int
    strengths: list[str]

class ModelRouter:
    """Route requests to appropriate models based on complexity."""
    
    MODELS = {
        TaskComplexity.SIMPLE: ModelConfig(
            name="gpt-4o-mini",
            cost_per_1k_input=0.00015,
            cost_per_1k_output=0.0006,
            max_tokens=128000,
            strengths=["fast", "cheap", "simple_qa"]
        ),
        TaskComplexity.MEDIUM: ModelConfig(
            name="gpt-4o-mini",
            cost_per_1k_input=0.00015,
            cost_per_1k_output=0.0006,
            max_tokens=128000,
            strengths=["summarization", "formatting"]
        ),
        TaskComplexity.COMPLEX: ModelConfig(
            name="gpt-4o",
            cost_per_1k_input=0.005,
            cost_per_1k_output=0.015,
            max_tokens=128000,
            strengths=["reasoning", "coding", "analysis"]
        ),
        TaskComplexity.EXPERT: ModelConfig(
            name="gpt-4o",
            cost_per_1k_input=0.005,
            cost_per_1k_output=0.015,
            max_tokens=128000,
            strengths=["complex_reasoning", "specialized"]
        ),
    }
    
    def __init__(self):
        self.classifier_model = "gpt-4o-mini"
    
    def classify_complexity(self, query: str) -> TaskComplexity:
        """Classify query complexity using a small model."""
        
        response = client.chat.completions.create(
            model=self.classifier_model,
            messages=[
                {
                    "role": "system",
                    "content": """Classify the complexity of this query:
- SIMPLE: Factual questions, definitions, simple formatting
- MEDIUM: Summarization, basic analysis, straightforward tasks
- COMPLEX: Multi-step reasoning, coding, creative writing, detailed analysis
- EXPERT: Specialized knowledge, complex problem-solving, research-level questions

Return only: SIMPLE, MEDIUM, COMPLEX, or EXPERT"""
                },
                {"role": "user", "content": query}
            ],
            max_tokens=10,
            temperature=0
        )
        
        result = response.choices[0].message.content.strip().upper()
        
        try:
            return TaskComplexity(result.lower())
        except ValueError:
            return TaskComplexity.MEDIUM  # Default to medium
    
    def route(self, query: str, force_model: Optional[str] = None) -> tuple[str, ModelConfig]:
        """Route query to appropriate model."""
        
        if force_model:
            # Find config for forced model
            for config in self.MODELS.values():
                if config.name == force_model:
                    return force_model, config
        
        complexity = self.classify_complexity(query)
        config = self.MODELS[complexity]
        
        return config.name, config
    
    def complete(self, query: str, system_prompt: str = "", **kwargs) -> str:
        """Complete query with automatic routing."""
        
        model, config = self.route(query)
        
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": query})
        
        response = client.chat.completions.create(
            model=model,
            messages=messages,
            **kwargs
        )
        
        return response.choices[0].message.content

# Usage
router = ModelRouter()

# Simple query -> routes to gpt-4o-mini
result = router.complete("What is the capital of France?")

# Complex query -> routes to gpt-4o
result = router.complete("Write a Python function to implement a red-black tree with all standard operations.")

Multi-Provider Router

from abc import ABC, abstractmethod
from openai import OpenAI
import anthropic

class LLMProvider(ABC):
    """Abstract base for LLM providers."""
    
    @abstractmethod
    def complete(self, messages: list[dict], **kwargs) -> str:
        pass
    
    @abstractmethod
    def get_cost(self, input_tokens: int, output_tokens: int) -> float:
        pass

class OpenAIProvider(LLMProvider):
    def __init__(self, model: str = "gpt-4o"):
        self.client = OpenAI()
        self.model = model
        self.costs = {
            "gpt-4o": (0.005, 0.015),
            "gpt-4o-mini": (0.00015, 0.0006),
            "gpt-4-turbo-preview": (0.01, 0.03),
        }
    
    def complete(self, messages: list[dict], **kwargs) -> str:
        response = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            **kwargs
        )
        return response.choices[0].message.content
    
    def get_cost(self, input_tokens: int, output_tokens: int) -> float:
        input_cost, output_cost = self.costs.get(self.model, (0.01, 0.03))
        return (input_tokens / 1000) * input_cost + (output_tokens / 1000) * output_cost

class AnthropicProvider(LLMProvider):
    def __init__(self, model: str = "claude-3-5-sonnet-20241022"):
        self.client = anthropic.Anthropic()
        self.model = model
        self.costs = {
            "claude-3-5-sonnet-20241022": (0.003, 0.015),
            "claude-3-haiku-20240307": (0.00025, 0.00125),
            "claude-3-opus-20240229": (0.015, 0.075),
        }
    
    def complete(self, messages: list[dict], **kwargs) -> str:
        # Convert OpenAI format to Anthropic
        system = ""
        anthropic_messages = []
        
        for msg in messages:
            if msg["role"] == "system":
                system = msg["content"]
            else:
                anthropic_messages.append(msg)
        
        response = self.client.messages.create(
            model=self.model,
            system=system,
            messages=anthropic_messages,
            max_tokens=kwargs.get("max_tokens", 1024)
        )
        return response.content[0].text
    
    def get_cost(self, input_tokens: int, output_tokens: int) -> float:
        input_cost, output_cost = self.costs.get(self.model, (0.003, 0.015))
        return (input_tokens / 1000) * input_cost + (output_tokens / 1000) * output_cost

class MultiProviderRouter:
    """Route across multiple LLM providers."""
    
    def __init__(self):
        self.providers = {
            "openai_fast": OpenAIProvider("gpt-4o-mini"),
            "openai_smart": OpenAIProvider("gpt-4o"),
            "claude_balanced": AnthropicProvider("claude-3-5-sonnet-20241022"),
            "claude_fast": AnthropicProvider("claude-3-haiku-20240307"),
        }
        
        self.task_routing = {
            "coding": "openai_smart",
            "reasoning": "claude_balanced",
            "simple_qa": "openai_fast",
            "creative": "claude_balanced",
            "analysis": "openai_smart",
            "summarization": "claude_fast",
        }
    
    def classify_task(self, query: str) -> str:
        """Classify task type."""
        # Use fast model for classification
        provider = self.providers["openai_fast"]
        
        result = provider.complete([
            {
                "role": "system",
                "content": "Classify this task. Return only: coding, reasoning, simple_qa, creative, analysis, or summarization"
            },
            {"role": "user", "content": query}
        ], max_tokens=20)
        
        task_type = result.strip().lower()
        return task_type if task_type in self.task_routing else "simple_qa"
    
    def route(self, query: str) -> tuple[str, LLMProvider]:
        """Route to best provider for task."""
        task_type = self.classify_task(query)
        provider_key = self.task_routing[task_type]
        return provider_key, self.providers[provider_key]
    
    def complete(self, query: str, system_prompt: str = "") -> str:
        """Complete with automatic routing."""
        provider_key, provider = self.route(query)
        
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": query})
        
        return provider.complete(messages)

Fallback and Retry Logic

import time
from typing import Optional

class ResilientRouter:
    """Router with fallback and retry logic."""
    
    def __init__(self):
        self.providers = {
            "primary": OpenAIProvider("gpt-4o"),
            "secondary": AnthropicProvider("claude-3-5-sonnet-20241022"),
            "fallback": OpenAIProvider("gpt-4o-mini"),
        }
        
        self.provider_order = ["primary", "secondary", "fallback"]
        self.max_retries = 3
        self.retry_delay = 1.0
    
    def complete_with_fallback(
        self,
        messages: list[dict],
        timeout: float = 30.0,
        **kwargs
    ) -> tuple[str, str]:  # Returns (response, provider_used)
        """Complete with automatic fallback on failure."""
        
        last_error = None
        
        for provider_key in self.provider_order:
            provider = self.providers[provider_key]
            
            for attempt in range(self.max_retries):
                try:
                    response = provider.complete(messages, **kwargs)
                    return response, provider_key
                    
                except Exception as e:
                    last_error = e
                    
                    # Check if retryable
                    if self._is_retryable(e):
                        time.sleep(self.retry_delay * (attempt + 1))
                        continue
                    else:
                        break  # Move to next provider
            
            # Log provider failure
            print(f"Provider {provider_key} failed: {last_error}")
        
        raise RuntimeError(f"All providers failed. Last error: {last_error}")
    
    def _is_retryable(self, error: Exception) -> bool:
        """Check if error is retryable."""
        retryable_messages = [
            "rate limit",
            "timeout",
            "overloaded",
            "503",
            "529",
        ]
        
        error_str = str(error).lower()
        return any(msg in error_str for msg in retryable_messages)

# Load balancing router
class LoadBalancingRouter:
    """Route with load balancing across providers."""
    
    def __init__(self):
        self.providers = [
            {"provider": OpenAIProvider("gpt-4o"), "weight": 0.5, "healthy": True},
            {"provider": AnthropicProvider("claude-3-5-sonnet-20241022"), "weight": 0.5, "healthy": True},
        ]
        
        self.health_check_interval = 60
        self.last_health_check = 0
    
    def select_provider(self) -> LLMProvider:
        """Select provider based on weights and health."""
        import random
        
        healthy = [p for p in self.providers if p["healthy"]]
        
        if not healthy:
            # All unhealthy, try first one anyway
            return self.providers[0]["provider"]
        
        # Weighted random selection
        total_weight = sum(p["weight"] for p in healthy)
        r = random.uniform(0, total_weight)
        
        cumulative = 0
        for p in healthy:
            cumulative += p["weight"]
            if r <= cumulative:
                return p["provider"]
        
        return healthy[-1]["provider"]
    
    def complete(self, messages: list[dict], **kwargs) -> str:
        """Complete with load balancing."""
        provider = self.select_provider()
        return provider.complete(messages, **kwargs)

Cost-Optimized Routing

from dataclasses import dataclass

@dataclass
class RoutingDecision:
    provider: str
    model: str
    estimated_cost: float
    quality_score: float
    reasoning: str

class CostOptimizedRouter:
    """Route to minimize cost while meeting quality requirements."""
    
    def __init__(self, max_cost_per_request: float = 0.10):
        self.max_cost = max_cost_per_request
        
        self.models = [
            {
                "provider": "openai",
                "model": "gpt-4o-mini",
                "cost_per_1k": 0.00015 + 0.0006,  # avg input + output
                "quality_score": 0.7,
                "best_for": ["simple", "formatting", "extraction"]
            },
            {
                "provider": "anthropic",
                "model": "claude-3-haiku-20240307",
                "cost_per_1k": 0.00025 + 0.00125,
                "quality_score": 0.75,
                "best_for": ["simple", "summarization"]
            },
            {
                "provider": "openai",
                "model": "gpt-4o",
                "cost_per_1k": 0.005 + 0.015,
                "quality_score": 0.95,
                "best_for": ["complex", "coding", "reasoning"]
            },
            {
                "provider": "anthropic",
                "model": "claude-3-5-sonnet-20241022",
                "cost_per_1k": 0.003 + 0.015,
                "quality_score": 0.93,
                "best_for": ["complex", "creative", "analysis"]
            },
        ]
    
    def estimate_tokens(self, query: str) -> int:
        """Estimate token count."""
        return len(query) // 4 + 500  # Query + estimated response
    
    def route(
        self,
        query: str,
        min_quality: float = 0.7,
        task_type: Optional[str] = None
    ) -> RoutingDecision:
        """Find cheapest model meeting quality requirements."""
        
        estimated_tokens = self.estimate_tokens(query)
        
        candidates = []
        
        for model in self.models:
            # Check quality threshold
            if model["quality_score"] < min_quality:
                continue
            
            # Estimate cost
            estimated_cost = (estimated_tokens / 1000) * model["cost_per_1k"]
            
            # Check cost limit
            if estimated_cost > self.max_cost:
                continue
            
            # Bonus for task match
            task_bonus = 0.1 if task_type and task_type in model["best_for"] else 0
            
            candidates.append({
                **model,
                "estimated_cost": estimated_cost,
                "effective_score": model["quality_score"] + task_bonus
            })
        
        if not candidates:
            # Fall back to cheapest
            candidates = sorted(self.models, key=lambda x: x["cost_per_1k"])
            best = candidates[0]
            estimated_cost = (estimated_tokens / 1000) * best["cost_per_1k"]
        else:
            # Sort by cost (cheapest first)
            candidates.sort(key=lambda x: x["estimated_cost"])
            best = candidates[0]
            estimated_cost = best["estimated_cost"]
        
        return RoutingDecision(
            provider=best["provider"],
            model=best["model"],
            estimated_cost=estimated_cost,
            quality_score=best["quality_score"],
            reasoning=f"Selected for cost efficiency with quality >= {min_quality}"
        )

# Usage
router = CostOptimizedRouter(max_cost_per_request=0.05)

# Simple query -> routes to cheapest
decision = router.route("What is 2+2?", min_quality=0.6)
print(f"Model: {decision.model}, Est. cost: ${decision.estimated_cost:.4f}")

# Complex query with high quality requirement -> routes to premium
decision = router.route(
    "Implement a distributed consensus algorithm in Python",
    min_quality=0.9,
    task_type="coding"
)
print(f"Model: {decision.model}, Est. cost: ${decision.estimated_cost:.4f}")

References

Conclusion

Smart LLM routing is essential for production applications. The naive approach of using GPT-4 for everything works but bleeds money. A well-designed router classifies queries, matches them to appropriate models, handles failures gracefully, and optimizes for your specific cost-quality tradeoffs. Start simple with complexity-based routing to a single provider’s model tiers. Add multi-provider support for resilience and to leverage each model’s strengths. Implement fallback logic to handle outages. Finally, add cost optimization to stay within budget while meeting quality requirements. The investment in routing infrastructure pays for itself quickly—often reducing LLM costs by 50-80% while maintaining or improving response quality. Monitor your routing decisions, track costs per route, and continuously tune your classification and selection logic based on real usage patterns.