Categories

Archives

A sample text widget

Etiam pulvinar consectetur dolor sed malesuada. Ut convallis euismod dolor nec pretium. Nunc ut tristique massa.

Nam sodales mi vitae dolor ullamcorper et vulputate enim accumsan. Morbi orci magna, tincidunt vitae molestie nec, molestie at mi. Nulla nulla lorem, suscipit in posuere in, interdum non magna.

RAG Optimization: Query Rewriting, Hybrid Search, and Re-ranking

Introduction: Retrieval-Augmented Generation (RAG) grounds LLM responses in factual data, but naive implementations often retrieve irrelevant content or miss important information. Optimizing RAG requires attention to every stage: query understanding, retrieval strategies, re-ranking, and context integration. This guide covers practical optimization techniques: query rewriting and expansion, hybrid search combining dense and sparse retrieval, re-ranking with cross-encoders, chunk optimization, and evaluation frameworks that help you measure and improve retrieval quality systematically.

RAG Optimization
RAG Pipeline: Query Rewriting, Hybrid Retrieval, Re-ranking

Query Rewriting and Expansion

from dataclasses import dataclass
from typing import Optional

@dataclass
class RewrittenQuery:
    """Result of query rewriting."""
    
    original: str
    rewritten: str
    expansions: list[str]
    hypothetical_answer: Optional[str] = None

class QueryRewriter:
    """Rewrite queries for better retrieval."""
    
    def __init__(self, client):
        self.client = client
    
    def rewrite(self, query: str) -> RewrittenQuery:
        """Rewrite query for better retrieval."""
        
        prompt = f"""Rewrite this search query to be more specific and effective for retrieval.
Also generate 2-3 alternative phrasings.

Original query: {query}

Respond in JSON format:
{{"rewritten": "improved query", "alternatives": ["alt1", "alt2"]}}"""
        
        response = self.client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}],
            response_format={"type": "json_object"}
        )
        
        import json
        result = json.loads(response.choices[0].message.content)
        
        return RewrittenQuery(
            original=query,
            rewritten=result.get("rewritten", query),
            expansions=result.get("alternatives", [])
        )
    
    def generate_hypothetical_answer(self, query: str) -> str:
        """Generate hypothetical answer for HyDE retrieval."""
        
        prompt = f"""Generate a hypothetical answer to this question.
The answer should be detailed and factual-sounding, even if you're not certain.
This will be used for semantic search.

Question: {query}

Hypothetical answer:"""
        
        response = self.client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}]
        )
        
        return response.choices[0].message.content

class MultiQueryRetriever:
    """Generate multiple queries for comprehensive retrieval."""
    
    def __init__(self, client, retriever):
        self.client = client
        self.retriever = retriever
    
    def generate_queries(self, query: str, num_queries: int = 3) -> list[str]:
        """Generate multiple search queries."""
        
        prompt = f"""Generate {num_queries} different search queries that would help answer this question.
Each query should approach the question from a different angle.

Question: {query}

Return as JSON: {{"queries": ["query1", "query2", ...]}}"""
        
        response = self.client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}],
            response_format={"type": "json_object"}
        )
        
        import json
        result = json.loads(response.choices[0].message.content)
        return result.get("queries", [query])
    
    def retrieve(self, query: str, top_k: int = 5) -> list[dict]:
        """Retrieve using multiple queries."""
        
        queries = self.generate_queries(query)
        queries.append(query)  # Include original
        
        all_results = []
        seen_ids = set()
        
        for q in queries:
            results = self.retriever.search(q, top_k=top_k)
            
            for result in results:
                if result["id"] not in seen_ids:
                    all_results.append(result)
                    seen_ids.add(result["id"])
        
        return all_results[:top_k * 2]

