Embedding Fine-Tuning: Training Custom Embeddings for Domain-Specific Retrieval

Introduction: Off-the-shelf embedding models work well for general text, but domain-specific applications often need better performance. Fine-tuning embeddings on your data can dramatically improve retrieval quality—turning a 70% recall into 90%+ for your specific use case. The key is creating high-quality training data that teaches the model what “similar” means in your domain. This guide covers practical embedding fine-tuning: generating contrastive pairs from your data, training with sentence transformers, evaluating with retrieval metrics, and deploying fine-tuned models. Whether you’re building a legal document search, medical knowledge base, or code retrieval system, fine-tuned embeddings can be the difference between a useful system and one that frustrates users with irrelevant results.

Embedding Fine-Tuning
Embedding Fine-Tuning: Contrastive Pairs, Model Training, Evaluation

Training Data Generation

from dataclasses import dataclass, field
from typing import Any, Optional
from abc import ABC, abstractmethod
import random

@dataclass
class TrainingPair:
    """A training pair for contrastive learning."""
    
    anchor: str
    positive: str
    negative: str = None
    score: float = 1.0  # Similarity score for soft labels

@dataclass
class TrainingTriplet:
    """A training triplet."""
    
    anchor: str
    positive: str
    negative: str

class PairGenerator(ABC):
    """Abstract pair generator."""
    
    @abstractmethod
    def generate(self, documents: list[str]) -> list[TrainingPair]:
        """Generate training pairs from documents."""
        pass

class QueryDocumentPairGenerator(PairGenerator):
    """Generate pairs from query-document relevance."""
    
    def __init__(
        self,
        llm_client: Any,
        queries_per_doc: int = 3
    ):
        self.llm = llm_client
        self.queries_per_doc = queries_per_doc
    
    async def generate_queries(self, document: str) -> list[str]:
        """Generate queries that this document answers."""
        
        prompt = f"""Generate {self.queries_per_doc} different search queries that 
this document would be a good result for. Make queries diverse - some specific, 
some general, some using different terminology.

Document:
{document[:2000]}

Output only the queries, one per line:"""
        
        response = await self.llm.complete(prompt)
        queries = [q.strip() for q in response.content.split('\n') if q.strip()]
        
        return queries[:self.queries_per_doc]
    
    def generate(self, documents: list[str]) -> list[TrainingPair]:
        """Generate query-document pairs."""
        
        import asyncio
        
        async def generate_all():
            pairs = []
            
            for doc in documents:
                queries = await self.generate_queries(doc)
                
                for query in queries:
                    pairs.append(TrainingPair(
                        anchor=query,
                        positive=doc
                    ))
            
            return pairs
        
        return asyncio.run(generate_all())

class ChunkPairGenerator(PairGenerator):
    """Generate pairs from document chunks."""
    
    def __init__(
        self,
        chunk_size: int = 512,
        overlap: int = 128
    ):
        self.chunk_size = chunk_size
        self.overlap = overlap
    
    def _chunk_document(self, document: str) -> list[str]:
        """Split document into chunks."""
        
        chunks = []
        start = 0
        
        while start < len(document):
            end = start + self.chunk_size
            chunks.append(document[start:end])
            start = end - self.overlap
        
        return chunks
    
    def generate(self, documents: list[str]) -> list[TrainingPair]:
        """Generate pairs from adjacent chunks."""
        
        pairs = []
        
        for doc in documents:
            chunks = self._chunk_document(doc)
            
            for i in range(len(chunks) - 1):
                # Adjacent chunks are positive pairs
                pairs.append(TrainingPair(
                    anchor=chunks[i],
                    positive=chunks[i + 1],
                    score=0.9  # High but not perfect similarity
                ))
                
                # Same document, further apart
                if i + 2 < len(chunks):
                    pairs.append(TrainingPair(
                        anchor=chunks[i],
                        positive=chunks[i + 2],
                        score=0.7
                    ))
        
        return pairs

class TitleContentPairGenerator(PairGenerator):
    """Generate pairs from titles and content."""
    
    def generate(
        self,
        documents: list[dict]  # {"title": str, "content": str}
    ) -> list[TrainingPair]:
        """Generate title-content pairs."""
        
        pairs = []
        
        for doc in documents:
            title = doc.get("title", "")
            content = doc.get("content", "")
            
            if title and content:
                pairs.append(TrainingPair(
                    anchor=title,
                    positive=content[:1000]  # First part of content
                ))
        
        return pairs

