Blog

What does it mean to distill a machine learning model or LLM?

Distillation is a technique for creating smaller, faster, and more efficient versions of neural networks while retaining most of their performance.


DeepSeek has recently captured attention for its advanced reasoning and problem-solving capabilities, but the team’s most impressive feat may be how they distilled DeepSeek from a much larger model.

Introduction to model distillation

Model distillation usually involves two key players: a large teacher model and a smaller student model. The teacher is first trained on a dataset until it can achieve high accuracy (or other performance metrics).

Once trained, it generates outputs for various inputs in the form of “soft labels” (i.e., probability distributions), which are richer than basic one-hot labels. The student then learns to replicate these distributions and gains much of the teacher’s knowledge in the process.

What are one-hot labels?

Many classification tasks use one-hot labels to represent the correct category of an input. For instance, if you have three classes—cat, dog, and mouse—then a one-hot label for “dog” might look like [0, 1, 0] because the bit representing dog is “hot” or switched on.

Such labels indicate a single correct answer without providing information about how confident the model should be about it. In contrast, soft labels from the teacher could look like [0.1, 0.8, 0.1], providing richer insight into the teacher’s confidence distribution.

Simple example of one-hot vs. soft labels

One-hot (hard label):

[0, 1, 0]

Soft label (teacher’s distribution):

[0.1, 0.8, 0.1]

When the student trains with soft labels, it not only learns that “dog” is the correct label, but also that the teacher considered “cat” and “mouse” plausible to a lesser degree. This extra information can help the student model generalize better.

Understanding the distillation process

Distillation goes beyond merely copying outputs from a larger model. It often involves an interplay of hyperparameters (such as temperature, which softens or sharpens the teacher’s probability distribution, and alpha, which balances soft labels from the teacher against hard labels from the dataset) and requires careful experimentation.

Consider DistilBERT, a popular distilled version of BERT. DistilBERT learns from the detailed outputs of BERT, including logits and attention patterns. In many NLP benchmarks, DistilBERT retains roughly 97% of BERT’s accuracy while being 40% smaller and 60% faster.

What is a logit?

A logit is a neural network's raw, unnormalized output before applying an activation function like softmax or sigmoid. Logits serve as a rich source of information in the context of model distillation, capturing both the predicted class and the relative confidence across all possible classes.

Distillation leverages these logits from a larger teacher model to train a smaller student model, transferring knowledge more effectively than using hard labels alone.

Code example

Below is a simplified end-to-end example using Hugging Face Transformers. The code is intentionally simplified to provide a clear example.

In a production implementation, you’d have a much larger dataset and would also want to perform validation checks and tune hyperparameters (alpha, temperature, learning rates).

import torch
from transformers import BertForSequenceClassification, BertTokenizer, \
    DistilBertForSequenceClassification, DistilBertConfig

# -------------------
# 1. Load the teacher model and tokenizer
# -------------------
teacher_model_name = "bert-base-uncased"
teacher_tokenizer = BertTokenizer.from_pretrained(teacher_model_name)
teacher_model = BertForSequenceClassification.from_pretrained(teacher_model_name)
teacher_model.eval()  # Put the teacher in inference mode

# -------------------
# 2. Initialize the student model
#    (based on DistilBERT but could be any smaller architecture)
# -------------------
distil_config = DistilBertConfig.from_pretrained(teacher_model_name)
student_model = DistilBertForSequenceClassification(distil_config)

# -------------------
# 3. Prepare a small dataset (illustrative purposes only)
#    In real scenarios, you'd have a much larger dataset.
# -------------------
texts = ["Hello world!", "This is a test sentence."]
labels = [0, 1]  # Example one-hot labels would be [1, 0] or [0, 1] in actual usage
encodings = teacher_tokenizer(texts, truncation=True, padding=True, return_tensors='pt')

# -------------------
# 4. Generate soft labels from the teacher
#    (Teacher's logits serve as the basis for distillation)
# -------------------
with torch.no_grad():
    teacher_outputs = teacher_model(**encodings)
teacher_logits = teacher_outputs.logits

# -------------------
# 5. Define a custom distillation loss function
#    Combines teacher-student alignment (KLDivLoss) with standard cross-entropy.
# -------------------
def distillation_loss(student_logits, teacher_logits, true_labels, alpha=0.5, temperature=2.0):
    # KL divergence (teacher-student alignment)
    kl_div = torch.nn.KLDivLoss(reduction='batchmean')(
        torch.log_softmax(student_logits / temperature, dim=-1),
        torch.softmax(teacher_logits / temperature, dim=-1)
    )
    # Cross-entropy (aligning with hard labels)
    ce_loss = torch.nn.CrossEntropyLoss()(student_logits, true_labels)
    
    # The alpha parameter balances teacher knowledge vs. true labels
    # Temperature controls how "soft" or "sharp" the teacher's distributions are
    return alpha * kl_div * (temperature**2) + (1 - alpha) * ce_loss

# -------------------
# 6. Set up a simple training loop
# -------------------
optimizer = torch.optim.AdamW(student_model.parameters(), lr=1e-5)
true_labels = torch.tensor(labels)

for epoch in range(3):
    student_model.train()
    optimizer.zero_grad()
    
    outputs = student_model(**encodings)
    loss = distillation_loss(outputs.logits, teacher_logits, true_labels)
    
    loss.backward()
    optimizer.step()
    
    print(f"Epoch {epoch} | Distillation loss: {loss.item():.4f}")

Monitoring production and handling pitfalls

Even after distillation, a model requires ongoing monitoring and maintenance. If the production data distribution shifts over time, distilled models may diverge from the teacher’s performance, a phenomenon known as model drift. Biases baked into the teacher model can also propagate to the student, sometimes amplifying inaccuracies.

Careful evaluation of error metrics, throughput, and resource utilization in real-world deployments is essential to catch these issues early.

Distillation can occasionally fail to produce the desired speed or size improvements if the student’s architecture is smaller or the teacher’s outputs are too complex.

Tasks with extremely tight accuracy requirements may also be poor candidates for distillation, as any drop in performance could be unacceptable.

Model distillation packages the power while reducing model size

A smaller student model can approach the teacher’s performance with significantly reduced computational cost by absorbing the teacher's expertise through soft labels rather than basic one-hot labels alone.

This is an excellent trade-off for many real-world applications, including those running in constrained environments or on edge networks closer to users.Nevertheless, distillation is not a universal remedy; practitioners must consider hyperparameter tuning, possible biases, and rigorous monitoring throughout the model’s lifecycle.

With careful application, distillation can unlock powerful capabilities, making advanced AI accessible to a broader range of devices and budgets.

In this article

This site uses cookies to improve your experience. Please accept the use of cookies on this site. You can review our cookie policy here and our privacy policy here. If you choose to refuse, functionality of this site will be limited.