# Step-back prompting for complex queries
class StepBackRetriever:
    """Use step-back prompting for complex queries."""
    
    def __init__(self, client, retriever):
        self.client = client
        self.retriever = retriever
    
    def get_step_back_query(self, query: str) -> str:
        """Generate broader step-back query."""
        
        prompt = f"""Given this specific question, generate a more general "step-back" question
that would help provide background context.

Specific question: {query}

Step-back question (more general):"""
        
        response = self.client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}]
        )
        
        return response.choices[0].message.content.strip()
    
    def retrieve(self, query: str, top_k: int = 5) -> dict:
        """Retrieve with step-back context."""
        
        step_back = self.get_step_back_query(query)
        
        # Get background context
        background = self.retriever.search(step_back, top_k=top_k // 2)
        
        # Get specific results
        specific = self.retriever.search(query, top_k=top_k)
        
        return {
            "background": background,
            "specific": specific
        }

Hybrid Search

from dataclasses import dataclass
from typing import Callable
import numpy as np

@dataclass
class SearchResult:
    """A search result with scores."""
    
    id: str
    content: str
    dense_score: float = 0.0
    sparse_score: float = 0.0
    combined_score: float = 0.0
    metadata: dict = None

class HybridRetriever:
    """Combine dense and sparse retrieval."""
    
    def __init__(
        self,
        embedding_client,
        embedding_model: str = "text-embedding-3-small",
        alpha: float = 0.5  # Weight for dense vs sparse
    ):
        self.embedding_client = embedding_client
        self.embedding_model = embedding_model
        self.alpha = alpha
        
        # Document storage
        self.documents: list[dict] = []
        self.embeddings: list[list[float]] = []
        self.bm25_index = None
    
    def _embed(self, texts: list[str]) -> list[list[float]]:
        """Get embeddings for texts."""
        
        response = self.embedding_client.embeddings.create(
            model=self.embedding_model,
            input=texts
        )
        
        return [d.embedding for d in response.data]
    
    def add_documents(self, documents: list[dict]):
        """Add documents to index."""
        
        # Store documents
        self.documents.extend(documents)
        
        # Create embeddings
        texts = [d["content"] for d in documents]
        embeddings = self._embed(texts)
        self.embeddings.extend(embeddings)
        
        # Build BM25 index
        self._build_bm25_index()
    
    def _build_bm25_index(self):
        """Build BM25 index for sparse retrieval."""
        
        from rank_bm25 import BM25Okapi
        
        # Tokenize documents
        tokenized = [
            doc["content"].lower().split()
            for doc in self.documents
        ]
        
        self.bm25_index = BM25Okapi(tokenized)
    
    def _dense_search(self, query: str, top_k: int) -> list[tuple[int, float]]:
        """Dense vector search."""
        
        query_embedding = self._embed([query])[0]
        
        # Calculate similarities
        similarities = []
        for i, doc_embedding in enumerate(self.embeddings):
            sim = np.dot(query_embedding, doc_embedding) / (
                np.linalg.norm(query_embedding) * np.linalg.norm(doc_embedding)
            )
            similarities.append((i, sim))
        
        # Sort by similarity
        similarities.sort(key=lambda x: x[1], reverse=True)
        return similarities[:top_k]
    
    def _sparse_search(self, query: str, top_k: int) -> list[tuple[int, float]]:
        """Sparse BM25 search."""
        
        if not self.bm25_index:
            return []
        
        tokenized_query = query.lower().split()
        scores = self.bm25_index.get_scores(tokenized_query)
        
        # Get top results
        indexed_scores = [(i, s) for i, s in enumerate(scores)]
        indexed_scores.sort(key=lambda x: x[1], reverse=True)
        
        return indexed_scores[:top_k]
    
    def search(self, query: str, top_k: int = 10) -> list[SearchResult]:
        """Hybrid search combining dense and sparse."""
        
        # Get results from both methods
        dense_results = self._dense_search(query, top_k * 2)
        sparse_results = self._sparse_search(query, top_k * 2)
        
        # Normalize scores
        dense_max = max(s for _, s in dense_results) if dense_results else 1
        sparse_max = max(s for _, s in sparse_results) if sparse_results else 1
        
        # Combine scores
        score_map = {}
        
        for idx, score in dense_results:
            normalized = score / dense_max if dense_max > 0 else 0
            score_map[idx] = {"dense": normalized, "sparse": 0}
        
        for idx, score in sparse_results:
            normalized = score / sparse_max if sparse_max > 0 else 0
            if idx in score_map:
                score_map[idx]["sparse"] = normalized
            else:
                score_map[idx] = {"dense": 0, "sparse": normalized}
        
        # Calculate combined scores
        results = []
        for idx, scores in score_map.items():
            combined = (
                self.alpha * scores["dense"] +
                (1 - self.alpha) * scores["sparse"]
            )
            
            results.append(SearchResult(
                id=str(idx),
                content=self.documents[idx]["content"],
                dense_score=scores["dense"],
                sparse_score=scores["sparse"],
                combined_score=combined,
                metadata=self.documents[idx].get("metadata")
            ))
        
        # Sort by combined score
        results.sort(key=lambda x: x.combined_score, reverse=True)
        return results[:top_k]

# Reciprocal Rank Fusion
class RRFRetriever:
    """Combine multiple retrievers using RRF."""
    
    def __init__(self, retrievers: list, k: int = 60):
        self.retrievers = retrievers
        self.k = k  # RRF constant
    
    def search(self, query: str, top_k: int = 10) -> list[dict]:
        """Search using RRF fusion."""
        
        # Get results from all retrievers
        all_rankings = []
        
        for retriever in self.retrievers:
            results = retriever.search(query, top_k=top_k * 2)
            all_rankings.append(results)
        
        # Calculate RRF scores
        rrf_scores = {}
        
        for ranking in all_rankings:
            for rank, result in enumerate(ranking):
                doc_id = result.get("id") or result.get("content")[:50]
                
                if doc_id not in rrf_scores:
                    rrf_scores[doc_id] = {
                        "score": 0,
                        "content": result.get("content"),
                        "metadata": result.get("metadata")
                    }
                
                # RRF formula: 1 / (k + rank)
                rrf_scores[doc_id]["score"] += 1 / (self.k + rank + 1)
        
        # Sort by RRF score
        sorted_results = sorted(
            rrf_scores.items(),
            key=lambda x: x[1]["score"],
            reverse=True
        )
        
        return [
            {
                "id": doc_id,
                "content": data["content"],
                "score": data["score"],
                "metadata": data["metadata"]
            }
            for doc_id, data in sorted_results[:top_k]
        ]

Re-ranking

from dataclasses import dataclass

@dataclass
class RankedResult:
    """A re-ranked result."""
    
    content: str
    original_rank: int
    new_rank: int
    relevance_score: float
    metadata: dict = None

class CrossEncoderReranker:
    """Re-rank using cross-encoder model."""
    
    def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
        from sentence_transformers import CrossEncoder
        self.model = CrossEncoder(model_name)
    
    def rerank(
        self,
        query: str,
        documents: list[dict],
        top_k: int = 5
    ) -> list[RankedResult]:
        """Re-rank documents using cross-encoder."""
        
        # Create query-document pairs
        pairs = [
            [query, doc["content"]]
            for doc in documents
        ]
        
        # Get relevance scores
        scores = self.model.predict(pairs)
        
        # Create ranked results
        results = []
        for i, (doc, score) in enumerate(zip(documents, scores)):
            results.append(RankedResult(
                content=doc["content"],
                original_rank=i,
                new_rank=0,  # Will be set after sorting
                relevance_score=float(score),
                metadata=doc.get("metadata")
            ))
        
        # Sort by relevance score
        results.sort(key=lambda x: x.relevance_score, reverse=True)
        
        # Update ranks
        for i, result in enumerate(results):
            result.new_rank = i
        
        return results[:top_k]

class LLMReranker:
    """Re-rank using LLM."""
    
    def __init__(self, client):
        self.client = client
    
    def rerank(
        self,
        query: str,
        documents: list[dict],
        top_k: int = 5
    ) -> list[RankedResult]:
        """Re-rank using LLM scoring."""
        
        results = []
        
        for i, doc in enumerate(documents):
            score = self._score_relevance(query, doc["content"])
            
            results.append(RankedResult(
                content=doc["content"],
                original_rank=i,
                new_rank=0,
                relevance_score=score,
                metadata=doc.get("metadata")
            ))
        
        # Sort and update ranks
        results.sort(key=lambda x: x.relevance_score, reverse=True)
        
        for i, result in enumerate(results):
            result.new_rank = i
        
        return results[:top_k]
    
    def _score_relevance(self, query: str, document: str) -> float:
        """Score document relevance to query."""
        
        prompt = f"""Rate how relevant this document is to the query on a scale of 0-10.

Query: {query}

Document: {document[:500]}

Respond with just a number (0-10):"""
        
        response = self.client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}],
            max_tokens=5
        )
        
        try:
            return float(response.choices[0].message.content.strip())
        except:
            return 5.0