class HardNegativeGenerator:
    """Generate hard negatives for training."""
    
    def __init__(
        self,
        embedding_model: Any,
        num_negatives: int = 5
    ):
        self.embedding_model = embedding_model
        self.num_negatives = num_negatives
    
    async def add_hard_negatives(
        self,
        pairs: list[TrainingPair],
        corpus: list[str]
    ) -> list[TrainingTriplet]:
        """Add hard negatives to pairs."""
        
        # Embed all corpus documents
        corpus_embeddings = await self.embedding_model.embed(corpus)
        
        triplets = []
        
        for pair in pairs:
            # Embed anchor
            anchor_embedding = await self.embedding_model.embed([pair.anchor])
            
            # Find similar but not positive documents
            similarities = self._compute_similarities(
                anchor_embedding[0],
                corpus_embeddings
            )
            
            # Sort by similarity
            sorted_indices = sorted(
                range(len(similarities)),
                key=lambda i: similarities[i],
                reverse=True
            )
            
            # Get hard negatives (similar but not the positive)
            negatives = []
            for idx in sorted_indices:
                if corpus[idx] != pair.positive:
                    negatives.append(corpus[idx])
                    if len(negatives) >= self.num_negatives:
                        break
            
            # Create triplets
            for negative in negatives:
                triplets.append(TrainingTriplet(
                    anchor=pair.anchor,
                    positive=pair.positive,
                    negative=negative
                ))
        
        return triplets
    
    def _compute_similarities(
        self,
        query_embedding: list[float],
        corpus_embeddings: list[list[float]]
    ) -> list[float]:
        """Compute cosine similarities."""
        
        import numpy as np
        
        query = np.array(query_embedding)
        corpus = np.array(corpus_embeddings)
        
        # Normalize
        query_norm = query / np.linalg.norm(query)
        corpus_norm = corpus / np.linalg.norm(corpus, axis=1, keepdims=True)
        
        return (corpus_norm @ query_norm).tolist()

class InBatchNegativeGenerator:
    """Use other samples in batch as negatives."""
    
    def __init__(self, batch_size: int = 32):
        self.batch_size = batch_size
    
    def create_batches(
        self,
        pairs: list[TrainingPair]
    ) -> list[list[TrainingPair]]:
        """Create batches where other positives serve as negatives."""
        
        # Shuffle pairs
        shuffled = pairs.copy()
        random.shuffle(shuffled)
        
        # Create batches
        batches = []
        for i in range(0, len(shuffled), self.batch_size):
            batch = shuffled[i:i + self.batch_size]
            batches.append(batch)
        
        return batches

Contrastive Learning

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from dataclasses import dataclass
from typing import Any, Optional

@dataclass
class ContrastiveConfig:
    """Contrastive learning configuration."""
    
    model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
    max_length: int = 512
    temperature: float = 0.05
    learning_rate: float = 2e-5
    batch_size: int = 32
    epochs: int = 3
    warmup_steps: int = 100

class ContrastiveDataset(Dataset):
    """Dataset for contrastive learning."""
    
    def __init__(
        self,
        pairs: list[TrainingPair],
        tokenizer: Any,
        max_length: int = 512
    ):
        self.pairs = pairs
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        pair = self.pairs[idx]
        
        anchor = self.tokenizer(
            pair.anchor,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        positive = self.tokenizer(
            pair.positive,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'anchor_input_ids': anchor['input_ids'].squeeze(),
            'anchor_attention_mask': anchor['attention_mask'].squeeze(),
            'positive_input_ids': positive['input_ids'].squeeze(),
            'positive_attention_mask': positive['attention_mask'].squeeze(),
            'score': torch.tensor(pair.score, dtype=torch.float)
        }

class TripletDataset(Dataset):
    """Dataset for triplet loss training."""
    
    def __init__(
        self,
        triplets: list[TrainingTriplet],
        tokenizer: Any,
        max_length: int = 512
    ):
        self.triplets = triplets
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.triplets)
    
    def __getitem__(self, idx):
        triplet = self.triplets[idx]
        
        anchor = self.tokenizer(
            triplet.anchor,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        positive = self.tokenizer(
            triplet.positive,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        negative = self.tokenizer(
            triplet.negative,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'anchor_input_ids': anchor['input_ids'].squeeze(),
            'anchor_attention_mask': anchor['attention_mask'].squeeze(),
            'positive_input_ids': positive['input_ids'].squeeze(),
            'positive_attention_mask': positive['attention_mask'].squeeze(),
            'negative_input_ids': negative['input_ids'].squeeze(),
            'negative_attention_mask': negative['attention_mask'].squeeze()
        }

class EmbeddingModel(nn.Module):
    """Embedding model for fine-tuning."""
    
    def __init__(self, model_name: str):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.pooling = 'mean'  # mean pooling
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """Generate embeddings."""
        
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # Mean pooling
        token_embeddings = outputs.last_hidden_state
        
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(
            token_embeddings.size()
        ).float()
        
        sum_embeddings = torch.sum(
            token_embeddings * input_mask_expanded, 1
        )
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        
        return sum_embeddings / sum_mask

class InfoNCELoss(nn.Module):
    """InfoNCE contrastive loss."""
    
    def __init__(self, temperature: float = 0.05):
        super().__init__()
        self.temperature = temperature
    
    def forward(
        self,
        anchor_embeddings: torch.Tensor,
        positive_embeddings: torch.Tensor
    ) -> torch.Tensor:
        """Compute InfoNCE loss with in-batch negatives."""
        
        # Normalize embeddings
        anchor_norm = F.normalize(anchor_embeddings, p=2, dim=1)
        positive_norm = F.normalize(positive_embeddings, p=2, dim=1)
        
        # Compute similarity matrix
        similarity = torch.matmul(anchor_norm, positive_norm.T) / self.temperature
        
        # Labels: diagonal elements are positives
        batch_size = anchor_embeddings.size(0)
        labels = torch.arange(batch_size, device=anchor_embeddings.device)
        
        # Cross entropy loss
        loss = F.cross_entropy(similarity, labels)
        
        return loss

class TripletLoss(nn.Module):
    """Triplet margin loss."""
    
    def __init__(self, margin: float = 0.5):
        super().__init__()
        self.margin = margin
    
    def forward(
        self,
        anchor: torch.Tensor,
        positive: torch.Tensor,
        negative: torch.Tensor
    ) -> torch.Tensor:
        """Compute triplet loss."""
        
        # Normalize
        anchor_norm = F.normalize(anchor, p=2, dim=1)
        positive_norm = F.normalize(positive, p=2, dim=1)
        negative_norm = F.normalize(negative, p=2, dim=1)
        
        # Compute distances
        pos_dist = 1 - (anchor_norm * positive_norm).sum(dim=1)
        neg_dist = 1 - (anchor_norm * negative_norm).sum(dim=1)
        
        # Triplet loss
        loss = F.relu(pos_dist - neg_dist + self.margin)
        
        return loss.mean()

class MultipleNegativesRankingLoss(nn.Module):
    """Multiple negatives ranking loss (MNRL)."""
    
    def __init__(self, scale: float = 20.0):
        super().__init__()
        self.scale = scale
    
    def forward(
        self,
        anchor_embeddings: torch.Tensor,
        positive_embeddings: torch.Tensor,
        scores: torch.Tensor = None
    ) -> torch.Tensor:
        """Compute MNRL loss."""
        
        # Normalize
        anchor_norm = F.normalize(anchor_embeddings, p=2, dim=1)
        positive_norm = F.normalize(positive_embeddings, p=2, dim=1)
        
        # Similarity matrix
        similarity = torch.matmul(anchor_norm, positive_norm.T) * self.scale
        
        # Labels
        batch_size = anchor_embeddings.size(0)
        labels = torch.arange(batch_size, device=anchor_embeddings.device)
        
        # If we have soft labels, use them
        if scores is not None:
            # Weighted cross entropy
            log_probs = F.log_softmax(similarity, dim=1)
            loss = -scores * log_probs.diag()
            return loss.mean()
        
        return F.cross_entropy(similarity, labels)

Training Pipeline

import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR, ConstantLR
from transformers import get_linear_schedule_with_warmup
from dataclasses import dataclass
from typing import Any, Optional
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class TrainingMetrics:
    """Training metrics."""
    
    epoch: int
    step: int
    loss: float
    learning_rate: float

class EmbeddingTrainer:
    """Trainer for embedding fine-tuning."""
    
    def __init__(
        self,
        model: EmbeddingModel,
        config: ContrastiveConfig,
        device: str = None
    ):
        self.model = model
        self.config = config
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.model.to(self.device)
        
        # Loss function
        self.loss_fn = MultipleNegativesRankingLoss()
        
        # Optimizer
        self.optimizer = AdamW(
            self.model.parameters(),
            lr=config.learning_rate
        )
    
    def train(
        self,
        train_dataloader: DataLoader,
        val_dataloader: DataLoader = None,
        callback: callable = None
    ) -> list[TrainingMetrics]:
        """Train the model."""
        
        # Learning rate scheduler
        total_steps = len(train_dataloader) * self.config.epochs
        
        scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.config.warmup_steps,
            num_training_steps=total_steps
        )
        
        metrics_history = []
        global_step = 0
        
        for epoch in range(self.config.epochs):
            self.model.train()
            epoch_loss = 0
            
            for batch_idx, batch in enumerate(train_dataloader):
                # Move to device
                anchor_ids = batch['anchor_input_ids'].to(self.device)
                anchor_mask = batch['anchor_attention_mask'].to(self.device)
                positive_ids = batch['positive_input_ids'].to(self.device)
                positive_mask = batch['positive_attention_mask'].to(self.device)
                scores = batch['score'].to(self.device)
                
                # Forward pass
                anchor_embeddings = self.model(anchor_ids, anchor_mask)
                positive_embeddings = self.model(positive_ids, positive_mask)
                
                # Compute loss
                loss = self.loss_fn(
                    anchor_embeddings,
                    positive_embeddings,
                    scores
                )
                
                # Backward pass
                self.optimizer.zero_grad()
                loss.backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(),
                    max_norm=1.0
                )
                
                self.optimizer.step()
                scheduler.step()
                
                epoch_loss += loss.item()
                global_step += 1
                
                # Log metrics
                if global_step % 100 == 0:
                    metrics = TrainingMetrics(
                        epoch=epoch,
                        step=global_step,
                        loss=loss.item(),
                        learning_rate=scheduler.get_last_lr()[0]
                    )
                    metrics_history.append(metrics)
                    
                    logger.info(
                        f"Epoch {epoch}, Step {global_step}, "
                        f"Loss: {loss.item():.4f}, "
                        f"LR: {scheduler.get_last_lr()[0]:.2e}"
                    )
                    
                    if callback:
                        callback(metrics)
            
            # Validation
            if val_dataloader:
                val_loss = self._validate(val_dataloader)
                logger.info(f"Epoch {epoch} validation loss: {val_loss:.4f}")
        
        return metrics_history
    
    def _validate(self, dataloader: DataLoader) -> float:
        """Validate the model."""
        
        self.model.eval()
        total_loss = 0
        num_batches = 0
        
        with torch.no_grad():
            for batch in dataloader:
                anchor_ids = batch['anchor_input_ids'].to(self.device)
                anchor_mask = batch['anchor_attention_mask'].to(self.device)
                positive_ids = batch['positive_input_ids'].to(self.device)
                positive_mask = batch['positive_attention_mask'].to(self.device)
                scores = batch['score'].to(self.device)
                
                anchor_embeddings = self.model(anchor_ids, anchor_mask)
                positive_embeddings = self.model(positive_ids, positive_mask)
                
                loss = self.loss_fn(
                    anchor_embeddings,
                    positive_embeddings,
                    scores
                )
                
                total_loss += loss.item()
                num_batches += 1
        
        return total_loss / num_batches
    
    def save(self, path: str):
        """Save model."""
        
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'config': self.config
        }, path)
    
    def load(self, path: str):
        """Load model."""
        
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])

