0% found this document useful (0 votes)
124 views29 pages

7-Knowledge Distillation

mmmmm kmmm

Uploaded by

MSR MSR
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
124 views29 pages

7-Knowledge Distillation

mmmmm kmmm

Uploaded by

MSR MSR
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 29

admin

● Second guest lecture scheduled: Jonathan Frankle


(MIT/MosaicML/Databricks)
● Start A2 early
○ Get your TinyML kits if you haven’t already
■ At the end of class or during office hours

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 1


ECE 5545: Machine Learning Hardware and Systems

7: Knowledge Distillation
Recap

2. Neural Network
● Quantization
● Pruning
● Knowledge Distillation
● AutoML

Methods to reduce the number of computations


and/or memory footprint of DNNs

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 3


What is Knowledge Distillation?

Data
Data

Teacher
Data

Knowledge
Student

Student

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 4


What is Knowledge Distillation?
● Distill “knowledge” from a large
neural network to a small one. Training
Data Training
○ E.g. ResNet101 → MobileNet Data

● Larger DNNs are easier to train

Teacher
Teacher
Large Neural Network
● Small DNNs are easier to deploy

● Knowledge? Knowledge
e.g.: softmax class
○ Classification: Softmax class probabilities probabilities

Student
● Proposed by Caruna et al. (2006)
Small Neural Network
Student
● Generalized by Hinton et al. (2015)
© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 5
Learning Outcomes

Understand what knowledge distillation is and why it is needed.

Write code to perform knowledge distillation

Understand advanced knowledge distillation techniques and open problems

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 6


Training a neural network
Logits (z) Probabilities (q)
Image 5x1 5x1
Neural Network
224 x 224 x 3 0.3 0.11
0.9 0.21
Input Output
Hidden Layers 1.9 0.54
Layer Layer Softmax
(conv + pool layers) 0.1 0.09
(conv) (FC)
0.2 0.10
Loss
Function
(Hard) Targets
● Key observation: Hard targets have no information 0
car
about wrong classes
dog 0
● Soft targets have information about wrong classes cat 1

○ Where do we get them from? bus 0


boat 0

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 7


Training a neural network
Logits (z) Probabilities (q)
Image 5x1 5x1
Neural Network
224 x 224 x 3 0.3 0.11
0.9 0.21
Input Output
Hidden Layers 1.9 0.54
Layer Layer Softmax
(conv + pool layers) 0.1 0.09
(conv) (FC)
0.2 0.10
Loss
Function
(Soft) Targets (Hard) Targets
● Key observation: Hard targets have no information 0
0.03 car
about wrong classes
0.25 dog 0
● Soft targets have information about wrong classes 0.70 cat 1

○ Where do we get them from? 0.02 bus 0


0.01 boat 0

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 8


QUESTION

Where do we get soft targets?

1. From labeled data by clustering similar classes


2. Using an equation
3. From a trained neural network

?
4. Using expert annotation
QUESTION

Where do we get soft targets?

1. From labeled data by clustering similar classes


2. Using an equation
“Teacher”
3. From a trained neural network network in

?
4. Using expert annotation distillation
Knowledge Distillation
Logits (z) Probabilities (q)
5x1 5x1
Neural Network
0.3 0.11
0.9 0.21

Student
Image Bac
Small Neural Network 1.9 0.54 kpro
224 x 224 x 3 Softmax p
0.1 0.09
0.2 0.10
Loss
Function
(Soft) Targets
0.1 0.03
Transfer set:
0.7
Teacher

subset of 0.25
training data Large Neural Network 2.9 Softmax 0.70
Trained 0.2 0.02
0.1 0.01

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 11


Softmax Temperature
Recall: Softmax function:

Softmax with temperature:

z q [T=1] q [T=10]
-1.1 0.007 0.171

Probability
1.4 0.087 0.219
3.7 0.880 0.276 T=10
T=5
0.1 0.024 0.193
T=1
-3.0 0.001 0.141
Classes

Additional hyper-parameter: T → exposes more “dark knowledge”

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 12


Knowledge Distillation Can we use logits?

Logits (z) Probabilities (q)


5x1 5x1
Neural Network
0.3 0.11 Which loss
0.9 0.21 function?

Student
Image 1.9 0.54
Small Neural Network Softmax
224 x 224 x 3
0.1 0.09
T=5
0.2 0.10
Loss
Which temperature to use during inference? Function
(Soft) Targets
0.1 0.03
0.7
Teacher

0.25
Large Neural Network 2.9 Softmax 0.70
0.2 T=5 0.02
0.1 0.01

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 13


QUESTION

What temperature to use during inference?

1. T=5 (same as training)


2. T=1 (standard softmax)
3. T=0

?
4. T=10
QUESTION

What temperature to use during inference?

1. T=5 (same as training) No reason to change T, and


standard T is sometimes
2. T=1 (standard softmax)
used for setting a confidence
3. T=0 threshold

