7-Knowledge Distillation
7-Knowledge Distillation
7: Knowledge Distillation
Recap
2. Neural Network
● Quantization
● Pruning
● Knowledge Distillation
● AutoML
Data
Data
Teacher
Data
Knowledge
Student
Student
Teacher
Teacher
Large Neural Network
● Small DNNs are easier to deploy
● Knowledge? Knowledge
e.g.: softmax class
○ Classification: Softmax class probabilities probabilities
Student
● Proposed by Caruna et al. (2006)
Small Neural Network
Student
● Generalized by Hinton et al. (2015)
© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 5
Learning Outcomes
?
4. Using expert annotation
QUESTION
?
4. Using expert annotation distillation
Knowledge Distillation
Logits (z) Probabilities (q)
5x1 5x1
Neural Network
0.3 0.11
0.9 0.21
Student
Image Bac
Small Neural Network 1.9 0.54 kpro
224 x 224 x 3 Softmax p
0.1 0.09
0.2 0.10
Loss
Function
(Soft) Targets
0.1 0.03
Transfer set:
0.7
Teacher
subset of 0.25
training data Large Neural Network 2.9 Softmax 0.70
Trained 0.2 0.02
0.1 0.01
z q [T=1] q [T=10]
-1.1 0.007 0.171
Probability
1.4 0.087 0.219
3.7 0.880 0.276 T=10
T=5
0.1 0.024 0.193
T=1
-3.0 0.001 0.141
Classes
Student
Image 1.9 0.54
Small Neural Network Softmax
224 x 224 x 3
0.1 0.09
T=5
0.2 0.10
Loss
Which temperature to use during inference? Function
(Soft) Targets
0.1 0.03
0.7
Teacher
0.25
Large Neural Network 2.9 Softmax 0.70
0.2 T=5 0.02
0.1 0.01
?
4. T=10
QUESTION
?
4. T=10
Hard labels
0 0 1 0 0
Distillation Loss + Student Loss Hard prediction
0.02
0.11
Loss
0.84
Softmax Function
0.01
T=1
0.02
Student
Soft predictions Loss
Neural Network
0.11
Image 0.21
Student
0.25
Large Neural Network Softmax Loss
0.70
T=5 0.02
© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS 0.01 Lecture 7 - 16
Distillation Loss + Student Loss
Student loss
Distillation loss
ure
ction
t
© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 17
import torch
import torch.nn.functional as F
import torchvision
# forward
outputs = net(inputs)
# loss
loss = F.cross_entropy(outputs, labels)
# backward + optimize
loss.backward()
optimizer.step()
18
# load model + init optimizer
student = torchvision.models.mobilenet_v2()
optimizer = torchvision.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
teacher = torchvision.models.resnet101() # load teacher model
Load teacher + set hyper-params
T, ALPHA = 5, 0.3 # distillation hyper-params
loss = ALPHA * dist_loss + (1. - ALPHA) * hard_loss # combined hard + distillation loss
# backward + optimize
loss.backward()
optimizer.step() 19
Learning Outcomes
Student
Small Neural Network
Distillation
Loss
Teacher
Student
complicated – boost accuracy
Small Neural Network
● Ensembles:
Distillation
Loss
○ Different initializations
○ Different model architectures
Teacher
Student
Small Neural Network complicated – boost accuracy
Distillation ● Ensembles:
Loss ○ Different initializations
Large Neural Network
○ Different model architectures
T
Specialist 1
● Specialists [1]:
Large Neural Network Select ○ Divide classes to different model
T
Specialist 2 relevant +
minimize KL ■ Google JFT dataset has 15000
divergence classes
Large Neural Network
Specialist 2 ○ One generalist NN on all data
T
Specialist 3
○ Top-k classes from generalist are
further refined by specialists
Large Neural Network
○ How to choose specialist classes?
T
Generalist
● Feature-based:
○ Output/weights of 1 or more “hint layers” and minimize e.g. MSE loss
○ More advanced: minimize difference in attention maps between student/teacher
● Relation-based:
○ Correlations between feature maps: e.g. Gramian between two features