Latest Articles

Knowledge Distillation: Transferring Intelligence from Large to Small Models

Introduction: Knowledge distillation transfers the capabilities of large, expensive models into smaller, faster ones that can run efficiently in production. Instead of training a small model from scratch, distillation leverages the “dark knowledge” encoded in a teacher model’s soft probability distributions—information that hard labels alone cannot capture. This guide covers the techniques that make distillation effective: response-based distillation using soft targets, feature-based distillation that matches intermediate representations, relation-based distillation that preserves structural relationships, and self-distillation for iterative improvement. Whether you’re compressing a 70B model into a 7B deployment target or creating specialized models from general-purpose teachers, these patterns will help you achieve the best quality-efficiency tradeoff.

Knowledge Distillation
Knowledge Distillation: Teacher Model to Student Model via Soft Labels

Response-Based Distillation

from dataclasses import dataclass, field
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F

@dataclass
class DistillationConfig:
    """Configuration for knowledge distillation."""
    
    temperature: float = 2.0
    alpha: float = 0.5  # Weight for distillation loss
    hard_label_weight: float = 0.5  # Weight for hard label loss

class SoftTargetDistiller:
    """Distillation using soft targets."""
    
    def __init__(
        self,
        teacher: nn.Module,
        student: nn.Module,
        config: DistillationConfig
    ):
        self.teacher = teacher
        self.student = student
        self.config = config
        
        # Freeze teacher
        for param in self.teacher.parameters():
            param.requires_grad = False
        
        self.teacher.eval()
    
    def distillation_loss(
        self,
        student_logits: torch.Tensor,
        teacher_logits: torch.Tensor,
        labels: torch.Tensor = None
    ) -> torch.Tensor:
        """Calculate distillation loss."""
        
        T = self.config.temperature
        
        # Soft target loss (KL divergence)
        soft_targets = F.softmax(teacher_logits / T, dim=-1)
        soft_predictions = F.log_softmax(student_logits / T, dim=-1)
        
        distill_loss = F.kl_div(
            soft_predictions,
            soft_targets,
            reduction='batchmean'
        ) * (T * T)
        
        # Hard label loss (if labels provided)
        if labels is not None:
            hard_loss = F.cross_entropy(student_logits, labels)
            
            total_loss = (
                self.config.alpha * distill_loss +
                self.config.hard_label_weight * hard_loss
            )
        else:
            total_loss = distill_loss
        
        return total_loss
    
    def train_step(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor = None
    ) -> dict:
        """Single training step."""
        
        # Get teacher predictions
        with torch.no_grad():
            teacher_logits = self.teacher(inputs)
        
        # Get student predictions
        student_logits = self.student(inputs)
        
        # Calculate loss
        loss = self.distillation_loss(student_logits, teacher_logits, labels)
        
        return {
            "loss": loss,
            "student_logits": student_logits,
            "teacher_logits": teacher_logits
        }