?
4. T=10
Hard labels
0 0 1 0 0
Distillation Loss + Student Loss Hard prediction
0.02
0.11
Loss
0.84
Softmax Function
0.01
T=1
0.02
Student
Soft predictions Loss
Neural Network
0.11
Image 0.21
Student

224 x 224 x 3 0.54


Small Neural Network Softmax
0.09
T=5
0.10
Loss
(Soft) Targets Function
0.03
Distillation
Teacher

0.25
Large Neural Network Softmax Loss
0.70
T=5 0.02
© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS 0.01 Lecture 7 - 16
Distillation Loss + Student Loss

Student loss
Distillation loss

ure
ction
t
© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 17
import torch
import torch.nn.functional as F
import torchvision

# load model + init optimizer


model = torchvision.models.mobilenet_v2()
optimizer = torchvision.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

for epoch in range(NUM_EPOCHS):

for inputs,labels in trainloader:

# forward
outputs = net(inputs)

# loss
loss = F.cross_entropy(outputs, labels)

# backward + optimize
loss.backward()
optimizer.step()
18
# load model + init optimizer
student = torchvision.models.mobilenet_v2()
optimizer = torchvision.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
teacher = torchvision.models.resnet101() # load teacher model
Load teacher + set hyper-params
T, ALPHA = 5, 0.3 # distillation hyper-params

for epoch in range(NUM_EPOCHS):


for inputs,labels in trainloader:
# forward
outputs_s = student(inputs)
outputs_t = teacher(inputs) # teacher forward pass

hard_loss = F.cross_entropy(outputs_s, labels) # hard loss with G.T. labels

#distillation loss Compute loss


p, q = F.softmax(outputs_s/T, dim=1), F.softmax(outputs_t/T, dim=1)
dist_loss = F.kl_div(p, q)

loss = ALPHA * dist_loss + (1. - ALPHA) * hard_loss # combined hard + distillation loss

# backward + optimize
loss.backward()
optimizer.step() 19
Learning Outcomes

Understand what knowledge distillation is and why it is needed.

Write code to perform knowledge distillation

Understand advanced knowledge distillation techniques and open problems

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 20


Ensembles and Specialists
● Teacher architecture can be more
complicated – boost accuracy

Student
Small Neural Network

Distillation
Loss
Teacher

Large Neural Network

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 21


Ensembles and Specialists
● Teacher architecture can be more

Student
complicated – boost accuracy
Small Neural Network
● Ensembles:
Distillation
Loss
○ Different initializations
○ Different model architectures
Teacher

Large Neural Network


Average
Teacher

Large Neural Network


Teacher

Large Neural Network

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 22


Ensembles and Specialists
● Teacher architecture can be more

Student
Small Neural Network complicated – boost accuracy
Distillation ● Ensembles:
Loss ○ Different initializations
Large Neural Network
○ Different model architectures
T

Specialist 1
● Specialists [1]:
Large Neural Network Select ○ Divide classes to different model
T

Specialist 2 relevant +
minimize KL ■ Google JFT dataset has 15000
divergence classes
Large Neural Network
Specialist 2 ○ One generalist NN on all data
T

Specialist 3
○ Top-k classes from generalist are
further refined by specialists
Large Neural Network
○ How to choose specialist classes?
T

Generalist

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 23


Distillation Types
● Offline: Pretrained teacher used to add Teacher
Student
distillation loss during student training Pretrained

● Online: Both teacher and student are


trained simultaneously Teacher
Student
○ Collaborative/mutual learning Untrained

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 24


Distillation Types
● Offline: Pretrained teacher used to add Teacher
distillation loss during student training Student
Pretrained
● Online: Both teacher and student are trained
simultaneously
Teacher
○ Collaborative/mutual learning Student
Untrained
● Self distillation:
○ E.g. Progressive hierarchical inference
Teacher/Student
Final
Exit
Early Exit 1 Early Exit 2

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 25


Knowledge Types
● Response-based:
○ Output probabilities as soft targets (as we have already seen)

● Feature-based:
○ Output/weights of 1 or more “hint layers” and minimize e.g. MSE loss
○ More advanced: minimize difference in attention maps between student/teacher

● Relation-based:
○ Correlations between feature maps: e.g. Gramian between two features

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 26


Distillation Algorithms
● Survey paper [2] has a thorough
overview of different distillation variations
● Adverserial: Teacher also acts as discriminator in
GAN to supplement training data to “teach” true
data distribution
● Cross-modal: Teacher trained on RGB distills
information to student learning on heat maps.
Unlabeled image pairs needed.
● Quantized distillation: Use full-precision network
to transfer knowledge to quantized network.
● … more in the paper!
Image from [2]

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 27


Further Reading
[1] G. Hinton et al., Distilling the knowledge in a neural network, Neural
Information Processing Systems (NeurIPS), 2015.
[2] J. Guo et al., Knowledge Distillation: A Survey, ArXiV preprint, 2021.

© 2022 Mohamed S. Abdelfattah ECE 5545: ML HW & SYS Lecture 7 - 28


29

You might also like

pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy