Introduction: Off-the-shelf embedding models work well for general text, but domain-specific applications often need better performance. Fine-tuning embeddings on your data can dramatically improve retrieval quality—turning a 70% recall into 90%+ for your specific use case. The key is creating high-quality training data that teaches the model what “similar” means in your domain. This guide covers practical embedding fine-tuning: generating contrastive pairs from your data, training with sentence transformers, evaluating with retrieval metrics, and deploying fine-tuned models. Whether you’re building a legal document search, medical knowledge base, or code retrieval system, fine-tuned embeddings can be the difference between a useful system and one that frustrates users with irrelevant results.

Training Data Generation
from dataclasses import dataclass, field
from typing import Any, Optional
from abc import ABC, abstractmethod
import random
@dataclass
class TrainingPair:
"""A training pair for contrastive learning."""
anchor: str
positive: str
negative: str = None
score: float = 1.0 # Similarity score for soft labels
@dataclass
class TrainingTriplet:
"""A training triplet."""
anchor: str
positive: str
negative: str
class PairGenerator(ABC):
"""Abstract pair generator."""
@abstractmethod
def generate(self, documents: list[str]) -> list[TrainingPair]:
"""Generate training pairs from documents."""
pass
class QueryDocumentPairGenerator(PairGenerator):
"""Generate pairs from query-document relevance."""
def __init__(
self,
llm_client: Any,
queries_per_doc: int = 3
):
self.llm = llm_client
self.queries_per_doc = queries_per_doc
async def generate_queries(self, document: str) -> list[str]:
"""Generate queries that this document answers."""
prompt = f"""Generate {self.queries_per_doc} different search queries that
this document would be a good result for. Make queries diverse - some specific,
some general, some using different terminology.
Document:
{document[:2000]}
Output only the queries, one per line:"""
response = await self.llm.complete(prompt)
queries = [q.strip() for q in response.content.split('\n') if q.strip()]
return queries[:self.queries_per_doc]
def generate(self, documents: list[str]) -> list[TrainingPair]:
"""Generate query-document pairs."""
import asyncio
async def generate_all():
pairs = []
for doc in documents:
queries = await self.generate_queries(doc)
for query in queries:
pairs.append(TrainingPair(
anchor=query,
positive=doc
))
return pairs
return asyncio.run(generate_all())
class ChunkPairGenerator(PairGenerator):
"""Generate pairs from document chunks."""
def __init__(
self,
chunk_size: int = 512,
overlap: int = 128
):
self.chunk_size = chunk_size
self.overlap = overlap
def _chunk_document(self, document: str) -> list[str]:
"""Split document into chunks."""
chunks = []
start = 0
while start < len(document):
end = start + self.chunk_size
chunks.append(document[start:end])
start = end - self.overlap
return chunks
def generate(self, documents: list[str]) -> list[TrainingPair]:
"""Generate pairs from adjacent chunks."""
pairs = []
for doc in documents:
chunks = self._chunk_document(doc)
for i in range(len(chunks) - 1):
# Adjacent chunks are positive pairs
pairs.append(TrainingPair(
anchor=chunks[i],
positive=chunks[i + 1],
score=0.9 # High but not perfect similarity
))
# Same document, further apart
if i + 2 < len(chunks):
pairs.append(TrainingPair(
anchor=chunks[i],
positive=chunks[i + 2],
score=0.7
))
return pairs
class TitleContentPairGenerator(PairGenerator):
"""Generate pairs from titles and content."""
def generate(
self,
documents: list[dict] # {"title": str, "content": str}
) -> list[TrainingPair]:
"""Generate title-content pairs."""
pairs = []
for doc in documents:
title = doc.get("title", "")
content = doc.get("content", "")
if title and content:
pairs.append(TrainingPair(
anchor=title,
positive=content[:1000] # First part of content
))
return pairs
class HardNegativeGenerator:
"""Generate hard negatives for training."""
def __init__(
self,
embedding_model: Any,
num_negatives: int = 5
):
self.embedding_model = embedding_model
self.num_negatives = num_negatives
async def add_hard_negatives(
self,
pairs: list[TrainingPair],
corpus: list[str]
) -> list[TrainingTriplet]:
"""Add hard negatives to pairs."""
# Embed all corpus documents
corpus_embeddings = await self.embedding_model.embed(corpus)
triplets = []
for pair in pairs:
# Embed anchor
anchor_embedding = await self.embedding_model.embed([pair.anchor])
# Find similar but not positive documents
similarities = self._compute_similarities(
anchor_embedding[0],
corpus_embeddings
)
# Sort by similarity
sorted_indices = sorted(
range(len(similarities)),
key=lambda i: similarities[i],
reverse=True
)
# Get hard negatives (similar but not the positive)
negatives = []
for idx in sorted_indices:
if corpus[idx] != pair.positive:
negatives.append(corpus[idx])
if len(negatives) >= self.num_negatives:
break
# Create triplets
for negative in negatives:
triplets.append(TrainingTriplet(
anchor=pair.anchor,
positive=pair.positive,
negative=negative
))
return triplets
def _compute_similarities(
self,
query_embedding: list[float],
corpus_embeddings: list[list[float]]
) -> list[float]:
"""Compute cosine similarities."""
import numpy as np
query = np.array(query_embedding)
corpus = np.array(corpus_embeddings)
# Normalize
query_norm = query / np.linalg.norm(query)
corpus_norm = corpus / np.linalg.norm(corpus, axis=1, keepdims=True)
return (corpus_norm @ query_norm).tolist()
class InBatchNegativeGenerator:
"""Use other samples in batch as negatives."""
def __init__(self, batch_size: int = 32):
self.batch_size = batch_size
def create_batches(
self,
pairs: list[TrainingPair]
) -> list[list[TrainingPair]]:
"""Create batches where other positives serve as negatives."""
# Shuffle pairs
shuffled = pairs.copy()
random.shuffle(shuffled)
# Create batches
batches = []
for i in range(0, len(shuffled), self.batch_size):
batch = shuffled[i:i + self.batch_size]
batches.append(batch)
return batches
Contrastive Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from dataclasses import dataclass
from typing import Any, Optional
@dataclass
class ContrastiveConfig:
"""Contrastive learning configuration."""
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
max_length: int = 512
temperature: float = 0.05
learning_rate: float = 2e-5
batch_size: int = 32
epochs: int = 3
warmup_steps: int = 100
class ContrastiveDataset(Dataset):
"""Dataset for contrastive learning."""
def __init__(
self,
pairs: list[TrainingPair],
tokenizer: Any,
max_length: int = 512
):
self.pairs = pairs
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
pair = self.pairs[idx]
anchor = self.tokenizer(
pair.anchor,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
positive = self.tokenizer(
pair.positive,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'anchor_input_ids': anchor['input_ids'].squeeze(),
'anchor_attention_mask': anchor['attention_mask'].squeeze(),
'positive_input_ids': positive['input_ids'].squeeze(),
'positive_attention_mask': positive['attention_mask'].squeeze(),
'score': torch.tensor(pair.score, dtype=torch.float)
}
class TripletDataset(Dataset):
"""Dataset for triplet loss training."""
def __init__(
self,
triplets: list[TrainingTriplet],
tokenizer: Any,
max_length: int = 512
):
self.triplets = triplets
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.triplets)
def __getitem__(self, idx):
triplet = self.triplets[idx]
anchor = self.tokenizer(
triplet.anchor,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
positive = self.tokenizer(
triplet.positive,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
negative = self.tokenizer(
triplet.negative,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'anchor_input_ids': anchor['input_ids'].squeeze(),
'anchor_attention_mask': anchor['attention_mask'].squeeze(),
'positive_input_ids': positive['input_ids'].squeeze(),
'positive_attention_mask': positive['attention_mask'].squeeze(),
'negative_input_ids': negative['input_ids'].squeeze(),
'negative_attention_mask': negative['attention_mask'].squeeze()
}
class EmbeddingModel(nn.Module):
"""Embedding model for fine-tuning."""
def __init__(self, model_name: str):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
self.pooling = 'mean' # mean pooling
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor
) -> torch.Tensor:
"""Generate embeddings."""
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
# Mean pooling
token_embeddings = outputs.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(
token_embeddings.size()
).float()
sum_embeddings = torch.sum(
token_embeddings * input_mask_expanded, 1
)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
class InfoNCELoss(nn.Module):
"""InfoNCE contrastive loss."""
def __init__(self, temperature: float = 0.05):
super().__init__()
self.temperature = temperature
def forward(
self,
anchor_embeddings: torch.Tensor,
positive_embeddings: torch.Tensor
) -> torch.Tensor:
"""Compute InfoNCE loss with in-batch negatives."""
# Normalize embeddings
anchor_norm = F.normalize(anchor_embeddings, p=2, dim=1)
positive_norm = F.normalize(positive_embeddings, p=2, dim=1)
# Compute similarity matrix
similarity = torch.matmul(anchor_norm, positive_norm.T) / self.temperature
# Labels: diagonal elements are positives
batch_size = anchor_embeddings.size(0)
labels = torch.arange(batch_size, device=anchor_embeddings.device)
# Cross entropy loss
loss = F.cross_entropy(similarity, labels)
return loss
class TripletLoss(nn.Module):
"""Triplet margin loss."""
def __init__(self, margin: float = 0.5):
super().__init__()
self.margin = margin
def forward(
self,
anchor: torch.Tensor,
positive: torch.Tensor,
negative: torch.Tensor
) -> torch.Tensor:
"""Compute triplet loss."""
# Normalize
anchor_norm = F.normalize(anchor, p=2, dim=1)
positive_norm = F.normalize(positive, p=2, dim=1)
negative_norm = F.normalize(negative, p=2, dim=1)
# Compute distances
pos_dist = 1 - (anchor_norm * positive_norm).sum(dim=1)
neg_dist = 1 - (anchor_norm * negative_norm).sum(dim=1)
# Triplet loss
loss = F.relu(pos_dist - neg_dist + self.margin)
return loss.mean()
class MultipleNegativesRankingLoss(nn.Module):
"""Multiple negatives ranking loss (MNRL)."""
def __init__(self, scale: float = 20.0):
super().__init__()
self.scale = scale
def forward(
self,
anchor_embeddings: torch.Tensor,
positive_embeddings: torch.Tensor,
scores: torch.Tensor = None
) -> torch.Tensor:
"""Compute MNRL loss."""
# Normalize
anchor_norm = F.normalize(anchor_embeddings, p=2, dim=1)
positive_norm = F.normalize(positive_embeddings, p=2, dim=1)
# Similarity matrix
similarity = torch.matmul(anchor_norm, positive_norm.T) * self.scale
# Labels
batch_size = anchor_embeddings.size(0)
labels = torch.arange(batch_size, device=anchor_embeddings.device)
# If we have soft labels, use them
if scores is not None:
# Weighted cross entropy
log_probs = F.log_softmax(similarity, dim=1)
loss = -scores * log_probs.diag()
return loss.mean()
return F.cross_entropy(similarity, labels)
Training Pipeline
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR, ConstantLR
from transformers import get_linear_schedule_with_warmup
from dataclasses import dataclass
from typing import Any, Optional
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class TrainingMetrics:
"""Training metrics."""
epoch: int
step: int
loss: float
learning_rate: float
class EmbeddingTrainer:
"""Trainer for embedding fine-tuning."""
def __init__(
self,
model: EmbeddingModel,
config: ContrastiveConfig,
device: str = None
):
self.model = model
self.config = config
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
# Loss function
self.loss_fn = MultipleNegativesRankingLoss()
# Optimizer
self.optimizer = AdamW(
self.model.parameters(),
lr=config.learning_rate
)
def train(
self,
train_dataloader: DataLoader,
val_dataloader: DataLoader = None,
callback: callable = None
) -> list[TrainingMetrics]:
"""Train the model."""
# Learning rate scheduler
total_steps = len(train_dataloader) * self.config.epochs
scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=self.config.warmup_steps,
num_training_steps=total_steps
)
metrics_history = []
global_step = 0
for epoch in range(self.config.epochs):
self.model.train()
epoch_loss = 0
for batch_idx, batch in enumerate(train_dataloader):
# Move to device
anchor_ids = batch['anchor_input_ids'].to(self.device)
anchor_mask = batch['anchor_attention_mask'].to(self.device)
positive_ids = batch['positive_input_ids'].to(self.device)
positive_mask = batch['positive_attention_mask'].to(self.device)
scores = batch['score'].to(self.device)
# Forward pass
anchor_embeddings = self.model(anchor_ids, anchor_mask)
positive_embeddings = self.model(positive_ids, positive_mask)
# Compute loss
loss = self.loss_fn(
anchor_embeddings,
positive_embeddings,
scores
)
# Backward pass
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=1.0
)
self.optimizer.step()
scheduler.step()
epoch_loss += loss.item()
global_step += 1
# Log metrics
if global_step % 100 == 0:
metrics = TrainingMetrics(
epoch=epoch,
step=global_step,
loss=loss.item(),
learning_rate=scheduler.get_last_lr()[0]
)
metrics_history.append(metrics)
logger.info(
f"Epoch {epoch}, Step {global_step}, "
f"Loss: {loss.item():.4f}, "
f"LR: {scheduler.get_last_lr()[0]:.2e}"
)
if callback:
callback(metrics)
# Validation
if val_dataloader:
val_loss = self._validate(val_dataloader)
logger.info(f"Epoch {epoch} validation loss: {val_loss:.4f}")
return metrics_history
def _validate(self, dataloader: DataLoader) -> float:
"""Validate the model."""
self.model.eval()
total_loss = 0
num_batches = 0
with torch.no_grad():
for batch in dataloader:
anchor_ids = batch['anchor_input_ids'].to(self.device)
anchor_mask = batch['anchor_attention_mask'].to(self.device)
positive_ids = batch['positive_input_ids'].to(self.device)
positive_mask = batch['positive_attention_mask'].to(self.device)
scores = batch['score'].to(self.device)
anchor_embeddings = self.model(anchor_ids, anchor_mask)
positive_embeddings = self.model(positive_ids, positive_mask)
loss = self.loss_fn(
anchor_embeddings,
positive_embeddings,
scores
)
total_loss += loss.item()
num_batches += 1
return total_loss / num_batches
def save(self, path: str):
"""Save model."""
torch.save({
'model_state_dict': self.model.state_dict(),
'config': self.config
}, path)
def load(self, path: str):
"""Load model."""
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
class TripletTrainer(EmbeddingTrainer):
"""Trainer using triplet loss."""
def __init__(
self,
model: EmbeddingModel,
config: ContrastiveConfig,
margin: float = 0.5,
device: str = None
):
super().__init__(model, config, device)
self.loss_fn = TripletLoss(margin=margin)
def train(
self,
train_dataloader: DataLoader,
val_dataloader: DataLoader = None,
callback: callable = None
) -> list[TrainingMetrics]:
"""Train with triplet loss."""
total_steps = len(train_dataloader) * self.config.epochs
scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=self.config.warmup_steps,
num_training_steps=total_steps
)
metrics_history = []
global_step = 0
for epoch in range(self.config.epochs):
self.model.train()
for batch in train_dataloader:
anchor_ids = batch['anchor_input_ids'].to(self.device)
anchor_mask = batch['anchor_attention_mask'].to(self.device)
positive_ids = batch['positive_input_ids'].to(self.device)
positive_mask = batch['positive_attention_mask'].to(self.device)
negative_ids = batch['negative_input_ids'].to(self.device)
negative_mask = batch['negative_attention_mask'].to(self.device)
anchor_emb = self.model(anchor_ids, anchor_mask)
positive_emb = self.model(positive_ids, positive_mask)
negative_emb = self.model(negative_ids, negative_mask)
loss = self.loss_fn(anchor_emb, positive_emb, negative_emb)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
scheduler.step()
global_step += 1
if global_step % 100 == 0:
metrics = TrainingMetrics(
epoch=epoch,
step=global_step,
loss=loss.item(),
learning_rate=scheduler.get_last_lr()[0]
)
metrics_history.append(metrics)
logger.info(f"Step {global_step}, Loss: {loss.item():.4f}")
return metrics_history
Evaluation Metrics
import numpy as np
from dataclasses import dataclass
from typing import Any, Optional
@dataclass
class RetrievalMetrics:
"""Retrieval evaluation metrics."""
mrr: float # Mean Reciprocal Rank
recall_at_1: float
recall_at_5: float
recall_at_10: float
ndcg_at_10: float
map_score: float # Mean Average Precision
class EmbeddingEvaluator:
"""Evaluate embedding model quality."""
def __init__(self, model: EmbeddingModel, tokenizer: Any):
self.model = model
self.tokenizer = tokenizer
self.device = next(model.parameters()).device
def embed(self, texts: list[str]) -> np.ndarray:
"""Generate embeddings."""
self.model.eval()
embeddings = []
with torch.no_grad():
for text in texts:
inputs = self.tokenizer(
text,
max_length=512,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = inputs['input_ids'].to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
embedding = self.model(input_ids, attention_mask)
embeddings.append(embedding.cpu().numpy())
return np.vstack(embeddings)
def evaluate(
self,
queries: list[str],
corpus: list[str],
relevance: dict[int, list[int]] # query_idx -> relevant_doc_indices
) -> RetrievalMetrics:
"""Evaluate retrieval performance."""
# Embed queries and corpus
query_embeddings = self.embed(queries)
corpus_embeddings = self.embed(corpus)
# Normalize
query_norm = query_embeddings / np.linalg.norm(
query_embeddings, axis=1, keepdims=True
)
corpus_norm = corpus_embeddings / np.linalg.norm(
corpus_embeddings, axis=1, keepdims=True
)
# Compute similarities
similarities = query_norm @ corpus_norm.T
# Compute metrics
mrr_scores = []
recall_1 = []
recall_5 = []
recall_10 = []
ndcg_10 = []
ap_scores = []
for query_idx, relevant_docs in relevance.items():
scores = similarities[query_idx]
ranked_indices = np.argsort(scores)[::-1]
# MRR
for rank, doc_idx in enumerate(ranked_indices):
if doc_idx in relevant_docs:
mrr_scores.append(1.0 / (rank + 1))
break
else:
mrr_scores.append(0.0)
# Recall@k
top_1 = set(ranked_indices[:1])
top_5 = set(ranked_indices[:5])
top_10 = set(ranked_indices[:10])
recall_1.append(
len(top_1 & set(relevant_docs)) / len(relevant_docs)
)
recall_5.append(
len(top_5 & set(relevant_docs)) / len(relevant_docs)
)
recall_10.append(
len(top_10 & set(relevant_docs)) / len(relevant_docs)
)
# NDCG@10
ndcg_10.append(
self._compute_ndcg(ranked_indices[:10], relevant_docs)
)
# Average Precision
ap_scores.append(
self._compute_ap(ranked_indices, relevant_docs)
)
return RetrievalMetrics(
mrr=np.mean(mrr_scores),
recall_at_1=np.mean(recall_1),
recall_at_5=np.mean(recall_5),
recall_at_10=np.mean(recall_10),
ndcg_at_10=np.mean(ndcg_10),
map_score=np.mean(ap_scores)
)
def _compute_ndcg(
self,
ranked_indices: np.ndarray,
relevant_docs: list[int]
) -> float:
"""Compute NDCG."""
# DCG
dcg = 0.0
for rank, doc_idx in enumerate(ranked_indices):
if doc_idx in relevant_docs:
dcg += 1.0 / np.log2(rank + 2)
# Ideal DCG
idcg = sum(1.0 / np.log2(i + 2) for i in range(min(len(relevant_docs), len(ranked_indices))))
return dcg / idcg if idcg > 0 else 0.0
def _compute_ap(
self,
ranked_indices: np.ndarray,
relevant_docs: list[int]
) -> float:
"""Compute Average Precision."""
relevant_set = set(relevant_docs)
num_relevant = 0
precision_sum = 0.0
for rank, doc_idx in enumerate(ranked_indices):
if doc_idx in relevant_set:
num_relevant += 1
precision_sum += num_relevant / (rank + 1)
return precision_sum / len(relevant_docs) if relevant_docs else 0.0
class SimilarityEvaluator:
"""Evaluate embedding similarity quality."""
def __init__(self, model: EmbeddingModel, tokenizer: Any):
self.model = model
self.tokenizer = tokenizer
self.device = next(model.parameters()).device
def evaluate_sts(
self,
sentence_pairs: list[tuple[str, str]],
gold_scores: list[float]
) -> dict:
"""Evaluate on semantic textual similarity."""
from scipy.stats import spearmanr, pearsonr
predicted_scores = []
self.model.eval()
with torch.no_grad():
for sent1, sent2 in sentence_pairs:
# Embed both sentences
emb1 = self._embed_single(sent1)
emb2 = self._embed_single(sent2)
# Cosine similarity
similarity = np.dot(emb1, emb2) / (
np.linalg.norm(emb1) * np.linalg.norm(emb2)
)
predicted_scores.append(similarity)
# Compute correlations
spearman = spearmanr(gold_scores, predicted_scores)[0]
pearson = pearsonr(gold_scores, predicted_scores)[0]
return {
'spearman': spearman,
'pearson': pearson,
'predicted_scores': predicted_scores
}
def _embed_single(self, text: str) -> np.ndarray:
"""Embed single text."""
inputs = self.tokenizer(
text,
max_length=512,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = inputs['input_ids'].to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
embedding = self.model(input_ids, attention_mask)
return embedding.cpu().numpy().flatten()
Production Fine-Tuning Service
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
from pydantic import BaseModel
from typing import Optional
import json
import uuid
app = FastAPI()
class FineTuneRequest(BaseModel):
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
training_data: list[dict] # [{"anchor": str, "positive": str}]
epochs: int = 3
batch_size: int = 32
learning_rate: float = 2e-5
class FineTuneStatus(BaseModel):
job_id: str
status: str
progress: float
metrics: Optional[dict] = None
# Job storage
jobs: dict[str, dict] = {}
@app.post("/v1/finetune")
async def create_finetune_job(
request: FineTuneRequest,
background_tasks: BackgroundTasks
) -> dict:
"""Create fine-tuning job."""
job_id = str(uuid.uuid4())
jobs[job_id] = {
"status": "pending",
"progress": 0.0,
"metrics": None
}
# Start training in background
background_tasks.add_task(
run_finetune_job,
job_id,
request
)
return {"job_id": job_id}
async def run_finetune_job(job_id: str, request: FineTuneRequest):
"""Run fine-tuning job."""
try:
jobs[job_id]["status"] = "running"
# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(request.model_name)
model = EmbeddingModel(request.model_name)
# Create training pairs
pairs = [
TrainingPair(
anchor=item["anchor"],
positive=item["positive"],
score=item.get("score", 1.0)
)
for item in request.training_data
]
# Create dataset
dataset = ContrastiveDataset(pairs, tokenizer)
dataloader = DataLoader(
dataset,
batch_size=request.batch_size,
shuffle=True
)
# Create trainer
config = ContrastiveConfig(
model_name=request.model_name,
epochs=request.epochs,
batch_size=request.batch_size,
learning_rate=request.learning_rate
)
trainer = EmbeddingTrainer(model, config)
# Train with progress callback
def progress_callback(metrics: TrainingMetrics):
total_steps = len(dataloader) * request.epochs
jobs[job_id]["progress"] = metrics.step / total_steps
jobs[job_id]["metrics"] = {
"loss": metrics.loss,
"learning_rate": metrics.learning_rate
}
trainer.train(dataloader, callback=progress_callback)
# Save model
model_path = f"/tmp/models/{job_id}"
trainer.save(model_path)
jobs[job_id]["status"] = "completed"
jobs[job_id]["progress"] = 1.0
jobs[job_id]["model_path"] = model_path
except Exception as e:
jobs[job_id]["status"] = "failed"
jobs[job_id]["error"] = str(e)
@app.get("/v1/finetune/{job_id}")
async def get_finetune_status(job_id: str) -> FineTuneStatus:
"""Get fine-tuning job status."""
if job_id not in jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = jobs[job_id]
return FineTuneStatus(
job_id=job_id,
status=job["status"],
progress=job["progress"],
metrics=job.get("metrics")
)
@app.post("/v1/evaluate")
async def evaluate_model(
model_path: str,
queries: list[str],
corpus: list[str],
relevance: dict
) -> dict:
"""Evaluate fine-tuned model."""
# Load model
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = EmbeddingModel("sentence-transformers/all-MiniLM-L6-v2")
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['model_state_dict'])
# Evaluate
evaluator = EmbeddingEvaluator(model, tokenizer)
# Convert relevance keys to int
relevance_int = {int(k): v for k, v in relevance.items()}
metrics = evaluator.evaluate(queries, corpus, relevance_int)
return {
"mrr": metrics.mrr,
"recall_at_1": metrics.recall_at_1,
"recall_at_5": metrics.recall_at_5,
"recall_at_10": metrics.recall_at_10,
"ndcg_at_10": metrics.ndcg_at_10,
"map": metrics.map_score
}
@app.get("/health")
async def health():
return {"status": "healthy"}
References
- Sentence Transformers: https://www.sbert.net/
- MTEB Benchmark: https://huggingface.co/spaces/mteb/leaderboard
- Contrastive Learning: https://arxiv.org/abs/2002.05709
- Hard Negative Mining: https://arxiv.org/abs/2007.00808
Conclusion
Fine-tuning embeddings can dramatically improve retrieval quality for domain-specific applications. The key is high-quality training data—invest time in generating diverse query-document pairs that represent real user queries. Use LLMs to generate synthetic queries, then mine hard negatives to teach the model subtle distinctions. Start with a strong base model like all-MiniLM-L6-v2 or bge-base-en, then fine-tune with contrastive learning using in-batch negatives. Evaluate with retrieval metrics (MRR, Recall@k, NDCG) on a held-out test set that reflects your actual use case. Monitor for overfitting—if training loss drops but validation metrics plateau, you’re memorizing rather than generalizing. Consider the trade-off between model size and latency: smaller models fine-tuned on your data often outperform larger general-purpose models while being faster to serve. The investment in fine-tuning pays off when you see retrieval quality jump from “sometimes useful” to “consistently finding the right documents.” For production, version your models and track metrics over time to catch regressions as your data distribution shifts.
Discover more from Code, Cloud & Context
Subscribe to get the latest posts sent to your email.