“Model ensembles are a pretty muchguaranteed way to gain 2% of accuracy on anything.”  Andrej Karpathy.
I absolutely agree! However, deploying an ensemble of heavyweight models may not always be feasible in many cases. Sometimes, your single model could be so large (GPT3, for example) that deploying it in resourceconstrained environments is often not possible. This is why we have been going over some of model optimization recipes  Quantization and Pruning. This report is the last one in this series. In this report, we will discuss a compelling model optimization technique  knowledge distillation. I have structured the report into the following sections 
When working with a classification problem, it is very typical to use softmax as the last activation unit in your neural network. Why is that? Because a softmax function takes a set of logits and spits out a probability distribution over the discrete classes, your network is being trained on. Figure 1 presents an example of this.
> Figure 1: Predictions of a neural network on an input image. <
In Figure 1, our imaginary neural network is highly confident that the given image is $1$. However, it also thinks that there is a slight chance it could be $7$ as well. It is thinking quite right, isn’t it? The given image does have subtle sevenish characteristics. This information would not have been available if we were only dealing with hard onehot encoded labels like [1, 0] (where 1 and 0 are probabilities of the image being a one and a seven respectively).
Humans are well equipped to exploit this sort of relativeness. More examples include  a catish dog, a brownish red, a catish tiger, and so on. These are still valid comparisons as Hinton et al. opines in [1] 
An image of a BMW, for example, may only have a minimal chance of being mistaken for a garbage truck, but that mistake is still many times more probable than mistaking it for a carrot.
This very knowledge helps us to generalize excruciatingly well out there in the wild.
This thought process helps us to dig deeper into what our models might be thinking about the input data. It should be somewhat consistent with the way we would think about the input data. Figure 1 again establishes this  to our eyes, that image looks like a one, but it has some traits of a seven.
So, what now? An immediate question that may strike the mind  what is the best way for us to use this knowledge in neural networks? Let us find out in the next section.
The softmax information is way more useful than plan hard onehot encoded labels. So, at this stage, we may have access to 
We are now interested in using the output probabilities produced by our trained network.
Consider teaching someone about the English digits with the MNIST dataset. It is highly likely that you would run into the question from a student  does not that one look like a seven? If that is the case, it’s definitely good news because your student, for sure, knows how a one and a seven look like. As a teacher, you have been able to transfer your knowledge of English digits to your student. It is possible to extend this idea to neural networks as well.
So, here is the deal at a highlevel 
This workflow briefly formulates the idea of knowledge distillation.
Why smaller? Isn’t this we want? To deploy a lightweight model to production that is performant enough?
Disclaimer: For the sake of brevity and simplicity, I am going to demonstrate the further sections on a computer visionbased example. Note: These ideas are independent of domains.
For an image classification example, we can extend the earlier highlevel idea 
> Figure 2: A highlevel overview of knowledge distillation. <
Why are we training the student model on softlabels?
Remember that our student model is smaller than the teacher model in terms of capacity. So, if your dataset is complex enough, then the smaller student model may not be well suited to capture the hidden representations required for the training objective. We train the student model on softlabels to compensate for this, which provides more meaningful information than the onehot encoded labels. In a sense, we are training the student model to imitate the teacher model’s outputs by giving a little bit of exposure to the training dataset.
Hopefully, this provided you with an intuitive understanding of knowledge distillation. In the next section, we will be taking a more detailed look at the student model's training mechanics.
In order to train the student model, we can still use our regular crossentropy loss between the softlabels from the teacher and predicted labels from the student. It is highly likely that the student model would be confident about many of the input data points, and it would predict probability distributions like the following 
> Figure 3: Highly confident predictions. <
The problem with these weak probabilities (marked in red in Figure 3) is they do not capture desirable information for the student model to learn effectively. For example, it is almost impossible to transfer the knowledge that the image has sevenish traits if the probability distribution is like [0.99, 0.01]
.
Hinton et al. address this problem by scaling the raw logits of the teacher model by some temperature ($\tau$) before they get passed to softmax [1] (known as extended softmax or temperaturescaled softmax). That way, the distribution gets more spread across the available class labels. This same temperature is used in order to train the student model. I have presented this idea in Figure 4.
> Figure 4: Softened predictions. <
We can write the student model’s modified loss function in the form of this equation 
> $\mathcal{L}{C E}^{K D} = \sum{i} p_{i} \log s_{i}$, <
where $p_i$ is the softened probability distribution of the teacher model and $s_i$ is expressed as  $\frac{\exp \left(z_{i} / T\right)}{\sum_{j} \exp \left(z_{j} / T\right)}$.
def get_kd_loss(student_logits, teacher_logits,
true_labels, temperature,
alpha, beta):
teacher_probs = tf.nn.softmax(teacher_logits / temperature)
kd_loss = tf.keras.losses.categorical_crossentropy(
teacher_probs, student_logits / temperature,
from_logits=True)
return kd_loss
In [1] Hinton et al. also explore the idea of using the conventional crossentropy loss between the true target labels (typically onehot encoded) and student model’s predictions. This especially helps when the training dataset is small and there isn’t enough signal in the softlabels for the student model to pick up.
This approach works significantly better when it is combined with the extended softmax and the overall loss function becomes a weighted average between the two 
> $\mathcal{L} = \frac{ (\alpha * \mathcal{L}{C E}^{K D} + \beta * \mathcal{L}{CE})} {(\alpha + \beta)}$ <
def get_kd_loss(student_logits, teacher_logits,
true_labels, temperature,
alpha, beta):
teacher_probs = tf.nn.softmax(teacher_logits / temperature)
kd_loss = tf.keras.losses.categorical_crossentropy(
teacher_probs, student_logits / temperature,
from_logits=True)
ce_loss = tf.keras.losses.sparse_categorical_crossentropy(
true_labels, student_logits, from_logits=True)
total_loss = (alpha * kd_loss) + (beta * ce_loss)
return total_loss / (alpha + beta)
It’s recommended to weigh $\beta$ considerably smaller than $\alpha$.
Caruana et al. operate on the raw logits instead of the softmax values [2]. This workflow is as follows 
This part remains the same 
Train a teacher model that performs well on your image dataset. Here the crossentropy loss would be calculated with respect to the true labels from your dataset.
Now, in order to train the student model, the training objective becomes minimizing the mean squared error between the raw logits from the teacher and the student models respectively.
> $\mathcal{L}{M S E}^{K D} = \sum{i}\leftz_{i}^{\theta_{student}}z_{i(\text { teacher })}^{\text {fixed }}\right^{2}$ <
mse = tf.keras.losses.MeanSquaredError()
def mse_kd_loss(teacher_logits, student_logits):
return mse(teacher_logits, student_logits)
One potential disadvantage of using this loss function could be its unconstrained nature. The raw logits can capture noise which a small model may not be able to properly fit. This is why in order for this loss function to fit well in the distillation regime, the student model needs to a bit bigger.
Tang et al. explore the idea of interpolating between the two losses  the extended softmax and the MSE loss [3]. Mathematically, it would look like the following 
> $\mathcal{L}=(1\alpha) \cdot \mathcal{L}{M S E}^{K D}+\alpha \cdot \mathcal{L}{C E}^{K D}$ <
Empirically, they found out when $\alpha$ is equal to 0, the best performance is achieved (on NLP tasks).
If you’re feeling a bit overwhelmed at this point, don’t sweat it. Hopefully, with the code, things will start to shine.
In this section, I will provide you with a few training recipes that you can consider while working with knowledge distillation.
This idea is explored in [3] by Tang et al. They demonstrate this idea on NLP datasets but this is applicable to other domains as well. In order to better guide the student model training using data augmentation can help especially when you are dealing with fewer data. As we typically keep the student model much smaller than the teacher model, so the hope is with more diverse data the student model gets to capture the domain better.
In works like Noisy Student Training [4] and SimCLRV2 [5] the authors use additional unlabeled data when training the student model. So, you would use your teacher model to generate the groundtruth distribution on the unlabeled dataset. This helps to increase the generalizability of the model to a great extent. This approach is only feasible when unlabeled data is available in the domain of the dataset you’re dealing with. Sometimes, it may not be the case (healthcare, for example). In [4], Xie et al. explore techniques like data balancing and data filtering in order to mitigate the issues that may arise when incorporating unlabeled data when training the student model.
Labelsmoothing is a technique used to relax the high confidence predictions produced by models. It helps to reduce overfitting but it is not recommended to use labelsmoothing when training the teacher model since its logits are anyway scaled by some temperature. Hence using labelsmoothing in a knowledge distillation situation is not typically recommended. You can check out this article to know more about labelsmoothing.
Hinton et al. recommends using higher temperature values to soften the distributions predicted by the teacher model so that the softlabels can contain even more information for the student model. This is especially useful when dealing with small datasets. For larger datasets, the information becomes available by means of the number of training examples. Refer to the extended softmax section again if this is not clear as to why using a higher temperature might be better in softening the predicted distributions.
We will be exploring these recipes shortly in the next section.
Let’s first review the experimental set up. I used the Flowers dataset for my experiments. Unless otherwise specified, I used the following configurations 
I used MobileNetV2 as the base model for finetuning with learning rate set to 1e5
with Adam as the optimizer.
I set the (temperature) $\tau$ to 5 when using $\mathcal{L}{C E}^{K D}$ and weighted average of $\mathcal{L}{C E}^{K D}$ and traditional crossentropy losses respectively.
$\alpha$ = 0.9 and $\beta$ = 0.1 when using a weighted average of $\mathcal{L}_{C E}^{K D}$ and traditional crossentropy losses.
For the student model, I followed this shallow architecture 
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 222, 222, 64) 1792
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 55, 55, 64) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 53, 53, 128) 73856
_________________________________________________________________
global_average_pooling2d_3 ( (None, 128) 0
_________________________________________________________________
dense_3 (Dense) (None, 512) 66048
_________________________________________________________________
dense_4 (Dense) (None, 5) 2565
=================================================================
During training the student model I used Adam as an optimizer with a learning rate of 1e2
.
During training the student model with data augmentation, I used the weighted average loss with the same default hyperparameters mentioned above.
To make the performance comparisons fair, let's also train the shallow CNN from scratch and observe its performance. Note that in this case, I used Adam as the optimizer with a learning rate of 1e3
.
Before we see the results, I wanted to shed some light on the training loop and how I was able to wrap that inside the classic model.fit()
call. This is how the training loop looks like 
def train_step(self, data):
images, labels = data
teacher_logits = self.trained_teacher(images)
with tf.GradientTape() as tape:
student_logits = self.student(images)
loss = get_kd_loss(teacher_logits, student_logits)
gradients = tape.gradient(loss, self.student.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))
train_loss.update_state(loss)
train_acc.update_state(labels, tf.nn.softmax(student_logits))
t_loss, t_acc = train_loss.result(), train_acc.result()
train_loss.reset_states(), train_acc.reset_states()
return {"loss": t_loss, "accuracy": t_acc}
The train_step()
function should be an easy read if you are already familiar with how to customize a training loop in TensorFlow 2. Notice the get_kd_loss()
function. This can be any of the loss functions we have discussed do far. We are using a trained teacher model here, the model we finetuned earlier. With this training loop, we can create an entire model that can be trained with a .fit()
call.
First, create a class extending tf.keras.Model