# Cohere reranker
class CohereReranker:
    """Re-rank using Cohere rerank API."""
    
    def __init__(self, api_key: str):
        import cohere
        self.client = cohere.Client(api_key)
    
    def rerank(
        self,
        query: str,
        documents: list[dict],
        top_k: int = 5
    ) -> list[RankedResult]:
        """Re-rank using Cohere."""
        
        docs = [d["content"] for d in documents]
        
        response = self.client.rerank(
            query=query,
            documents=docs,
            top_n=top_k,
            model="rerank-english-v3.0"
        )
        
        results = []
        for i, result in enumerate(response.results):
            results.append(RankedResult(
                content=documents[result.index]["content"],
                original_rank=result.index,
                new_rank=i,
                relevance_score=result.relevance_score,
                metadata=documents[result.index].get("metadata")
            ))
        
        return results

Chunk Optimization

from dataclasses import dataclass
from typing import Optional

@dataclass
class Chunk:
    """A document chunk."""
    
    content: str
    metadata: dict
    parent_id: Optional[str] = None
    chunk_index: int = 0

class SemanticChunker:
    """Chunk documents based on semantic boundaries."""
    
    def __init__(
        self,
        embedding_client,
        similarity_threshold: float = 0.8
    ):
        self.embedding_client = embedding_client
        self.threshold = similarity_threshold
    
    def chunk(self, text: str, metadata: dict = None) -> list[Chunk]:
        """Chunk text at semantic boundaries."""
        
        # Split into sentences
        import re
        sentences = re.split(r'(?<=[.!?])\s+', text)
        
        if len(sentences) <= 1:
            return [Chunk(content=text, metadata=metadata or {})]
        
        # Get embeddings for sentences
        embeddings = self._embed(sentences)
        
        # Find semantic boundaries
        chunks = []
        current_chunk = [sentences[0]]
        
        for i in range(1, len(sentences)):
            similarity = self._cosine_similarity(
                embeddings[i-1],
                embeddings[i]
            )
            
            if similarity < self.threshold:
                # Semantic boundary - start new chunk
                chunks.append(Chunk(
                    content=" ".join(current_chunk),
                    metadata=metadata or {},
                    chunk_index=len(chunks)
                ))
                current_chunk = [sentences[i]]
            else:
                current_chunk.append(sentences[i])
        
        # Add final chunk
        if current_chunk:
            chunks.append(Chunk(
                content=" ".join(current_chunk),
                metadata=metadata or {},
                chunk_index=len(chunks)
            ))
        
        return chunks
    
    def _embed(self, texts: list[str]) -> list[list[float]]:
        response = self.embedding_client.embeddings.create(
            model="text-embedding-3-small",
            input=texts
        )
        return [d.embedding for d in response.data]
    
    def _cosine_similarity(self, a, b) -> float:
        import numpy as np
        a, b = np.array(a), np.array(b)
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

