Embedding Strategies: Model Selection, Batching, and Long Document Handling

Introduction: Embeddings are the foundation of semantic search, RAG systems, and similarity-based applications. Choosing the right embedding model and strategy significantly impacts retrieval quality, latency, and cost. Different models excel at different tasks—some optimize for semantic similarity, others for retrieval, and some for specific domains. This guide covers practical embedding strategies: model selection based on benchmarks, text preprocessing for better embeddings, batching and caching for efficiency, dimensionality reduction, and techniques for handling long documents that exceed model context limits.

Embedding Strategies
Embedding Pipeline: Text Preprocessing, Model, Normalization

Embedding Model Selection

from dataclasses import dataclass
from typing import Optional
from enum import Enum

class EmbeddingTask(str, Enum):
    RETRIEVAL = "retrieval"
    SIMILARITY = "similarity"
    CLUSTERING = "clustering"
    CLASSIFICATION = "classification"

@dataclass
class EmbeddingModel:
    """Embedding model configuration."""
    
    name: str
    provider: str
    dimensions: int
    max_tokens: int
    cost_per_1m_tokens: float
    best_for: list[EmbeddingTask]

# Popular embedding models
EMBEDDING_MODELS = {
    "text-embedding-3-small": EmbeddingModel(
        name="text-embedding-3-small",
        provider="openai",
        dimensions=1536,
        max_tokens=8191,
        cost_per_1m_tokens=0.02,
        best_for=[EmbeddingTask.RETRIEVAL, EmbeddingTask.SIMILARITY]
    ),
    "text-embedding-3-large": EmbeddingModel(
        name="text-embedding-3-large",
        provider="openai",
        dimensions=3072,
        max_tokens=8191,
        cost_per_1m_tokens=0.13,
        best_for=[EmbeddingTask.RETRIEVAL, EmbeddingTask.CLUSTERING]
    ),
    "voyage-3": EmbeddingModel(
        name="voyage-3",
        provider="voyage",
        dimensions=1024,
        max_tokens=32000,
        cost_per_1m_tokens=0.06,
        best_for=[EmbeddingTask.RETRIEVAL]
    ),
    "embed-english-v3.0": EmbeddingModel(
        name="embed-english-v3.0",
        provider="cohere",
        dimensions=1024,
        max_tokens=512,
        cost_per_1m_tokens=0.10,
        best_for=[EmbeddingTask.RETRIEVAL, EmbeddingTask.CLASSIFICATION]
    ),
    "all-MiniLM-L6-v2": EmbeddingModel(
        name="all-MiniLM-L6-v2",
        provider="sentence-transformers",
        dimensions=384,
        max_tokens=256,
        cost_per_1m_tokens=0.0,  # Free, local
        best_for=[EmbeddingTask.SIMILARITY, EmbeddingTask.CLUSTERING]
    ),
}

def select_model(
    task: EmbeddingTask,
    max_cost: float = None,
    min_dimensions: int = None,
    prefer_local: bool = False
) -> EmbeddingModel:
    """Select best model for task and constraints."""
    
    candidates = []
    
    for model in EMBEDDING_MODELS.values():
        # Filter by task
        if task not in model.best_for:
            continue
        
        # Filter by cost
        if max_cost and model.cost_per_1m_tokens > max_cost:
            continue
        
        # Filter by dimensions
        if min_dimensions and model.dimensions < min_dimensions:
            continue
        
        # Prefer local if requested
        if prefer_local and model.provider != "sentence-transformers":
            continue
        
        candidates.append(model)
    
    if not candidates:
        return EMBEDDING_MODELS["text-embedding-3-small"]
    
    # Sort by dimensions (quality proxy)
    candidates.sort(key=lambda m: m.dimensions, reverse=True)
    return candidates[0]

# Unified embedding client
class EmbeddingClient:
    """Unified client for multiple embedding providers."""
    
    def __init__(self, model: EmbeddingModel):
        self.model = model
        self._init_client()
    
    def _init_client(self):
        """Initialize provider-specific client."""
        
        if self.model.provider == "openai":
            from openai import OpenAI
            self.client = OpenAI()
        
        elif self.model.provider == "cohere":
            import cohere
            self.client = cohere.Client()
        
        elif self.model.provider == "voyage":
            import voyageai
            self.client = voyageai.Client()
        
        elif self.model.provider == "sentence-transformers":
            from sentence_transformers import SentenceTransformer
            self.client = SentenceTransformer(self.model.name)
    
    def embed(self, texts: list[str]) -> list[list[float]]:
        """Get embeddings for texts."""
        
        if self.model.provider == "openai":
            response = self.client.embeddings.create(
                model=self.model.name,
                input=texts
            )
            return [d.embedding for d in response.data]
        
        elif self.model.provider == "cohere":
            response = self.client.embed(
                texts=texts,
                model=self.model.name,
                input_type="search_document"
            )
            return response.embeddings
        
        elif self.model.provider == "voyage":
            response = self.client.embed(
                texts,
                model=self.model.name
            )
            return response.embeddings
        
        elif self.model.provider == "sentence-transformers":
            return self.client.encode(texts).tolist()