class TripletTrainer(EmbeddingTrainer):
    """Trainer using triplet loss."""
    
    def __init__(
        self,
        model: EmbeddingModel,
        config: ContrastiveConfig,
        margin: float = 0.5,
        device: str = None
    ):
        super().__init__(model, config, device)
        self.loss_fn = TripletLoss(margin=margin)
    
    def train(
        self,
        train_dataloader: DataLoader,
        val_dataloader: DataLoader = None,
        callback: callable = None
    ) -> list[TrainingMetrics]:
        """Train with triplet loss."""
        
        total_steps = len(train_dataloader) * self.config.epochs
        
        scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.config.warmup_steps,
            num_training_steps=total_steps
        )
        
        metrics_history = []
        global_step = 0
        
        for epoch in range(self.config.epochs):
            self.model.train()
            
            for batch in train_dataloader:
                anchor_ids = batch['anchor_input_ids'].to(self.device)
                anchor_mask = batch['anchor_attention_mask'].to(self.device)
                positive_ids = batch['positive_input_ids'].to(self.device)
                positive_mask = batch['positive_attention_mask'].to(self.device)
                negative_ids = batch['negative_input_ids'].to(self.device)
                negative_mask = batch['negative_attention_mask'].to(self.device)
                
                anchor_emb = self.model(anchor_ids, anchor_mask)
                positive_emb = self.model(positive_ids, positive_mask)
                negative_emb = self.model(negative_ids, negative_mask)
                
                loss = self.loss_fn(anchor_emb, positive_emb, negative_emb)
                
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.optimizer.step()
                scheduler.step()
                
                global_step += 1
                
                if global_step % 100 == 0:
                    metrics = TrainingMetrics(
                        epoch=epoch,
                        step=global_step,
                        loss=loss.item(),
                        learning_rate=scheduler.get_last_lr()[0]
                    )
                    metrics_history.append(metrics)
                    logger.info(f"Step {global_step}, Loss: {loss.item():.4f}")
        
        return metrics_history