class Student(tf.keras.Model):
def __init__(self, trained_teacher, student):
super(Student, self).__init__()
self.trained_teacher = trained_teacher
self.student = student
When you extend the tf.keras.Model
class, you can put your custom training logic inside the train_step()
function (it's provided by the class). So, in its entirety, the Student
class would look like this 
class Student(tf.keras.Model):
def __init__(self, trained_teacher, student):
super(Student, self).__init__()
self.trained_teacher = trained_teacher
self.student = student
def train_step(self, data):
images, labels = data
teacher_logits = self.trained_teacher(images)
with tf.GradientTape() as tape:
student_logits = self.student(images)
loss = get_kd_loss(teacher_logits, student_logits)
gradients = tape.gradient(loss, self.student.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))
train_loss.update_state(loss)
train_acc.update_state(labels, tf.nn.softmax(student_logits))
t_loss, t_acc = train_loss.result(), train_acc.result()
train_loss.reset_states(), train_acc.reset_states()
return {"train_loss": t_loss, "train_accuracy": t_acc}
You can even write a test_step
to customize the evaluation behavior of the model. If you are interested to check that out and also the train_step()
utility check out this Colab Notebook. Our model can now be trained in the following manner 
student = Student(teacher_model, get_student_model())
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
student.compile(optimizer)
student.fit(train_ds,
validation_data=validation_ds,
epochs=10)
One potential advantage of this method is one can easily incorporate other capabilities like distributed training, custom callbacks, mixed precision, and so on.
Upon training our shallow student model with this loss function we get ~74% validation accuracy. We see that the losses start to increase after epoch 8. This suggests stronger regularization might have helped. Also, note that the hyperparameter tuning process has a significant impact here. In my experiments, I did not do rigorous hyperparameter tuning. In order to do faster experimentation, I kept the training schedules short.
Let's now see if incorporating the groundtruth labels in the distillation training objective helps. With $\beta$ = 0.1 and $\alpha$ = 0.1, we get around ~71% validation accuracy. The training dynamics again suggests that stronger regularization with a longer training schedule would have helped.
With the MSE loss, we see that the validation accuracy gets a whopping decrease to ~56%. The same kind of loss behavior is present in this setting as well suggesting the need for regularization.
Note that this loss function is absolutely unconstrained and our shallow student model may not be capable of handling the noise that comes with it. Let's try out with a deeper student model.
As mentioned earlier, the student models are of smaller capacity than the teacher model. When dealing with less data, data augmentation can be helpful to train the student model. Let's verify.
In this experiment, let's study the effect of temperature on the student model. In this setting, I used the same shallow CNN.
Finally, I wanted to study if the choice of a base model for finetuning had significant effect on the student model.
Finally, if you are wondering what kind of improvement one could get out of knowledge distillation with respect to production purposes. The table below dictates that for us. Without any hyperparameter tuning, we are able to get a decent model that is significantly more lightweight than the other models shown in the table.
The first row corresponds to the default student model trained with the weighted average loss while the other rows correspond to EfficientNey B0 and MobileNetV2 respectively. Note that I did not include the results I got from including data augmentation during training the student model.
This concludes the report and also the series I have been developing on model optimization. Knowledge distillation is a very promising technique specifically suited for deployment purposes. A very good about it is it can be combined with quantization and pruning pretty seamlessly in order to further reduce the size of your production models without having to compromise with the accuracy.
We studied the idea of knowledge distillation in an image classification setting. If you are wondering if it’s extensible to other areas like NLP or even GANs I recommend going over the following resources 
You can also check out some of the CVPR 2020 papers on knowledge distillation here.
Another trend that you would see is having a larger or equivalent student model. This has been very systematically studied in [4].
I am grateful to Aakash Kumar Nain for providing valuable feedback on the code.