Text Preprocessing

import re
from dataclasses import dataclass
from typing import Optional

@dataclass
class PreprocessingConfig:
    """Configuration for text preprocessing."""
    
    lowercase: bool = False
    remove_urls: bool = True
    remove_emails: bool = True
    normalize_whitespace: bool = True
    remove_special_chars: bool = False
    max_length: Optional[int] = None

class TextPreprocessor:
    """Preprocess text for better embeddings."""
    
    def __init__(self, config: PreprocessingConfig = None):
        self.config = config or PreprocessingConfig()
    
    def preprocess(self, text: str) -> str:
        """Preprocess text."""
        
        result = text
        
        # Remove URLs
        if self.config.remove_urls:
            result = re.sub(
                r'https?://\S+|www\.\S+',
                '',
                result
            )
        
        # Remove emails
        if self.config.remove_emails:
            result = re.sub(
                r'\S+@\S+\.\S+',
                '',
                result
            )
        
        # Normalize whitespace
        if self.config.normalize_whitespace:
            result = ' '.join(result.split())
        
        # Remove special characters
        if self.config.remove_special_chars:
            result = re.sub(r'[^\w\s]', '', result)
        
        # Lowercase
        if self.config.lowercase:
            result = result.lower()
        
        # Truncate
        if self.config.max_length:
            result = result[:self.config.max_length]
        
        return result.strip()
    
    def preprocess_batch(self, texts: list[str]) -> list[str]:
        """Preprocess multiple texts."""
        return [self.preprocess(t) for t in texts]

# Query-specific preprocessing
class QueryPreprocessor:
    """Preprocess queries differently from documents."""
    
    def __init__(self, add_instruction: bool = True):
        self.add_instruction = add_instruction
        
        # Instructions for different embedding models
        self.instructions = {
            "text-embedding-3-small": "",
            "text-embedding-3-large": "",
            "voyage-3": "",
            "embed-english-v3.0": "",  # Cohere uses input_type
            "e5-large-v2": "query: ",
            "bge-large-en-v1.5": "Represent this sentence for searching: ",
        }
    
    def preprocess_query(self, query: str, model_name: str) -> str:
        """Preprocess query for specific model."""
        
        # Clean query
        query = ' '.join(query.split())
        
        # Add instruction prefix if needed
        if self.add_instruction:
            prefix = self.instructions.get(model_name, "")
            query = prefix + query
        
        return query
    
    def preprocess_document(self, doc: str, model_name: str) -> str:
        """Preprocess document for specific model."""
        
        # Clean document
        doc = ' '.join(doc.split())
        
        # Some models need document prefix
        doc_prefixes = {
            "e5-large-v2": "passage: ",
            "bge-large-en-v1.5": "",
        }
        
        prefix = doc_prefixes.get(model_name, "")
        return prefix + doc

Batching and Caching

from dataclasses import dataclass
from typing import Optional
import hashlib
import json

class BatchedEmbedder:
    """Embed texts in efficient batches."""
    
    def __init__(
        self,
        client: EmbeddingClient,
        batch_size: int = 100,
        max_tokens_per_batch: int = 8000
    ):
        self.client = client
        self.batch_size = batch_size
        self.max_tokens = max_tokens_per_batch
    
    def embed(self, texts: list[str]) -> list[list[float]]:
        """Embed texts in batches."""
        
        all_embeddings = []
        
        for batch in self._create_batches(texts):
            embeddings = self.client.embed(batch)
            all_embeddings.extend(embeddings)
        
        return all_embeddings
    
    def _create_batches(self, texts: list[str]) -> list[list[str]]:
        """Create batches respecting size and token limits."""
        
        batches = []
        current_batch = []
        current_tokens = 0
        
        for text in texts:
            # Rough token estimate
            tokens = len(text.split()) * 1.3
            
            if (len(current_batch) >= self.batch_size or
                current_tokens + tokens > self.max_tokens):
                
                if current_batch:
                    batches.append(current_batch)
                
                current_batch = [text]
                current_tokens = tokens
            else:
                current_batch.append(text)
                current_tokens += tokens
        
        if current_batch:
            batches.append(current_batch)
        
        return batches

class EmbeddingCache:
    """Cache embeddings to avoid recomputation."""
    
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.cache: dict[str, list[float]] = {}
    
    def _hash_text(self, text: str) -> str:
        """Create hash key for text."""
        
        key = f"{self.model_name}:{text}"
        return hashlib.md5(key.encode()).hexdigest()
    
    def get(self, text: str) -> Optional[list[float]]:
        """Get cached embedding."""
        
        key = self._hash_text(text)
        return self.cache.get(key)
    
    def set(self, text: str, embedding: list[float]):
        """Cache embedding."""
        
        key = self._hash_text(text)
        self.cache[key] = embedding
    
    def get_batch(self, texts: list[str]) -> tuple[list[list[float]], list[str]]:
        """Get cached embeddings, return uncached texts."""
        
        cached = []
        uncached = []
        
        for text in texts:
            embedding = self.get(text)
            if embedding:
                cached.append(embedding)
            else:
                uncached.append(text)
        
        return cached, uncached

class CachedEmbedder:
    """Embedder with caching layer."""
    
    def __init__(self, client: EmbeddingClient):
        self.client = client
        self.cache = EmbeddingCache(client.model.name)
        self.batcher = BatchedEmbedder(client)
        
        self.stats = {"hits": 0, "misses": 0}
    
    def embed(self, texts: list[str]) -> list[list[float]]:
        """Embed with caching."""
        
        results = [None] * len(texts)
        uncached_indices = []
        uncached_texts = []
        
        # Check cache
        for i, text in enumerate(texts):
            cached = self.cache.get(text)
            if cached:
                results[i] = cached
                self.stats["hits"] += 1
            else:
                uncached_indices.append(i)
                uncached_texts.append(text)
                self.stats["misses"] += 1
        
        # Embed uncached
        if uncached_texts:
            new_embeddings = self.batcher.embed(uncached_texts)
            
            for i, (idx, text) in enumerate(zip(uncached_indices, uncached_texts)):
                embedding = new_embeddings[i]
                results[idx] = embedding
                self.cache.set(text, embedding)
        
        return results
    
    def get_stats(self) -> dict:
        """Get cache statistics."""
        
        total = self.stats["hits"] + self.stats["misses"]
        hit_rate = self.stats["hits"] / total if total > 0 else 0
        
        return {
            "hits": self.stats["hits"],
            "misses": self.stats["misses"],
            "hit_rate": hit_rate
        }

Long Document Handling

from dataclasses import dataclass
import numpy as np

@dataclass
class ChunkEmbedding:
    """Embedding for a document chunk."""
    
    chunk_index: int
    text: str
    embedding: list[float]

class LongDocumentEmbedder:
    """Handle documents exceeding model context."""
    
    def __init__(
        self,
        client: EmbeddingClient,
        chunk_size: int = 500,
        overlap: int = 50
    ):
        self.client = client
        self.chunk_size = chunk_size
        self.overlap = overlap
    
    def embed_document(
        self,
        document: str,
        strategy: str = "mean"
    ) -> list[float]:
        """Embed long document using specified strategy."""
        
        # Split into chunks
        chunks = self._split_document(document)
        
        if len(chunks) == 1:
            return self.client.embed([chunks[0]])[0]
        
        # Embed all chunks
        chunk_embeddings = self.client.embed(chunks)
        
        # Combine based on strategy
        if strategy == "mean":
            return self._mean_pooling(chunk_embeddings)
        elif strategy == "weighted":
            return self._weighted_pooling(chunk_embeddings, chunks)
        elif strategy == "first":
            return chunk_embeddings[0]
        elif strategy == "max":
            return self._max_pooling(chunk_embeddings)
        else:
            return self._mean_pooling(chunk_embeddings)
    
    def _split_document(self, document: str) -> list[str]:
        """Split document into overlapping chunks."""
        
        words = document.split()
        chunks = []
        
        start = 0
        while start < len(words):
            end = start + self.chunk_size
            chunk = ' '.join(words[start:end])
            chunks.append(chunk)
            start = end - self.overlap
        
        return chunks
    
    def _mean_pooling(self, embeddings: list[list[float]]) -> list[float]:
        """Average all chunk embeddings."""
        
        arr = np.array(embeddings)
        return np.mean(arr, axis=0).tolist()
    
    def _weighted_pooling(
        self,
        embeddings: list[list[float]],
        chunks: list[str]
    ) -> list[float]:
        """Weight by chunk length."""
        
        weights = np.array([len(c.split()) for c in chunks])
        weights = weights / weights.sum()
        
        arr = np.array(embeddings)
        weighted = np.average(arr, axis=0, weights=weights)
        
        return weighted.tolist()
    
    def _max_pooling(self, embeddings: list[list[float]]) -> list[float]:
        """Take max across dimensions."""
        
        arr = np.array(embeddings)
        return np.max(arr, axis=0).tolist()
    
    def embed_with_chunks(self, document: str) -> list[ChunkEmbedding]:
        """Return individual chunk embeddings."""
        
        chunks = self._split_document(document)
        embeddings = self.client.embed(chunks)
        
        return [
            ChunkEmbedding(
                chunk_index=i,
                text=chunk,
                embedding=embedding
            )
            for i, (chunk, embedding) in enumerate(zip(chunks, embeddings))
        ]

# Late chunking (embed then chunk)
class LateChunker:
    """Chunk after embedding for better context."""
    
    def __init__(self, client: EmbeddingClient):
        self.client = client
    
    def embed_and_chunk(
        self,
        document: str,
        chunk_size: int = 100
    ) -> list[ChunkEmbedding]:
        """Embed full document, then extract chunk embeddings."""
        
        # This requires a model that returns token-level embeddings
        # For demonstration, we'll use sentence-level
        
        sentences = self._split_sentences(document)
        embeddings = self.client.embed(sentences)
        
        # Group into chunks
        chunks = []
        current_chunk = []
        current_embeddings = []
        current_length = 0
        
        for sent, emb in zip(sentences, embeddings):
            sent_length = len(sent.split())
            
            if current_length + sent_length > chunk_size and current_chunk:
                # Create chunk embedding (mean of sentence embeddings)
                chunk_text = ' '.join(current_chunk)
                chunk_emb = np.mean(current_embeddings, axis=0).tolist()
                
                chunks.append(ChunkEmbedding(
                    chunk_index=len(chunks),
                    text=chunk_text,
                    embedding=chunk_emb
                ))
                
                current_chunk = [sent]
                current_embeddings = [emb]
                current_length = sent_length
            else:
                current_chunk.append(sent)
                current_embeddings.append(emb)
                current_length += sent_length
        
        # Add final chunk
        if current_chunk:
            chunk_text = ' '.join(current_chunk)
            chunk_emb = np.mean(current_embeddings, axis=0).tolist()
            
            chunks.append(ChunkEmbedding(
                chunk_index=len(chunks),
                text=chunk_text,
                embedding=chunk_emb
            ))
        
        return chunks
    
    def _split_sentences(self, text: str) -> list[str]:
        """Split text into sentences."""
        
        import re
        sentences = re.split(r'(?<=[.!?])\s+', text)
        return [s.strip() for s in sentences if s.strip()]

Dimensionality Reduction

import numpy as np
from typing import Optional

class DimensionalityReducer:
    """Reduce embedding dimensions for efficiency."""
    
    def __init__(self, target_dims: int = 256):
        self.target_dims = target_dims
        self.pca = None
        self.fitted = False
    
    def fit(self, embeddings: list[list[float]]):
        """Fit PCA on embeddings."""
        
        from sklearn.decomposition import PCA
        
        arr = np.array(embeddings)
        
        self.pca = PCA(n_components=self.target_dims)
        self.pca.fit(arr)
        self.fitted = True
    
    def transform(self, embeddings: list[list[float]]) -> list[list[float]]:
        """Reduce dimensions."""
        
        if not self.fitted:
            raise ValueError("Must fit before transform")
        
        arr = np.array(embeddings)
        reduced = self.pca.transform(arr)
        
        return reduced.tolist()
    
    def fit_transform(self, embeddings: list[list[float]]) -> list[list[float]]:
        """Fit and transform."""
        
        self.fit(embeddings)
        return self.transform(embeddings)

# Matryoshka embeddings (OpenAI text-embedding-3)
class MatryoshkaEmbedder:
    """Use Matryoshka embeddings for flexible dimensions."""
    
    def __init__(self, client, full_dims: int = 3072):
        self.client = client
        self.full_dims = full_dims
    
    def embed(
        self,
        texts: list[str],
        dimensions: int = None
    ) -> list[list[float]]:
        """Embed with optional dimension reduction."""
        
        # OpenAI text-embedding-3 supports native dimension reduction
        response = self.client.embeddings.create(
            model="text-embedding-3-large",
            input=texts,
            dimensions=dimensions  # Can be 256, 1024, 3072
        )
        
        return [d.embedding for d in response.data]
    
    def embed_multi_resolution(
        self,
        texts: list[str],
        dimensions: list[int] = [256, 1024, 3072]
    ) -> dict[int, list[list[float]]]:
        """Get embeddings at multiple resolutions."""
        
        results = {}
        
        for dim in dimensions:
            embeddings = self.embed(texts, dimensions=dim)
            results[dim] = embeddings
        
        return results