Evaluation Metrics

import numpy as np
from dataclasses import dataclass
from typing import Any, Optional

@dataclass
class RetrievalMetrics:
    """Retrieval evaluation metrics."""
    
    mrr: float  # Mean Reciprocal Rank
    recall_at_1: float
    recall_at_5: float
    recall_at_10: float
    ndcg_at_10: float
    map_score: float  # Mean Average Precision

class EmbeddingEvaluator:
    """Evaluate embedding model quality."""
    
    def __init__(self, model: EmbeddingModel, tokenizer: Any):
        self.model = model
        self.tokenizer = tokenizer
        self.device = next(model.parameters()).device
    
    def embed(self, texts: list[str]) -> np.ndarray:
        """Generate embeddings."""
        
        self.model.eval()
        embeddings = []
        
        with torch.no_grad():
            for text in texts:
                inputs = self.tokenizer(
                    text,
                    max_length=512,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                )
                
                input_ids = inputs['input_ids'].to(self.device)
                attention_mask = inputs['attention_mask'].to(self.device)
                
                embedding = self.model(input_ids, attention_mask)
                embeddings.append(embedding.cpu().numpy())
        
        return np.vstack(embeddings)
    
    def evaluate(
        self,
        queries: list[str],
        corpus: list[str],
        relevance: dict[int, list[int]]  # query_idx -> relevant_doc_indices
    ) -> RetrievalMetrics:
        """Evaluate retrieval performance."""
        
        # Embed queries and corpus
        query_embeddings = self.embed(queries)
        corpus_embeddings = self.embed(corpus)
        
        # Normalize
        query_norm = query_embeddings / np.linalg.norm(
            query_embeddings, axis=1, keepdims=True
        )
        corpus_norm = corpus_embeddings / np.linalg.norm(
            corpus_embeddings, axis=1, keepdims=True
        )
        
        # Compute similarities
        similarities = query_norm @ corpus_norm.T
        
        # Compute metrics
        mrr_scores = []
        recall_1 = []
        recall_5 = []
        recall_10 = []
        ndcg_10 = []
        ap_scores = []
        
        for query_idx, relevant_docs in relevance.items():
            scores = similarities[query_idx]
            ranked_indices = np.argsort(scores)[::-1]
            
            # MRR
            for rank, doc_idx in enumerate(ranked_indices):
                if doc_idx in relevant_docs:
                    mrr_scores.append(1.0 / (rank + 1))
                    break
            else:
                mrr_scores.append(0.0)
            
            # Recall@k
            top_1 = set(ranked_indices[:1])
            top_5 = set(ranked_indices[:5])
            top_10 = set(ranked_indices[:10])
            
            recall_1.append(
                len(top_1 & set(relevant_docs)) / len(relevant_docs)
            )
            recall_5.append(
                len(top_5 & set(relevant_docs)) / len(relevant_docs)
            )
            recall_10.append(
                len(top_10 & set(relevant_docs)) / len(relevant_docs)
            )
            
            # NDCG@10
            ndcg_10.append(
                self._compute_ndcg(ranked_indices[:10], relevant_docs)
            )
            
            # Average Precision
            ap_scores.append(
                self._compute_ap(ranked_indices, relevant_docs)
            )
        
        return RetrievalMetrics(
            mrr=np.mean(mrr_scores),
            recall_at_1=np.mean(recall_1),
            recall_at_5=np.mean(recall_5),
            recall_at_10=np.mean(recall_10),
            ndcg_at_10=np.mean(ndcg_10),
            map_score=np.mean(ap_scores)
        )
    
    def _compute_ndcg(
        self,
        ranked_indices: np.ndarray,
        relevant_docs: list[int]
    ) -> float:
        """Compute NDCG."""
        
        # DCG
        dcg = 0.0
        for rank, doc_idx in enumerate(ranked_indices):
            if doc_idx in relevant_docs:
                dcg += 1.0 / np.log2(rank + 2)
        
        # Ideal DCG
        idcg = sum(1.0 / np.log2(i + 2) for i in range(min(len(relevant_docs), len(ranked_indices))))
        
        return dcg / idcg if idcg > 0 else 0.0
    
    def _compute_ap(
        self,
        ranked_indices: np.ndarray,
        relevant_docs: list[int]
    ) -> float:
        """Compute Average Precision."""
        
        relevant_set = set(relevant_docs)
        num_relevant = 0
        precision_sum = 0.0
        
        for rank, doc_idx in enumerate(ranked_indices):
            if doc_idx in relevant_set:
                num_relevant += 1
                precision_sum += num_relevant / (rank + 1)
        
        return precision_sum / len(relevant_docs) if relevant_docs else 0.0