class HierarchicalChunker:
    """Create hierarchical chunks with parent-child relationships."""
    
    def __init__(
        self,
        large_chunk_size: int = 2000,
        small_chunk_size: int = 400,
        overlap: int = 50
    ):
        self.large_size = large_chunk_size
        self.small_size = small_chunk_size
        self.overlap = overlap
    
    def chunk(self, text: str, doc_id: str) -> dict:
        """Create hierarchical chunks."""
        
        # Create large chunks (parents)
        large_chunks = self._split_text(text, self.large_size, self.overlap)
        
        result = {
            "parents": [],
            "children": []
        }
        
        for i, large_chunk in enumerate(large_chunks):
            parent_id = f"{doc_id}_parent_{i}"
            
            result["parents"].append(Chunk(
                content=large_chunk,
                metadata={"doc_id": doc_id, "type": "parent"},
                chunk_index=i
            ))
            
            # Create small chunks (children)
            small_chunks = self._split_text(large_chunk, self.small_size, self.overlap)
            
            for j, small_chunk in enumerate(small_chunks):
                result["children"].append(Chunk(
                    content=small_chunk,
                    metadata={"doc_id": doc_id, "type": "child"},
                    parent_id=parent_id,
                    chunk_index=j
                ))
        
        return result
    
    def _split_text(self, text: str, chunk_size: int, overlap: int) -> list[str]:
        """Split text into overlapping chunks."""
        
        chunks = []
        start = 0
        
        while start < len(text):
            end = start + chunk_size
            chunk = text[start:end]
            
            # Try to break at sentence boundary
            if end < len(text):
                last_period = chunk.rfind('.')
                if last_period > chunk_size // 2:
                    chunk = chunk[:last_period + 1]
                    end = start + last_period + 1
            
            chunks.append(chunk.strip())
            start = end - overlap
        
        return chunks