# Binary quantization
class BinaryQuantizer:
    """Quantize embeddings to binary for fast search."""
    
    def quantize(self, embeddings: list[list[float]]) -> list[bytes]:
        """Convert to binary embeddings."""
        
        binary_embeddings = []
        
        for embedding in embeddings:
            # Convert to binary (1 if > 0, else 0)
            binary = np.array(embedding) > 0
            
            # Pack into bytes
            packed = np.packbits(binary)
            binary_embeddings.append(packed.tobytes())
        
        return binary_embeddings
    
    def hamming_distance(self, a: bytes, b: bytes) -> int:
        """Calculate Hamming distance between binary embeddings."""
        
        a_arr = np.frombuffer(a, dtype=np.uint8)
        b_arr = np.frombuffer(b, dtype=np.uint8)
        
        xor = np.bitwise_xor(a_arr, b_arr)
        return np.unpackbits(xor).sum()

Production Embedding Service

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional

app = FastAPI()

# Initialize components
from openai import OpenAI
client = OpenAI()

model = EMBEDDING_MODELS["text-embedding-3-small"]
embedding_client = EmbeddingClient(model)
cached_embedder = CachedEmbedder(embedding_client)
long_doc_embedder = LongDocumentEmbedder(embedding_client)
preprocessor = TextPreprocessor()

class EmbedRequest(BaseModel):
    texts: list[str]
    preprocess: bool = True
    dimensions: Optional[int] = None

class LongDocRequest(BaseModel):
    document: str
    strategy: str = "mean"
    return_chunks: bool = False

@app.post("/v1/embed")
async def embed_texts(request: EmbedRequest):
    """Embed texts."""
    
    texts = request.texts
    
    # Preprocess
    if request.preprocess:
        texts = preprocessor.preprocess_batch(texts)
    
    # Embed
    embeddings = cached_embedder.embed(texts)
    
    # Reduce dimensions if requested
    if request.dimensions and request.dimensions < model.dimensions:
        # Use Matryoshka if available
        response = client.embeddings.create(
            model="text-embedding-3-small",
            input=texts,
            dimensions=request.dimensions
        )
        embeddings = [d.embedding for d in response.data]
    
    return {
        "embeddings": embeddings,
        "model": model.name,
        "dimensions": request.dimensions or model.dimensions
    }

@app.post("/v1/embed/document")
async def embed_long_document(request: LongDocRequest):
    """Embed long document."""
    
    if request.return_chunks:
        chunks = long_doc_embedder.embed_with_chunks(request.document)
        
        return {
            "chunks": [
                {
                    "index": c.chunk_index,
                    "text": c.text,
                    "embedding": c.embedding
                }
                for c in chunks
            ]
        }
    
    embedding = long_doc_embedder.embed_document(
        request.document,
        strategy=request.strategy
    )
    
    return {
        "embedding": embedding,
        "strategy": request.strategy
    }

@app.get("/v1/models")
async def list_models():
    """List available embedding models."""
    
    return {
        "models": [
            {
                "name": m.name,
                "provider": m.provider,
                "dimensions": m.dimensions,
                "max_tokens": m.max_tokens,
                "cost_per_1m": m.cost_per_1m_tokens
            }
            for m in EMBEDDING_MODELS.values()
        ]
    }

@app.get("/v1/stats")
async def get_stats():
    """Get embedding cache stats."""
    
    return cached_embedder.get_stats()

@app.get("/health")
async def health():
    return {"status": "healthy", "model": model.name}

References

Conclusion

Embedding strategy significantly impacts retrieval quality and cost. Choose models based on your specific task—retrieval-optimized models like Voyage or Cohere outperform general-purpose models for search. Preprocess text consistently between indexing and querying. Implement batching to reduce API calls and caching to avoid recomputing embeddings for repeated content. For long documents, use chunking with mean pooling or hierarchical approaches. Consider dimensionality reduction for storage efficiency—Matryoshka embeddings from OpenAI allow flexible dimension selection without quality loss. Binary quantization enables fast approximate search at scale. The key is matching your embedding strategy to your use case: high-quality embeddings for precision-critical applications, efficient embeddings for cost-sensitive or high-volume scenarios.


Discover more from Code, Cloud & Context

Subscribe to get the latest posts sent to your email.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.