AI Summary Hub

Knowledge distillation

Training a small student model to mimic a large teacher.

Definition

Knowledge distillation is a model compression technique in which a smaller student model is trained to reproduce the behavior of a larger, more capable teacher model. Rather than training the student on hard labels alone (the ground-truth class or token), distillation exposes the student to the teacher's soft outputs — probability distributions over all classes or tokens — which contain richer information about the model's internal representations and the relative similarities between concepts. This additional signal allows the student to reach accuracy levels that would require significantly more data or capacity if trained from scratch.

The concept was formalized by Hinton et al. in 2015 and has since been applied broadly: BERT (110M parameters) was distilled into DistilBERT (66M, retaining ~97% of BERT's performance), GPT-style models have been distilled into smaller chat-capable variants, and ensemble models have been distilled into single networks. Beyond classification, distillation applies to sequence generation (matching output distributions token by token), intermediate feature matching (aligning hidden states between teacher and student), and attention transfer (matching attention maps in transformer models).

Knowledge distillation is complementary to quantization and pruning in the compression pipeline. A typical production workflow distills a large model into a smaller one, then quantizes the student for deployment. Unlike pruning (which modifies an existing model) and quantization (which changes numeric representation), distillation creates a fundamentally different, purpose-trained model whose architecture can be freely designed.

How it works

Training pipeline

Loss function decomposition

Distillation variants

VariantWhat is matchedUse case
Response-based (Hinton)Output logits (soft labels)Classification, generation
Feature-basedIntermediate hidden statesStructural compression
Attention transferAttention weight mapsTransformer head compression
Data-free distillationSynthetic data generated by teacherNo access to original training data
Online distillationMutual learning between peersNo strong teacher required

When to use / When NOT to use

ScenarioUse knowledge distillationDo NOT use knowledge distillation
Need a small model with near-teacher accuracyYes — distillation is most accurate compression method
Deploying a student fine-tuned for a specific taskYes — task-specific distillation is very effective
Compressing ensemble models into a single networkYes — canonical use case
Need fast compression with no retrainingUse PTQ quantization instead
No access to training dataData-free distillation is complex; quantization is simpler
Pruning an existing model without changing architecturePruning is more appropriate

Pros and cons

ProsCons
Student architecture is unconstrained — can be freely designedRequires significant training compute (full training run)
Often achieves better accuracy than pruning at same compression ratioRequires access to the teacher at training time
Soft labels provide richer signal than hard labels aloneTeacher-student capacity gap can limit transfer effectiveness
Complementary to quantization and pruningHyperparameter tuning (temperature, loss weight) adds complexity

Code examples

# Knowledge distillation training loop in PyTorch
import torch
import torch.nn.functional as F

def distillation_loss(
    student_logits: torch.Tensor,
    teacher_logits: torch.Tensor,
    hard_labels: torch.Tensor,
    temperature: float = 4.0,
    alpha: float = 0.7,
) -> torch.Tensor:
    """Combine KL-divergence distillation loss with cross-entropy task loss."""
    # Soft targets from teacher (scaled by temperature)
    soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)
    soft_student = F.log_softmax(student_logits / temperature, dim=-1)

    # Distillation loss: KL divergence between soft distributions
    # Multiply by T^2 to maintain gradient magnitude relative to task loss
    loss_kl = F.kl_div(soft_student, soft_teacher, reduction="batchmean") * (temperature ** 2)

    # Task loss: standard cross-entropy with hard labels
    loss_ce = F.cross_entropy(student_logits, hard_labels)

    return alpha * loss_kl + (1 - alpha) * loss_ce


# Training setup: teacher is frozen, student is updated
teacher.train(False)   # set teacher to inference mode (no gradient updates)
student.train(True)

for x_batch, y_batch in train_loader:
    with torch.no_grad():
        teacher_logits = teacher(x_batch)   # get soft labels from frozen teacher

    student_logits = student(x_batch)       # student forward pass

    loss = distillation_loss(
        student_logits, teacher_logits, y_batch,
        temperature=4.0,
        alpha=0.7,
    )

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Practical resources

See also