class SequenceDistiller:
    """Distillation for sequence models (LLMs)."""
    
    def __init__(
        self,
        teacher: nn.Module,
        student: nn.Module,
        config: DistillationConfig
    ):
        self.teacher = teacher
        self.student = student
        self.config = config
        
        for param in self.teacher.parameters():
            param.requires_grad = False
    
    def distillation_loss(
        self,
        student_logits: torch.Tensor,
        teacher_logits: torch.Tensor,
        attention_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """Calculate sequence distillation loss."""
        
        T = self.config.temperature
        
        # Reshape for sequence
        batch_size, seq_len, vocab_size = student_logits.shape
        
        student_flat = student_logits.view(-1, vocab_size)
        teacher_flat = teacher_logits.view(-1, vocab_size)
        
        # KL divergence
        soft_targets = F.softmax(teacher_flat / T, dim=-1)
        soft_predictions = F.log_softmax(student_flat / T, dim=-1)
        
        kl_loss = F.kl_div(
            soft_predictions,
            soft_targets,
            reduction='none'
        ).sum(dim=-1)
        
        # Apply attention mask
        if attention_mask is not None:
            mask_flat = attention_mask.view(-1)
            kl_loss = kl_loss * mask_flat
            kl_loss = kl_loss.sum() / mask_flat.sum()
        else:
            kl_loss = kl_loss.mean()
        
        return kl_loss * (T * T)
    
    def train_step(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor = None
    ) -> dict:
        """Training step for sequence distillation."""
        
        with torch.no_grad():
            teacher_outputs = self.teacher(
                input_ids,
                attention_mask=attention_mask
            )
            teacher_logits = teacher_outputs.logits
        
        student_outputs = self.student(
            input_ids,
            attention_mask=attention_mask
        )
        student_logits = student_outputs.logits
        
        loss = self.distillation_loss(
            student_logits,
            teacher_logits,
            attention_mask
        )
        
        return {"loss": loss}

class TopKDistiller:
    """Distillation using only top-k logits."""
    
    def __init__(
        self,
        teacher: nn.Module,
        student: nn.Module,
        k: int = 100,
        temperature: float = 2.0
    ):
        self.teacher = teacher
        self.student = student
        self.k = k
        self.temperature = temperature
    
    def distillation_loss(
        self,
        student_logits: torch.Tensor,
        teacher_logits: torch.Tensor
    ) -> torch.Tensor:
        """Calculate top-k distillation loss."""
        
        T = self.temperature
        
        # Get top-k from teacher
        top_k_values, top_k_indices = torch.topk(teacher_logits, self.k, dim=-1)
        
        # Gather corresponding student logits
        student_top_k = torch.gather(student_logits, -1, top_k_indices)
        
        # Softmax over top-k only
        teacher_probs = F.softmax(top_k_values / T, dim=-1)
        student_log_probs = F.log_softmax(student_top_k / T, dim=-1)
        
        # KL divergence
        loss = F.kl_div(
            student_log_probs,
            teacher_probs,
            reduction='batchmean'
        ) * (T * T)
        
        return loss

Feature-Based Distillation

from dataclasses import dataclass
from typing import Any, Optional, List
import torch
import torch.nn as nn

@dataclass
class FeatureDistillConfig:
    """Configuration for feature distillation."""
    
    layer_mapping: dict  # teacher_layer -> student_layer
    loss_type: str = "mse"  # "mse", "cosine", "attention"
    projection: bool = True  # Project student features to teacher dim

class FeatureDistiller:
    """Distillation using intermediate features."""
    
    def __init__(
        self,
        teacher: nn.Module,
        student: nn.Module,
        config: FeatureDistillConfig
    ):
        self.teacher = teacher
        self.student = student
        self.config = config
        
        # Feature hooks
        self.teacher_features = {}
        self.student_features = {}
        
        self._register_hooks()
        
        # Projection layers
        if config.projection:
            self.projectors = self._create_projectors()
    
    def _register_hooks(self):
        """Register forward hooks to capture features."""
        
        def make_hook(storage, name):
            def hook(module, input, output):
                storage[name] = output
            return hook
        
        for teacher_layer, student_layer in self.config.layer_mapping.items():
            # Register teacher hook
            teacher_module = self._get_module(self.teacher, teacher_layer)
            teacher_module.register_forward_hook(
                make_hook(self.teacher_features, teacher_layer)
            )
            
            # Register student hook
            student_module = self._get_module(self.student, student_layer)
            student_module.register_forward_hook(
                make_hook(self.student_features, student_layer)
            )
    
    def _get_module(self, model: nn.Module, name: str) -> nn.Module:
        """Get module by name."""
        
        parts = name.split('.')
        module = model
        
        for part in parts:
            module = getattr(module, part)
        
        return module
    
    def _create_projectors(self) -> nn.ModuleDict:
        """Create projection layers for dimension matching."""
        
        projectors = {}
        
        for teacher_layer, student_layer in self.config.layer_mapping.items():
            # Would need to infer dimensions from model config
            # Simplified: assume we know dimensions
            projectors[student_layer] = nn.Linear(768, 1024)  # Example dims
        
        return nn.ModuleDict(projectors)
    
    def feature_loss(
        self,
        teacher_feat: torch.Tensor,
        student_feat: torch.Tensor,
        layer_name: str
    ) -> torch.Tensor:
        """Calculate feature matching loss."""
        
        # Project student features if needed
        if self.config.projection and layer_name in self.projectors:
            student_feat = self.projectors[layer_name](student_feat)
        
        if self.config.loss_type == "mse":
            return F.mse_loss(student_feat, teacher_feat)
        
        elif self.config.loss_type == "cosine":
            # Cosine similarity loss
            cos_sim = F.cosine_similarity(
                student_feat.view(-1, student_feat.size(-1)),
                teacher_feat.view(-1, teacher_feat.size(-1)),
                dim=-1
            )
            return (1 - cos_sim).mean()
        
        elif self.config.loss_type == "attention":
            # Attention transfer
            student_att = self._attention_map(student_feat)
            teacher_att = self._attention_map(teacher_feat)
            return F.mse_loss(student_att, teacher_att)
        
        raise ValueError(f"Unknown loss type: {self.config.loss_type}")
    
    def _attention_map(self, features: torch.Tensor) -> torch.Tensor:
        """Convert features to attention map."""
        
        # Sum of squared activations across channels
        return features.pow(2).sum(dim=-1)
    
    def train_step(self, inputs: torch.Tensor) -> dict:
        """Training step with feature distillation."""
        
        # Forward pass through both models
        with torch.no_grad():
            _ = self.teacher(inputs)
        
        _ = self.student(inputs)
        
        # Calculate feature losses
        total_loss = 0
        feature_losses = {}
        
        for teacher_layer, student_layer in self.config.layer_mapping.items():
            teacher_feat = self.teacher_features[teacher_layer]
            student_feat = self.student_features[student_layer]
            
            loss = self.feature_loss(teacher_feat, student_feat, student_layer)
            feature_losses[student_layer] = loss.item()
            total_loss += loss
        
        return {
            "loss": total_loss,
            "feature_losses": feature_losses
        }

class AttentionDistiller:
    """Distillation of attention patterns."""
    
    def __init__(
        self,
        teacher: nn.Module,
        student: nn.Module,
        layer_mapping: dict
    ):
        self.teacher = teacher
        self.student = student
        self.layer_mapping = layer_mapping
        
        self.teacher_attentions = {}
        self.student_attentions = {}
    
    def attention_loss(
        self,
        student_att: torch.Tensor,
        teacher_att: torch.Tensor
    ) -> torch.Tensor:
        """Calculate attention distillation loss."""
        
        # Handle different number of heads
        if student_att.shape[1] != teacher_att.shape[1]:
            # Average teacher heads to match student
            teacher_att = teacher_att.mean(dim=1, keepdim=True)
            teacher_att = teacher_att.expand_as(student_att)
        
        # MSE on attention weights
        return F.mse_loss(student_att, teacher_att)
    
    def train_step(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor = None
    ) -> dict:
        """Training step for attention distillation."""
        
        # Get teacher attentions
        with torch.no_grad():
            teacher_out = self.teacher(
                input_ids,
                attention_mask=attention_mask,
                output_attentions=True
            )
        
        # Get student attentions
        student_out = self.student(
            input_ids,
            attention_mask=attention_mask,
            output_attentions=True
        )
        
        # Calculate attention losses
        total_loss = 0
        
        for t_layer, s_layer in self.layer_mapping.items():
            t_att = teacher_out.attentions[t_layer]
            s_att = student_out.attentions[s_layer]
            
            total_loss += self.attention_loss(s_att, t_att)
        
        return {"loss": total_loss}

Relation-Based Distillation

from dataclasses import dataclass
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F

class RelationDistiller:
    """Distillation preserving relational knowledge."""
    
    def __init__(
        self,
        teacher: nn.Module,
        student: nn.Module,
        relation_type: str = "distance"  # "distance", "angle", "similarity"
    ):
        self.teacher = teacher
        self.student = student
        self.relation_type = relation_type
    
    def compute_relations(
        self,
        features: torch.Tensor
    ) -> torch.Tensor:
        """Compute pairwise relations between samples."""
        
        batch_size = features.shape[0]
        features_flat = features.view(batch_size, -1)
        
        if self.relation_type == "distance":
            # Pairwise Euclidean distance
            diff = features_flat.unsqueeze(0) - features_flat.unsqueeze(1)
            relations = torch.norm(diff, dim=-1)
            
        elif self.relation_type == "angle":
            # Pairwise angles
            normalized = F.normalize(features_flat, dim=-1)
            relations = torch.mm(normalized, normalized.t())
            
        elif self.relation_type == "similarity":
            # Cosine similarity
            relations = F.cosine_similarity(
                features_flat.unsqueeze(0),
                features_flat.unsqueeze(1),
                dim=-1
            )
        
        return relations
    
    def relation_loss(
        self,
        student_features: torch.Tensor,
        teacher_features: torch.Tensor
    ) -> torch.Tensor:
        """Calculate relation distillation loss."""
        
        student_relations = self.compute_relations(student_features)
        teacher_relations = self.compute_relations(teacher_features)
        
        # Normalize relations
        student_relations = F.normalize(student_relations.view(-1), dim=0)
        teacher_relations = F.normalize(teacher_relations.view(-1), dim=0)
        
        return F.mse_loss(student_relations, teacher_relations)
    
    def train_step(self, inputs: torch.Tensor) -> dict:
        """Training step with relation distillation."""
        
        with torch.no_grad():
            teacher_features = self.teacher.get_features(inputs)
        
        student_features = self.student.get_features(inputs)
        
        loss = self.relation_loss(student_features, teacher_features)
        
        return {"loss": loss}

class ContrastiveDistiller:
    """Contrastive knowledge distillation."""
    
    def __init__(
        self,
        teacher: nn.Module,
        student: nn.Module,
        temperature: float = 0.07
    ):
        self.teacher = teacher
        self.student = student
        self.temperature = temperature
    
    def contrastive_loss(
        self,
        student_features: torch.Tensor,
        teacher_features: torch.Tensor
    ) -> torch.Tensor:
        """Calculate contrastive distillation loss."""
        
        batch_size = student_features.shape[0]
        
        # Normalize features
        student_norm = F.normalize(student_features, dim=-1)
        teacher_norm = F.normalize(teacher_features, dim=-1)
        
        # Positive pairs: student-teacher for same sample
        positives = (student_norm * teacher_norm).sum(dim=-1) / self.temperature
        
        # Negative pairs: student with other teachers
        negatives = torch.mm(student_norm, teacher_norm.t()) / self.temperature
        
        # InfoNCE loss
        labels = torch.arange(batch_size, device=student_features.device)
        loss = F.cross_entropy(negatives, labels)
        
        return loss
    
    def train_step(self, inputs: torch.Tensor) -> dict:
        """Training step with contrastive distillation."""
        
        with torch.no_grad():
            teacher_features = self.teacher.get_features(inputs)
        
        student_features = self.student.get_features(inputs)
        
        loss = self.contrastive_loss(student_features, teacher_features)
        
        return {"loss": loss}

class GraphDistiller:
    """Distillation preserving graph structure."""
    
    def __init__(
        self,
        teacher: nn.Module,
        student: nn.Module,
        k_neighbors: int = 5
    ):
        self.teacher = teacher
        self.student = student
        self.k_neighbors = k_neighbors
    
    def build_graph(self, features: torch.Tensor) -> torch.Tensor:
        """Build k-NN graph from features."""
        
        # Compute pairwise distances
        distances = torch.cdist(features, features)
        
        # Get k nearest neighbors
        _, indices = torch.topk(distances, self.k_neighbors + 1, largest=False)
        
        # Build adjacency matrix
        batch_size = features.shape[0]
        adj = torch.zeros(batch_size, batch_size, device=features.device)
        
        for i in range(batch_size):
            adj[i, indices[i, 1:]] = 1  # Exclude self
        
        return adj
    
    def graph_loss(
        self,
        student_features: torch.Tensor,
        teacher_features: torch.Tensor
    ) -> torch.Tensor:
        """Calculate graph distillation loss."""
        
        student_graph = self.build_graph(student_features)
        teacher_graph = self.build_graph(teacher_features)
        
        # Binary cross entropy for graph matching
        loss = F.binary_cross_entropy(
            student_graph,
            teacher_graph
        )
        
        return loss

Self-Distillation

from dataclasses import dataclass
from typing import Any, Optional
import torch
import torch.nn as nn
import copy

class SelfDistiller:
    """Self-distillation for iterative improvement."""
    
    def __init__(
        self,
        model: nn.Module,
        temperature: float = 3.0,
        ema_decay: float = 0.999
    ):
        self.student = model
        self.teacher = copy.deepcopy(model)
        self.temperature = temperature
        self.ema_decay = ema_decay
        
        # Freeze teacher
        for param in self.teacher.parameters():
            param.requires_grad = False
    
    def update_teacher(self):
        """Update teacher with EMA of student weights."""
        
        with torch.no_grad():
            for t_param, s_param in zip(
                self.teacher.parameters(),
                self.student.parameters()
            ):
                t_param.data = (
                    self.ema_decay * t_param.data +
                    (1 - self.ema_decay) * s_param.data
                )
    
    def distillation_loss(
        self,
        student_logits: torch.Tensor,
        teacher_logits: torch.Tensor,
        labels: torch.Tensor
    ) -> torch.Tensor:
        """Calculate self-distillation loss."""
        
        T = self.temperature
        
        # Soft target loss
        soft_loss = F.kl_div(
            F.log_softmax(student_logits / T, dim=-1),
            F.softmax(teacher_logits / T, dim=-1),
            reduction='batchmean'
        ) * (T * T)
        
        # Hard label loss
        hard_loss = F.cross_entropy(student_logits, labels)
        
        return 0.5 * soft_loss + 0.5 * hard_loss
    
    def train_step(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor
    ) -> dict:
        """Training step with self-distillation."""
        
        # Get teacher predictions
        with torch.no_grad():
            teacher_logits = self.teacher(inputs)
        
        # Get student predictions
        student_logits = self.student(inputs)
        
        # Calculate loss
        loss = self.distillation_loss(student_logits, teacher_logits, labels)
        
        return {"loss": loss}

class BornAgainDistiller:
    """Born-Again Networks: multi-generation distillation."""
    
    def __init__(
        self,
        model_fn: callable,
        num_generations: int = 3,
        temperature: float = 2.0
    ):
        self.model_fn = model_fn
        self.num_generations = num_generations
        self.temperature = temperature
        
        self.generations = []
    
    def train_generation(
        self,
        train_loader: Any,
        num_epochs: int = 10
    ) -> nn.Module:
        """Train one generation."""
        
        # Create new student
        student = self.model_fn()
        
        # Get teacher (previous generation or None)
        teacher = self.generations[-1] if self.generations else None
        
        optimizer = torch.optim.Adam(student.parameters())
        
        for epoch in range(num_epochs):
            for inputs, labels in train_loader:
                optimizer.zero_grad()
                
                student_logits = student(inputs)
                
                if teacher is not None:
                    with torch.no_grad():
                        teacher_logits = teacher(inputs)
                    
                    # Distillation loss
                    T = self.temperature
                    loss = F.kl_div(
                        F.log_softmax(student_logits / T, dim=-1),
                        F.softmax(teacher_logits / T, dim=-1),
                        reduction='batchmean'
                    ) * (T * T)
                else:
                    # Standard cross-entropy for first generation
                    loss = F.cross_entropy(student_logits, labels)
                
                loss.backward()
                optimizer.step()
        
        self.generations.append(student)
        return student
    
    def train_all_generations(
        self,
        train_loader: Any,
        num_epochs: int = 10
    ) -> nn.Module:
        """Train all generations."""
        
        for gen in range(self.num_generations):
            print(f"Training generation {gen + 1}/{self.num_generations}")
            self.train_generation(train_loader, num_epochs)
        
        return self.generations[-1]

class DeepMutualLearning:
    """Deep Mutual Learning: students teach each other."""
    
    def __init__(
        self,
        models: list[nn.Module],
        temperature: float = 2.0
    ):
        self.models = models
        self.temperature = temperature
    
    def mutual_loss(
        self,
        logits_list: list[torch.Tensor],
        labels: torch.Tensor
    ) -> list[torch.Tensor]:
        """Calculate mutual learning losses."""
        
        T = self.temperature
        losses = []
        
        for i, logits_i in enumerate(logits_list):
            # Hard label loss
            ce_loss = F.cross_entropy(logits_i, labels)
            
            # KL divergence with other models
            kl_loss = 0
            for j, logits_j in enumerate(logits_list):
                if i != j:
                    kl_loss += F.kl_div(
                        F.log_softmax(logits_i / T, dim=-1),
                        F.softmax(logits_j.detach() / T, dim=-1),
                        reduction='batchmean'
                    ) * (T * T)
            
            kl_loss /= (len(logits_list) - 1)
            
            losses.append(ce_loss + kl_loss)
        
        return losses
    
    def train_step(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor
    ) -> dict:
        """Training step for mutual learning."""
        
        # Get predictions from all models
        logits_list = [model(inputs) for model in self.models]
        
        # Calculate losses
        losses = self.mutual_loss(logits_list, labels)
        
        return {
            "losses": losses,
            "total_loss": sum(losses)
        }

LLM-Specific Distillation

from dataclasses import dataclass
from typing import Any, Optional, List
import torch
import torch.nn.functional as F

@dataclass
class LLMDistillConfig:
    """Configuration for LLM distillation."""
    
    temperature: float = 2.0
    alpha: float = 0.5
    use_teacher_forcing: bool = True
    max_seq_len: int = 2048

class LLMDistiller:
    """Distillation for large language models."""
    
    def __init__(
        self,
        teacher: Any,
        student: Any,
        config: LLMDistillConfig
    ):
        self.teacher = teacher
        self.student = student
        self.config = config
    
    async def generate_training_data(
        self,
        prompts: list[str],
        num_samples: int = 1
    ) -> list[dict]:
        """Generate training data from teacher."""
        
        training_data = []
        
        for prompt in prompts:
            for _ in range(num_samples):
                # Generate from teacher
                response = await self.teacher.generate(
                    prompt,
                    temperature=0.7,
                    max_tokens=512
                )
                
                training_data.append({
                    "prompt": prompt,
                    "response": response,
                    "source": "teacher"
                })
        
        return training_data
    
    def distillation_loss(
        self,
        student_logits: torch.Tensor,
        teacher_logits: torch.Tensor,
        labels: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """Calculate LLM distillation loss."""
        
        T = self.config.temperature
        
        # Reshape
        vocab_size = student_logits.shape[-1]
        student_flat = student_logits.view(-1, vocab_size)
        teacher_flat = teacher_logits.view(-1, vocab_size)
        labels_flat = labels.view(-1)
        mask_flat = attention_mask.view(-1)
        
        # Soft target loss
        soft_targets = F.softmax(teacher_flat / T, dim=-1)
        soft_predictions = F.log_softmax(student_flat / T, dim=-1)
        
        kl_loss = F.kl_div(
            soft_predictions,
            soft_targets,
            reduction='none'
        ).sum(dim=-1)
        
        kl_loss = (kl_loss * mask_flat).sum() / mask_flat.sum()
        kl_loss = kl_loss * (T * T)
        
        # Hard label loss
        ce_loss = F.cross_entropy(
            student_flat,
            labels_flat,
            reduction='none'
        )
        ce_loss = (ce_loss * mask_flat).sum() / mask_flat.sum()
        
        return self.config.alpha * kl_loss + (1 - self.config.alpha) * ce_loss

class OnlineDistiller:
    """Online distillation from API-based teacher."""
    
    def __init__(
        self,
        teacher_api: Any,
        student: Any,
        cache_size: int = 10000
    ):
        self.teacher_api = teacher_api
        self.student = student
        self.cache: dict[str, str] = {}
        self.cache_size = cache_size
    
    async def get_teacher_response(self, prompt: str) -> str:
        """Get teacher response with caching."""
        
        if prompt in self.cache:
            return self.cache[prompt]
        
        response = await self.teacher_api.generate(prompt)
        
        # Cache management
        if len(self.cache) >= self.cache_size:
            # Remove oldest entry
            oldest = next(iter(self.cache))
            del self.cache[oldest]
        
        self.cache[prompt] = response
        return response
    
    async def train_step(
        self,
        prompts: list[str]
    ) -> dict:
        """Training step with online distillation."""
        
        # Get teacher responses
        teacher_responses = []
        for prompt in prompts:
            response = await self.get_teacher_response(prompt)
            teacher_responses.append(response)
        
        # Train student on teacher responses
        loss = self.student.train_on_examples(
            prompts,
            teacher_responses
        )
        
        return {"loss": loss}

class TaskSpecificDistiller:
    """Distillation for specific tasks."""
    
    def __init__(
        self,
        teacher: Any,
        student: Any,
        task_type: str  # "classification", "generation", "qa"
    ):
        self.teacher = teacher
        self.student = student
        self.task_type = task_type
    
    async def distill_classification(
        self,
        texts: list[str],
        num_classes: int
    ) -> dict:
        """Distill classification task."""
        
        # Get teacher soft labels
        teacher_probs = []
        for text in texts:
            probs = await self.teacher.classify(text)
            teacher_probs.append(probs)
        
        # Train student
        loss = self.student.train_classifier(
            texts,
            torch.tensor(teacher_probs)
        )
        
        return {"loss": loss}
    
    async def distill_generation(
        self,
        prompts: list[str]
    ) -> dict:
        """Distill generation task."""
        
        # Generate from teacher
        teacher_outputs = []
        for prompt in prompts:
            output = await self.teacher.generate(prompt)
            teacher_outputs.append(output)
        
        # Train student on teacher outputs
        loss = self.student.train_generator(
            prompts,
            teacher_outputs
        )
        
        return {"loss": loss}

Production Distillation Service

from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
from typing import Optional, Any
import asyncio
import uuid

app = FastAPI()

class DistillationJob:
    """A distillation job."""
    
    def __init__(self, job_id: str, config: dict):
        self.id = job_id
        self.config = config
        self.status = "pending"
        self.progress = 0.0
        self.metrics = {}

jobs: dict[str, DistillationJob] = {}

class CreateJobRequest(BaseModel):
    teacher_model: str
    student_model: str
    dataset_path: str
    temperature: float = 2.0
    num_epochs: int = 3

class JobResponse(BaseModel):
    job_id: str
    status: str
    progress: float
    metrics: dict

async def run_distillation(job: DistillationJob):
    """Run distillation job."""
    
    job.status = "running"
    
    try:
        # Simulate distillation progress
        for epoch in range(job.config["num_epochs"]):
            job.progress = (epoch + 1) / job.config["num_epochs"]
            job.metrics["epoch"] = epoch + 1
            job.metrics["loss"] = 1.0 / (epoch + 1)
            await asyncio.sleep(1)
        
        job.status = "completed"
        job.progress = 1.0
        
    except Exception as e:
        job.status = "failed"
        job.metrics["error"] = str(e)

@app.post("/v1/distillation/jobs")
async def create_job(
    request: CreateJobRequest,
    background_tasks: BackgroundTasks
) -> JobResponse:
    """Create distillation job."""
    
    job_id = str(uuid.uuid4())
    
    job = DistillationJob(
        job_id=job_id,
        config=request.model_dump()
    )
    jobs[job_id] = job
    
    background_tasks.add_task(run_distillation, job)
    
    return JobResponse(
        job_id=job_id,
        status=job.status,
        progress=job.progress,
        metrics=job.metrics
    )

@app.get("/v1/distillation/jobs/{job_id}")
async def get_job(job_id: str) -> JobResponse:
    """Get job status."""
    
    if job_id not in jobs:
        raise HTTPException(status_code=404, detail="Job not found")
    
    job = jobs[job_id]
    
    return JobResponse(
        job_id=job.id,
        status=job.status,
        progress=job.progress,
        metrics=job.metrics
    )

@app.get("/health")
async def health():
    return {"status": "healthy"}

References

Conclusion

Knowledge distillation is the most practical path to deploying capable models efficiently. Response-based distillation using soft targets is the simplest and often most effective approach—the temperature parameter controls how much “dark knowledge” transfers from teacher to student. Feature-based distillation adds depth by matching intermediate representations, useful when student architecture differs significantly from teacher. Relation-based distillation preserves structural knowledge about how samples relate to each other, valuable for tasks where relationships matter more than absolute predictions. Self-distillation and born-again networks show that even without a larger teacher, models can improve by distilling from themselves across generations. For LLMs, online distillation from API-based teachers enables creating specialized models without access to teacher weights. The key insight is that distillation is not just about compression—it’s about transferring the right knowledge in the right form. Start with soft target distillation, add feature matching if needed, and always evaluate on your target task rather than just matching teacher outputs. The goal is a student that performs well on your use case, not one that perfectly mimics the teacher.


Discover more from Code, Cloud & Context

Subscribe to get the latest posts sent to your email.

About the Author

I am a Cloud Architect and Developer passionate about solving complex problems with modern technology. My blog explores the intersection of Cloud Architecture, Artificial Intelligence, and Software Engineering. I share tutorials, deep dives, and insights into building scalable, intelligent systems.

Areas of Expertise

Cloud Architecture (Azure, AWS)
Artificial Intelligence & LLMs
DevOps & Kubernetes
Backend Dev (C#, .NET, Python, Node.js)
© 2025 Code, Cloud & Context | Built by Nithin Mohan TK | Powered by Passion