class SimilarityEvaluator:
    """Evaluate embedding similarity quality."""
    
    def __init__(self, model: EmbeddingModel, tokenizer: Any):
        self.model = model
        self.tokenizer = tokenizer
        self.device = next(model.parameters()).device
    
    def evaluate_sts(
        self,
        sentence_pairs: list[tuple[str, str]],
        gold_scores: list[float]
    ) -> dict:
        """Evaluate on semantic textual similarity."""
        
        from scipy.stats import spearmanr, pearsonr
        
        predicted_scores = []
        
        self.model.eval()
        
        with torch.no_grad():
            for sent1, sent2 in sentence_pairs:
                # Embed both sentences
                emb1 = self._embed_single(sent1)
                emb2 = self._embed_single(sent2)
                
                # Cosine similarity
                similarity = np.dot(emb1, emb2) / (
                    np.linalg.norm(emb1) * np.linalg.norm(emb2)
                )
                predicted_scores.append(similarity)
        
        # Compute correlations
        spearman = spearmanr(gold_scores, predicted_scores)[0]
        pearson = pearsonr(gold_scores, predicted_scores)[0]
        
        return {
            'spearman': spearman,
            'pearson': pearson,
            'predicted_scores': predicted_scores
        }
    
    def _embed_single(self, text: str) -> np.ndarray:
        """Embed single text."""
        
        inputs = self.tokenizer(
            text,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        input_ids = inputs['input_ids'].to(self.device)
        attention_mask = inputs['attention_mask'].to(self.device)
        
        embedding = self.model(input_ids, attention_mask)
        
        return embedding.cpu().numpy().flatten()

Production Fine-Tuning Service

from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
from pydantic import BaseModel
from typing import Optional
import json
import uuid

app = FastAPI()

class FineTuneRequest(BaseModel):
    model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
    training_data: list[dict]  # [{"anchor": str, "positive": str}]
    epochs: int = 3
    batch_size: int = 32
    learning_rate: float = 2e-5

class FineTuneStatus(BaseModel):
    job_id: str
    status: str
    progress: float
    metrics: Optional[dict] = None

# Job storage
jobs: dict[str, dict] = {}

@app.post("/v1/finetune")
async def create_finetune_job(
    request: FineTuneRequest,
    background_tasks: BackgroundTasks
) -> dict:
    """Create fine-tuning job."""
    
    job_id = str(uuid.uuid4())
    
    jobs[job_id] = {
        "status": "pending",
        "progress": 0.0,
        "metrics": None
    }
    
    # Start training in background
    background_tasks.add_task(
        run_finetune_job,
        job_id,
        request
    )
    
    return {"job_id": job_id}

async def run_finetune_job(job_id: str, request: FineTuneRequest):
    """Run fine-tuning job."""
    
    try:
        jobs[job_id]["status"] = "running"
        
        # Initialize model and tokenizer
        tokenizer = AutoTokenizer.from_pretrained(request.model_name)
        model = EmbeddingModel(request.model_name)
        
        # Create training pairs
        pairs = [
            TrainingPair(
                anchor=item["anchor"],
                positive=item["positive"],
                score=item.get("score", 1.0)
            )
            for item in request.training_data
        ]
        
        # Create dataset
        dataset = ContrastiveDataset(pairs, tokenizer)
        dataloader = DataLoader(
            dataset,
            batch_size=request.batch_size,
            shuffle=True
        )
        
        # Create trainer
        config = ContrastiveConfig(
            model_name=request.model_name,
            epochs=request.epochs,
            batch_size=request.batch_size,
            learning_rate=request.learning_rate
        )
        
        trainer = EmbeddingTrainer(model, config)
        
        # Train with progress callback
        def progress_callback(metrics: TrainingMetrics):
            total_steps = len(dataloader) * request.epochs
            jobs[job_id]["progress"] = metrics.step / total_steps
            jobs[job_id]["metrics"] = {
                "loss": metrics.loss,
                "learning_rate": metrics.learning_rate
            }
        
        trainer.train(dataloader, callback=progress_callback)
        
        # Save model
        model_path = f"/tmp/models/{job_id}"
        trainer.save(model_path)
        
        jobs[job_id]["status"] = "completed"
        jobs[job_id]["progress"] = 1.0
        jobs[job_id]["model_path"] = model_path
        
    except Exception as e:
        jobs[job_id]["status"] = "failed"
        jobs[job_id]["error"] = str(e)

@app.get("/v1/finetune/{job_id}")
async def get_finetune_status(job_id: str) -> FineTuneStatus:
    """Get fine-tuning job status."""
    
    if job_id not in jobs:
        raise HTTPException(status_code=404, detail="Job not found")
    
    job = jobs[job_id]
    
    return FineTuneStatus(
        job_id=job_id,
        status=job["status"],
        progress=job["progress"],
        metrics=job.get("metrics")
    )

@app.post("/v1/evaluate")
async def evaluate_model(
    model_path: str,
    queries: list[str],
    corpus: list[str],
    relevance: dict
) -> dict:
    """Evaluate fine-tuned model."""
    
    # Load model
    tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
    model = EmbeddingModel("sentence-transformers/all-MiniLM-L6-v2")
    
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Evaluate
    evaluator = EmbeddingEvaluator(model, tokenizer)
    
    # Convert relevance keys to int
    relevance_int = {int(k): v for k, v in relevance.items()}
    
    metrics = evaluator.evaluate(queries, corpus, relevance_int)
    
    return {
        "mrr": metrics.mrr,
        "recall_at_1": metrics.recall_at_1,
        "recall_at_5": metrics.recall_at_5,
        "recall_at_10": metrics.recall_at_10,
        "ndcg_at_10": metrics.ndcg_at_10,
        "map": metrics.map_score
    }

@app.get("/health")
async def health():
    return {"status": "healthy"}

References

Conclusion

Fine-tuning embeddings can dramatically improve retrieval quality for domain-specific applications. The key is high-quality training data—invest time in generating diverse query-document pairs that represent real user queries. Use LLMs to generate synthetic queries, then mine hard negatives to teach the model subtle distinctions. Start with a strong base model like all-MiniLM-L6-v2 or bge-base-en, then fine-tune with contrastive learning using in-batch negatives. Evaluate with retrieval metrics (MRR, Recall@k, NDCG) on a held-out test set that reflects your actual use case. Monitor for overfitting—if training loss drops but validation metrics plateau, you’re memorizing rather than generalizing. Consider the trade-off between model size and latency: smaller models fine-tuned on your data often outperform larger general-purpose models while being faster to serve. The investment in fine-tuning pays off when you see retrieval quality jump from “sometimes useful” to “consistently finding the right documents.” For production, version your models and track metrics over time to catch regressions as your data distribution shifts.


Discover more from Code, Cloud & Context

Subscribe to get the latest posts sent to your email.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.