# Parent document retriever
class ParentDocumentRetriever:
    """Retrieve child chunks but return parent context."""
    
    def __init__(self, child_retriever, parent_store: dict):
        self.child_retriever = child_retriever
        self.parent_store = parent_store
    
    def search(self, query: str, top_k: int = 5) -> list[dict]:
        """Search children, return parents."""
        
        # Search child chunks
        child_results = self.child_retriever.search(query, top_k=top_k * 2)
        
        # Get unique parents
        seen_parents = set()
        results = []
        
        for child in child_results:
            parent_id = child.get("parent_id")
            
            if parent_id and parent_id not in seen_parents:
                seen_parents.add(parent_id)
                
                parent = self.parent_store.get(parent_id)
                if parent:
                    results.append({
                        "content": parent["content"],
                        "metadata": parent.get("metadata"),
                        "matched_child": child["content"]
                    })
        
        return results[:top_k]

Production RAG Service

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional

app = FastAPI()

# Initialize components
from openai import OpenAI
client = OpenAI()

hybrid_retriever = HybridRetriever(client)
query_rewriter = QueryRewriter(client)
reranker = LLMReranker(client)

class RAGRequest(BaseModel):
    query: str
    top_k: int = 5
    use_rewriting: bool = True
    use_reranking: bool = True

class IndexRequest(BaseModel):
    documents: list[dict]

@app.post("/v1/index")
async def index_documents(request: IndexRequest):
    """Index documents for retrieval."""
    
    hybrid_retriever.add_documents(request.documents)
    
    return {
        "indexed": len(request.documents),
        "total": len(hybrid_retriever.documents)
    }

@app.post("/v1/retrieve")
async def retrieve(request: RAGRequest):
    """Retrieve relevant documents."""
    
    query = request.query
    
    # Query rewriting
    if request.use_rewriting:
        rewritten = query_rewriter.rewrite(query)
        query = rewritten.rewritten
    
    # Hybrid search
    results = hybrid_retriever.search(query, top_k=request.top_k * 2)
    
    # Re-ranking
    if request.use_reranking and results:
        documents = [{"content": r.content, "metadata": r.metadata} for r in results]
        ranked = reranker.rerank(request.query, documents, top_k=request.top_k)
        
        return {
            "results": [
                {
                    "content": r.content,
                    "score": r.relevance_score,
                    "original_rank": r.original_rank,
                    "new_rank": r.new_rank
                }
                for r in ranked
            ]
        }
    
    return {
        "results": [
            {
                "content": r.content,
                "score": r.combined_score,
                "dense_score": r.dense_score,
                "sparse_score": r.sparse_score
            }
            for r in results[:request.top_k]
        ]
    }

@app.post("/v1/rag")
async def rag_query(request: RAGRequest):
    """Full RAG pipeline with generation."""
    
    # Retrieve
    retrieval_response = await retrieve(request)
    results = retrieval_response["results"]
    
    if not results:
        return {"answer": "No relevant information found.", "sources": []}
    
    # Build context
    context = "\n\n".join([
        f"Source {i+1}: {r['content']}"
        for i, r in enumerate(results)
    ])
    
    # Generate answer
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "system",
                "content": "Answer based on the provided context. Cite sources."
            },
            {
                "role": "user",
                "content": f"Context:\n{context}\n\nQuestion: {request.query}"
            }
        ]
    )
    
    return {
        "answer": response.choices[0].message.content,
        "sources": results
    }

@app.get("/health")
async def health():
    return {
        "status": "healthy",
        "documents_indexed": len(hybrid_retriever.documents)
    }

References

Conclusion

RAG optimization is iterative—measure retrieval quality, identify failure modes, and apply targeted improvements. Start with query rewriting to handle ambiguous or poorly-formed queries. Implement hybrid search combining dense embeddings with sparse BM25 to capture both semantic similarity and keyword matches. Add re-ranking with cross-encoders or LLMs to improve precision on the final results. Optimize chunking strategy based on your content—semantic chunking for varied documents, hierarchical chunking for long documents where context matters. Use evaluation frameworks to measure recall, precision, and answer quality systematically. The goal is retrieving the most relevant context while staying within token limits—every improvement in retrieval quality directly improves generation quality.


Discover more from Code, Cloud & Context

Subscribe to get the latest posts sent to your email.

Leave a Reply

You can use these HTML tags

<a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <s> <strike> <strong>

  

  

  

This site uses Akismet to reduce spam. Learn how your comment data is processed.