Introduction: Knowledge distillation transfers the capabilities of large, expensive models into smaller, faster ones that can run efficiently in production. Instead of training a small model from scratch, distillation leverages the “dark knowledge” encoded in a teacher model’s soft probability distributions—information that hard labels alone cannot capture. This guide covers the techniques that make distillation effective: response-based distillation using soft targets, feature-based distillation that matches intermediate representations, relation-based distillation that preserves structural relationships, and self-distillation for iterative improvement. Whether you’re compressing a 70B model into a 7B deployment target or creating specialized models from general-purpose teachers, these patterns will help you achieve the best quality-efficiency tradeoff.

Response-Based Distillation
from dataclasses import dataclass, field
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class DistillationConfig:
"""Configuration for knowledge distillation."""
temperature: float = 2.0
alpha: float = 0.5 # Weight for distillation loss
hard_label_weight: float = 0.5 # Weight for hard label loss
class SoftTargetDistiller:
"""Distillation using soft targets."""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
config: DistillationConfig
):
self.teacher = teacher
self.student = student
self.config = config
# Freeze teacher
for param in self.teacher.parameters():
param.requires_grad = False
self.teacher.eval()
def distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor = None
) -> torch.Tensor:
"""Calculate distillation loss."""
T = self.config.temperature
# Soft target loss (KL divergence)
soft_targets = F.softmax(teacher_logits / T, dim=-1)
soft_predictions = F.log_softmax(student_logits / T, dim=-1)
distill_loss = F.kl_div(
soft_predictions,
soft_targets,
reduction='batchmean'
) * (T * T)
# Hard label loss (if labels provided)
if labels is not None:
hard_loss = F.cross_entropy(student_logits, labels)
total_loss = (
self.config.alpha * distill_loss +
self.config.hard_label_weight * hard_loss
)
else:
total_loss = distill_loss
return total_loss
def train_step(
self,
inputs: torch.Tensor,
labels: torch.Tensor = None
) -> dict:
"""Single training step."""
# Get teacher predictions
with torch.no_grad():
teacher_logits = self.teacher(inputs)
# Get student predictions
student_logits = self.student(inputs)
# Calculate loss
loss = self.distillation_loss(student_logits, teacher_logits, labels)
return {
"loss": loss,
"student_logits": student_logits,
"teacher_logits": teacher_logits
}
class SequenceDistiller:
"""Distillation for sequence models (LLMs)."""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
config: DistillationConfig
):
self.teacher = teacher
self.student = student
self.config = config
for param in self.teacher.parameters():
param.requires_grad = False
def distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
attention_mask: torch.Tensor = None
) -> torch.Tensor:
"""Calculate sequence distillation loss."""
T = self.config.temperature
# Reshape for sequence
batch_size, seq_len, vocab_size = student_logits.shape
student_flat = student_logits.view(-1, vocab_size)
teacher_flat = teacher_logits.view(-1, vocab_size)
# KL divergence
soft_targets = F.softmax(teacher_flat / T, dim=-1)
soft_predictions = F.log_softmax(student_flat / T, dim=-1)
kl_loss = F.kl_div(
soft_predictions,
soft_targets,
reduction='none'
).sum(dim=-1)
# Apply attention mask
if attention_mask is not None:
mask_flat = attention_mask.view(-1)
kl_loss = kl_loss * mask_flat
kl_loss = kl_loss.sum() / mask_flat.sum()
else:
kl_loss = kl_loss.mean()
return kl_loss * (T * T)
def train_step(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor = None
) -> dict:
"""Training step for sequence distillation."""
with torch.no_grad():
teacher_outputs = self.teacher(
input_ids,
attention_mask=attention_mask
)
teacher_logits = teacher_outputs.logits
student_outputs = self.student(
input_ids,
attention_mask=attention_mask
)
student_logits = student_outputs.logits
loss = self.distillation_loss(
student_logits,
teacher_logits,
attention_mask
)
return {"loss": loss}
class TopKDistiller:
"""Distillation using only top-k logits."""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
k: int = 100,
temperature: float = 2.0
):
self.teacher = teacher
self.student = student
self.k = k
self.temperature = temperature
def distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor
) -> torch.Tensor:
"""Calculate top-k distillation loss."""
T = self.temperature
# Get top-k from teacher
top_k_values, top_k_indices = torch.topk(teacher_logits, self.k, dim=-1)
# Gather corresponding student logits
student_top_k = torch.gather(student_logits, -1, top_k_indices)
# Softmax over top-k only
teacher_probs = F.softmax(top_k_values / T, dim=-1)
student_log_probs = F.log_softmax(student_top_k / T, dim=-1)
# KL divergence
loss = F.kl_div(
student_log_probs,
teacher_probs,
reduction='batchmean'
) * (T * T)
return loss
Feature-Based Distillation
from dataclasses import dataclass
from typing import Any, Optional, List
import torch
import torch.nn as nn
@dataclass
class FeatureDistillConfig:
"""Configuration for feature distillation."""
layer_mapping: dict # teacher_layer -> student_layer
loss_type: str = "mse" # "mse", "cosine", "attention"
projection: bool = True # Project student features to teacher dim
class FeatureDistiller:
"""Distillation using intermediate features."""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
config: FeatureDistillConfig
):
self.teacher = teacher
self.student = student
self.config = config
# Feature hooks
self.teacher_features = {}
self.student_features = {}
self._register_hooks()
# Projection layers
if config.projection:
self.projectors = self._create_projectors()
def _register_hooks(self):
"""Register forward hooks to capture features."""
def make_hook(storage, name):
def hook(module, input, output):
storage[name] = output
return hook
for teacher_layer, student_layer in self.config.layer_mapping.items():
# Register teacher hook
teacher_module = self._get_module(self.teacher, teacher_layer)
teacher_module.register_forward_hook(
make_hook(self.teacher_features, teacher_layer)
)
# Register student hook
student_module = self._get_module(self.student, student_layer)
student_module.register_forward_hook(
make_hook(self.student_features, student_layer)
)
def _get_module(self, model: nn.Module, name: str) -> nn.Module:
"""Get module by name."""
parts = name.split('.')
module = model
for part in parts:
module = getattr(module, part)
return module
def _create_projectors(self) -> nn.ModuleDict:
"""Create projection layers for dimension matching."""
projectors = {}
for teacher_layer, student_layer in self.config.layer_mapping.items():
# Would need to infer dimensions from model config
# Simplified: assume we know dimensions
projectors[student_layer] = nn.Linear(768, 1024) # Example dims
return nn.ModuleDict(projectors)
def feature_loss(
self,
teacher_feat: torch.Tensor,
student_feat: torch.Tensor,
layer_name: str
) -> torch.Tensor:
"""Calculate feature matching loss."""
# Project student features if needed
if self.config.projection and layer_name in self.projectors:
student_feat = self.projectors[layer_name](student_feat)
if self.config.loss_type == "mse":
return F.mse_loss(student_feat, teacher_feat)
elif self.config.loss_type == "cosine":
# Cosine similarity loss
cos_sim = F.cosine_similarity(
student_feat.view(-1, student_feat.size(-1)),
teacher_feat.view(-1, teacher_feat.size(-1)),
dim=-1
)
return (1 - cos_sim).mean()
elif self.config.loss_type == "attention":
# Attention transfer
student_att = self._attention_map(student_feat)
teacher_att = self._attention_map(teacher_feat)
return F.mse_loss(student_att, teacher_att)
raise ValueError(f"Unknown loss type: {self.config.loss_type}")
def _attention_map(self, features: torch.Tensor) -> torch.Tensor:
"""Convert features to attention map."""
# Sum of squared activations across channels
return features.pow(2).sum(dim=-1)
def train_step(self, inputs: torch.Tensor) -> dict:
"""Training step with feature distillation."""
# Forward pass through both models
with torch.no_grad():
_ = self.teacher(inputs)
_ = self.student(inputs)
# Calculate feature losses
total_loss = 0
feature_losses = {}
for teacher_layer, student_layer in self.config.layer_mapping.items():
teacher_feat = self.teacher_features[teacher_layer]
student_feat = self.student_features[student_layer]
loss = self.feature_loss(teacher_feat, student_feat, student_layer)
feature_losses[student_layer] = loss.item()
total_loss += loss
return {
"loss": total_loss,
"feature_losses": feature_losses
}
class AttentionDistiller:
"""Distillation of attention patterns."""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
layer_mapping: dict
):
self.teacher = teacher
self.student = student
self.layer_mapping = layer_mapping
self.teacher_attentions = {}
self.student_attentions = {}
def attention_loss(
self,
student_att: torch.Tensor,
teacher_att: torch.Tensor
) -> torch.Tensor:
"""Calculate attention distillation loss."""
# Handle different number of heads
if student_att.shape[1] != teacher_att.shape[1]:
# Average teacher heads to match student
teacher_att = teacher_att.mean(dim=1, keepdim=True)
teacher_att = teacher_att.expand_as(student_att)
# MSE on attention weights
return F.mse_loss(student_att, teacher_att)
def train_step(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor = None
) -> dict:
"""Training step for attention distillation."""
# Get teacher attentions
with torch.no_grad():
teacher_out = self.teacher(
input_ids,
attention_mask=attention_mask,
output_attentions=True
)
# Get student attentions
student_out = self.student(
input_ids,
attention_mask=attention_mask,
output_attentions=True
)
# Calculate attention losses
total_loss = 0
for t_layer, s_layer in self.layer_mapping.items():
t_att = teacher_out.attentions[t_layer]
s_att = student_out.attentions[s_layer]
total_loss += self.attention_loss(s_att, t_att)
return {"loss": total_loss}
Relation-Based Distillation
from dataclasses import dataclass
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class RelationDistiller:
"""Distillation preserving relational knowledge."""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
relation_type: str = "distance" # "distance", "angle", "similarity"
):
self.teacher = teacher
self.student = student
self.relation_type = relation_type
def compute_relations(
self,
features: torch.Tensor
) -> torch.Tensor:
"""Compute pairwise relations between samples."""
batch_size = features.shape[0]
features_flat = features.view(batch_size, -1)
if self.relation_type == "distance":
# Pairwise Euclidean distance
diff = features_flat.unsqueeze(0) - features_flat.unsqueeze(1)
relations = torch.norm(diff, dim=-1)
elif self.relation_type == "angle":
# Pairwise angles
normalized = F.normalize(features_flat, dim=-1)
relations = torch.mm(normalized, normalized.t())
elif self.relation_type == "similarity":
# Cosine similarity
relations = F.cosine_similarity(
features_flat.unsqueeze(0),
features_flat.unsqueeze(1),
dim=-1
)
return relations
def relation_loss(
self,
student_features: torch.Tensor,
teacher_features: torch.Tensor
) -> torch.Tensor:
"""Calculate relation distillation loss."""
student_relations = self.compute_relations(student_features)
teacher_relations = self.compute_relations(teacher_features)
# Normalize relations
student_relations = F.normalize(student_relations.view(-1), dim=0)
teacher_relations = F.normalize(teacher_relations.view(-1), dim=0)
return F.mse_loss(student_relations, teacher_relations)
def train_step(self, inputs: torch.Tensor) -> dict:
"""Training step with relation distillation."""
with torch.no_grad():
teacher_features = self.teacher.get_features(inputs)
student_features = self.student.get_features(inputs)
loss = self.relation_loss(student_features, teacher_features)
return {"loss": loss}
class ContrastiveDistiller:
"""Contrastive knowledge distillation."""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
temperature: float = 0.07
):
self.teacher = teacher
self.student = student
self.temperature = temperature
def contrastive_loss(
self,
student_features: torch.Tensor,
teacher_features: torch.Tensor
) -> torch.Tensor:
"""Calculate contrastive distillation loss."""
batch_size = student_features.shape[0]
# Normalize features
student_norm = F.normalize(student_features, dim=-1)
teacher_norm = F.normalize(teacher_features, dim=-1)
# Positive pairs: student-teacher for same sample
positives = (student_norm * teacher_norm).sum(dim=-1) / self.temperature
# Negative pairs: student with other teachers
negatives = torch.mm(student_norm, teacher_norm.t()) / self.temperature
# InfoNCE loss
labels = torch.arange(batch_size, device=student_features.device)
loss = F.cross_entropy(negatives, labels)
return loss
def train_step(self, inputs: torch.Tensor) -> dict:
"""Training step with contrastive distillation."""
with torch.no_grad():
teacher_features = self.teacher.get_features(inputs)
student_features = self.student.get_features(inputs)
loss = self.contrastive_loss(student_features, teacher_features)
return {"loss": loss}
class GraphDistiller:
"""Distillation preserving graph structure."""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
k_neighbors: int = 5
):
self.teacher = teacher
self.student = student
self.k_neighbors = k_neighbors
def build_graph(self, features: torch.Tensor) -> torch.Tensor:
"""Build k-NN graph from features."""
# Compute pairwise distances
distances = torch.cdist(features, features)
# Get k nearest neighbors
_, indices = torch.topk(distances, self.k_neighbors + 1, largest=False)
# Build adjacency matrix
batch_size = features.shape[0]
adj = torch.zeros(batch_size, batch_size, device=features.device)
for i in range(batch_size):
adj[i, indices[i, 1:]] = 1 # Exclude self
return adj
def graph_loss(
self,
student_features: torch.Tensor,
teacher_features: torch.Tensor
) -> torch.Tensor:
"""Calculate graph distillation loss."""
student_graph = self.build_graph(student_features)
teacher_graph = self.build_graph(teacher_features)
# Binary cross entropy for graph matching
loss = F.binary_cross_entropy(
student_graph,
teacher_graph
)
return loss
Self-Distillation
from dataclasses import dataclass
from typing import Any, Optional
import torch
import torch.nn as nn
import copy
class SelfDistiller:
"""Self-distillation for iterative improvement."""
def __init__(
self,
model: nn.Module,
temperature: float = 3.0,
ema_decay: float = 0.999
):
self.student = model
self.teacher = copy.deepcopy(model)
self.temperature = temperature
self.ema_decay = ema_decay
# Freeze teacher
for param in self.teacher.parameters():
param.requires_grad = False
def update_teacher(self):
"""Update teacher with EMA of student weights."""
with torch.no_grad():
for t_param, s_param in zip(
self.teacher.parameters(),
self.student.parameters()
):
t_param.data = (
self.ema_decay * t_param.data +
(1 - self.ema_decay) * s_param.data
)
def distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor
) -> torch.Tensor:
"""Calculate self-distillation loss."""
T = self.temperature
# Soft target loss
soft_loss = F.kl_div(
F.log_softmax(student_logits / T, dim=-1),
F.softmax(teacher_logits / T, dim=-1),
reduction='batchmean'
) * (T * T)
# Hard label loss
hard_loss = F.cross_entropy(student_logits, labels)
return 0.5 * soft_loss + 0.5 * hard_loss
def train_step(
self,
inputs: torch.Tensor,
labels: torch.Tensor
) -> dict:
"""Training step with self-distillation."""
# Get teacher predictions
with torch.no_grad():
teacher_logits = self.teacher(inputs)
# Get student predictions
student_logits = self.student(inputs)
# Calculate loss
loss = self.distillation_loss(student_logits, teacher_logits, labels)
return {"loss": loss}
class BornAgainDistiller:
"""Born-Again Networks: multi-generation distillation."""
def __init__(
self,
model_fn: callable,
num_generations: int = 3,
temperature: float = 2.0
):
self.model_fn = model_fn
self.num_generations = num_generations
self.temperature = temperature
self.generations = []
def train_generation(
self,
train_loader: Any,
num_epochs: int = 10
) -> nn.Module:
"""Train one generation."""
# Create new student
student = self.model_fn()
# Get teacher (previous generation or None)
teacher = self.generations[-1] if self.generations else None
optimizer = torch.optim.Adam(student.parameters())
for epoch in range(num_epochs):
for inputs, labels in train_loader:
optimizer.zero_grad()
student_logits = student(inputs)
if teacher is not None:
with torch.no_grad():
teacher_logits = teacher(inputs)
# Distillation loss
T = self.temperature
loss = F.kl_div(
F.log_softmax(student_logits / T, dim=-1),
F.softmax(teacher_logits / T, dim=-1),
reduction='batchmean'
) * (T * T)
else:
# Standard cross-entropy for first generation
loss = F.cross_entropy(student_logits, labels)
loss.backward()
optimizer.step()
self.generations.append(student)
return student
def train_all_generations(
self,
train_loader: Any,
num_epochs: int = 10
) -> nn.Module:
"""Train all generations."""
for gen in range(self.num_generations):
print(f"Training generation {gen + 1}/{self.num_generations}")
self.train_generation(train_loader, num_epochs)
return self.generations[-1]
class DeepMutualLearning:
"""Deep Mutual Learning: students teach each other."""
def __init__(
self,
models: list[nn.Module],
temperature: float = 2.0
):
self.models = models
self.temperature = temperature
def mutual_loss(
self,
logits_list: list[torch.Tensor],
labels: torch.Tensor
) -> list[torch.Tensor]:
"""Calculate mutual learning losses."""
T = self.temperature
losses = []
for i, logits_i in enumerate(logits_list):
# Hard label loss
ce_loss = F.cross_entropy(logits_i, labels)
# KL divergence with other models
kl_loss = 0
for j, logits_j in enumerate(logits_list):
if i != j:
kl_loss += F.kl_div(
F.log_softmax(logits_i / T, dim=-1),
F.softmax(logits_j.detach() / T, dim=-1),
reduction='batchmean'
) * (T * T)
kl_loss /= (len(logits_list) - 1)
losses.append(ce_loss + kl_loss)
return losses
def train_step(
self,
inputs: torch.Tensor,
labels: torch.Tensor
) -> dict:
"""Training step for mutual learning."""
# Get predictions from all models
logits_list = [model(inputs) for model in self.models]
# Calculate losses
losses = self.mutual_loss(logits_list, labels)
return {
"losses": losses,
"total_loss": sum(losses)
}
LLM-Specific Distillation
from dataclasses import dataclass
from typing import Any, Optional, List
import torch
import torch.nn.functional as F
@dataclass
class LLMDistillConfig:
"""Configuration for LLM distillation."""
temperature: float = 2.0
alpha: float = 0.5
use_teacher_forcing: bool = True
max_seq_len: int = 2048
class LLMDistiller:
"""Distillation for large language models."""
def __init__(
self,
teacher: Any,
student: Any,
config: LLMDistillConfig
):
self.teacher = teacher
self.student = student
self.config = config
async def generate_training_data(
self,
prompts: list[str],
num_samples: int = 1
) -> list[dict]:
"""Generate training data from teacher."""
training_data = []
for prompt in prompts:
for _ in range(num_samples):
# Generate from teacher
response = await self.teacher.generate(
prompt,
temperature=0.7,
max_tokens=512
)
training_data.append({
"prompt": prompt,
"response": response,
"source": "teacher"
})
return training_data
def distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor,
attention_mask: torch.Tensor
) -> torch.Tensor:
"""Calculate LLM distillation loss."""
T = self.config.temperature
# Reshape
vocab_size = student_logits.shape[-1]
student_flat = student_logits.view(-1, vocab_size)
teacher_flat = teacher_logits.view(-1, vocab_size)
labels_flat = labels.view(-1)
mask_flat = attention_mask.view(-1)
# Soft target loss
soft_targets = F.softmax(teacher_flat / T, dim=-1)
soft_predictions = F.log_softmax(student_flat / T, dim=-1)
kl_loss = F.kl_div(
soft_predictions,
soft_targets,
reduction='none'
).sum(dim=-1)
kl_loss = (kl_loss * mask_flat).sum() / mask_flat.sum()
kl_loss = kl_loss * (T * T)
# Hard label loss
ce_loss = F.cross_entropy(
student_flat,
labels_flat,
reduction='none'
)
ce_loss = (ce_loss * mask_flat).sum() / mask_flat.sum()
return self.config.alpha * kl_loss + (1 - self.config.alpha) * ce_loss
class OnlineDistiller:
"""Online distillation from API-based teacher."""
def __init__(
self,
teacher_api: Any,
student: Any,
cache_size: int = 10000
):
self.teacher_api = teacher_api
self.student = student
self.cache: dict[str, str] = {}
self.cache_size = cache_size
async def get_teacher_response(self, prompt: str) -> str:
"""Get teacher response with caching."""
if prompt in self.cache:
return self.cache[prompt]
response = await self.teacher_api.generate(prompt)
# Cache management
if len(self.cache) >= self.cache_size:
# Remove oldest entry
oldest = next(iter(self.cache))
del self.cache[oldest]
self.cache[prompt] = response
return response
async def train_step(
self,
prompts: list[str]
) -> dict:
"""Training step with online distillation."""
# Get teacher responses
teacher_responses = []
for prompt in prompts:
response = await self.get_teacher_response(prompt)
teacher_responses.append(response)
# Train student on teacher responses
loss = self.student.train_on_examples(
prompts,
teacher_responses
)
return {"loss": loss}
class TaskSpecificDistiller:
"""Distillation for specific tasks."""
def __init__(
self,
teacher: Any,
student: Any,
task_type: str # "classification", "generation", "qa"
):
self.teacher = teacher
self.student = student
self.task_type = task_type
async def distill_classification(
self,
texts: list[str],
num_classes: int
) -> dict:
"""Distill classification task."""
# Get teacher soft labels
teacher_probs = []
for text in texts:
probs = await self.teacher.classify(text)
teacher_probs.append(probs)
# Train student
loss = self.student.train_classifier(
texts,
torch.tensor(teacher_probs)
)
return {"loss": loss}
async def distill_generation(
self,
prompts: list[str]
) -> dict:
"""Distill generation task."""
# Generate from teacher
teacher_outputs = []
for prompt in prompts:
output = await self.teacher.generate(prompt)
teacher_outputs.append(output)
# Train student on teacher outputs
loss = self.student.train_generator(
prompts,
teacher_outputs
)
return {"loss": loss}
Production Distillation Service
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
from typing import Optional, Any
import asyncio
import uuid
app = FastAPI()
class DistillationJob:
"""A distillation job."""
def __init__(self, job_id: str, config: dict):
self.id = job_id
self.config = config
self.status = "pending"
self.progress = 0.0
self.metrics = {}
jobs: dict[str, DistillationJob] = {}
class CreateJobRequest(BaseModel):
teacher_model: str
student_model: str
dataset_path: str
temperature: float = 2.0
num_epochs: int = 3
class JobResponse(BaseModel):
job_id: str
status: str
progress: float
metrics: dict
async def run_distillation(job: DistillationJob):
"""Run distillation job."""
job.status = "running"
try:
# Simulate distillation progress
for epoch in range(job.config["num_epochs"]):
job.progress = (epoch + 1) / job.config["num_epochs"]
job.metrics["epoch"] = epoch + 1
job.metrics["loss"] = 1.0 / (epoch + 1)
await asyncio.sleep(1)
job.status = "completed"
job.progress = 1.0
except Exception as e:
job.status = "failed"
job.metrics["error"] = str(e)
@app.post("/v1/distillation/jobs")
async def create_job(
request: CreateJobRequest,
background_tasks: BackgroundTasks
) -> JobResponse:
"""Create distillation job."""
job_id = str(uuid.uuid4())
job = DistillationJob(
job_id=job_id,
config=request.model_dump()
)
jobs[job_id] = job
background_tasks.add_task(run_distillation, job)
return JobResponse(
job_id=job_id,
status=job.status,
progress=job.progress,
metrics=job.metrics
)
@app.get("/v1/distillation/jobs/{job_id}")
async def get_job(job_id: str) -> JobResponse:
"""Get job status."""
if job_id not in jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = jobs[job_id]
return JobResponse(
job_id=job.id,
status=job.status,
progress=job.progress,
metrics=job.metrics
)
@app.get("/health")
async def health():
return {"status": "healthy"}
References
- Hinton’s Distillation Paper: https://arxiv.org/abs/1503.02531
- DistilBERT: https://arxiv.org/abs/1910.01108
- TinyLlama: https://github.com/jzhang38/TinyLlama
- Born-Again Networks: https://arxiv.org/abs/1805.04770
Conclusion
Knowledge distillation is the most practical path to deploying capable models efficiently. Response-based distillation using soft targets is the simplest and often most effective approach—the temperature parameter controls how much “dark knowledge” transfers from teacher to student. Feature-based distillation adds depth by matching intermediate representations, useful when student architecture differs significantly from teacher. Relation-based distillation preserves structural knowledge about how samples relate to each other, valuable for tasks where relationships matter more than absolute predictions. Self-distillation and born-again networks show that even without a larger teacher, models can improve by distilling from themselves across generations. For LLMs, online distillation from API-based teachers enables creating specialized models without access to teacher weights. The key insight is that distillation is not just about compression—it’s about transferring the right knowledge in the right form. Start with soft target distillation, add feature matching if needed, and always evaluate on your target task rather than just matching teacher outputs. The goal is a student that performs well on your use case, not one that perfectly mimics the teacher.
Discover more from Code, Cloud & Context
Subscribe to get the latest posts sent to your email.
