知识蒸馏
知识蒸馏(Knowledge Distillation,KD)是想将复杂模型(teacher network)中的暗知识(dark knowledge)迁移到简单模型(student network)中。一般来说,老师网络具有强大的能力和表现,而学生网络则更为紧凑。通过知识蒸馏,希望学生网络能尽可能逼近亦或是超过老师网络,从而用复杂度更小的模型来获得类似的预测效果。Hinton在Distilling the Knowledge in a Neural Network一文中首次提出了知识蒸馏的概念,通过引入老师网络的软目标(soft targets)以诱导学生网络的训练。
在具体了解知识蒸馏的具体流程前,我们首先回顾一下四个常见的损失函数:Softmax、log_softmax、NLLLoss和CrossEntropy。
- Softmax:Softmax广泛的应用于分类问题中,它输入一个实数向量并返回一个表示类别可能性的概率分布,其中每个元素都是非负的,且所有元素总和为1。
Softmax
(
x
)
=
exp
(
x
i
)
∑
j
exp
(
x
j
)
\text{Softmax}(x) = \frac{\exp(x_{i})}{\sum_{j}\exp(x_{j})}
- log_softmax:即对softmax处理后的结果做一次对数运算
- NNLLoss(negtive log likelihood losss):若
x
i
=
[
q
1
,
q
2
,
.
.
.
,
q
N
]
x_{i}=[q_{1},q_{2},…,q_{N}]
为网络的第
i
i
个输出,
y
i
y_{i}
为真实标签,那么有:
f
(
x
i
,
y
i
)
=
−
q
y
i
f(x_{i},y_{i})= -q_{y_{i}}
- CrossEntropy:对于
N
N
分类问题一个特定的样本,已知其真实标签,CrossEntropy的计算公式为:
c
r
o
s
s
_
e
n
t
r
o
p
y
=
−
∑
k
=
1
N
(
p
k
log
q
k
)
cross\_entropy=-\sum_{k=1}^{N}\left(p_{k} \log q_{k}\right)
其中
p
p
表示真实值,在这个公式中是one-hot形式;q是经过softmax计算后的结果,
q
k
q_k
为神经网络认为该样本为第
k
k
类的概率。
若该样本的真实标签为
y
y
,则交叉熵的公式可变形为:
c
r
o
s
s
_
e
n
t
r
o
p
y
=
−
∑
k
=
1
N
(
p
k
log
q
k
)
=
−
l
o
g
q
y
cross\_entropy=-\sum_{k=1}^{N}\left(p_{k} \log q_{k}\right)=-log \, q_{y}
Softmax函数在多分类问题中预测样本的类别时通常会有分布尖锐的问题,即输出的概率分布只在某一类别上分配很大的概率值,这就可能导致模型把注意力集中于较大的概率上,而忽略了值较小的概率,从而使得模型的泛化性能下降。而为了使softmax输出的概率分布更加的平缓,改造后的softmax函数中加入了温度参数T,即
q
i
=
exp
(
z
i
/
T
)
∑
j
exp
(
z
j
/
T
)
q_{i}=\frac{\exp(z_{i}/T)}{\sum_{j} \exp(z_{j} /T)}
其中
T
T
的值越大,softmax输出的概率分布就越平缓,它所包含的分类信息就越多。
下面通过一个简单的例子来说一下知识蒸馏是如何进行的。如下图所示,此时我们需要对图片进行分类,即判断它是Dog、Cat和Car中的哪一类,对应的类标签采用one-got向量表示。模型此时存在老师网络和学习网络两个模型,其中学生网络更为简单,因此如果直接使用学生网络进行分类时,它会给出判断图片具体属于哪一类,是狗是猫还是汽车。学生网络若只是单纯的进行训练,模型希望预测的类标签尽可能和真实标签一致,通常称为硬目标(hard target),对应的损失函数记为
L
o
s
s
h
a
r
d
=
(
p
,
q
)
Loss_{hard} = (p,q)
。但由于学生网络本身较为简单,因此分类的效果通常来说并不好。因此,此时就需要老师网络的输出信息来提供指导。
老师网络同样采用相同的数据进行训练,但它的softmax输出记为
q
′
q’
,
q
′
q’
的概率分布并不平缓。因此,
q
′
q’
需要经过一个“蒸馏”的过程使得它的分布更加的平缓,这样的方式得到的结果称为软目标(soft target)。因此,如果有老师网络的帮助,学生网络的输出分布也应和
q
′
′
q”
接近,即希望它也可以学习到老师的知识,那么对应的损失函数记为
L
o
s
s
s
o
f
t
=
(
q
′
,
q
′
′
)
Loss_{soft} = (q’,q”)
。
所以实际上学生网络对应的损失包含两部分:
L
O
S
S
=
L
o
s
s
h
a
r
d
+
L
o
s
s
+
s
o
f
t
LOSS = Loss_{hard} + Loss+{soft}
以上内容学习自B站不相识不打同学的《神经网络知识蒸馏》,本人做的只是总结和转述,十分感谢,有兴趣的可以自行观看~
知识蒸馏广泛的应用了深度学习的各个领域,这里只是做简单的介绍,有兴趣的可以阅读更多相关的文章。
Like What You Like: Knowledge Distill via Neuron Selectivity Transfer
Training Shallow and Thin Networks for Acceleration via Knowledge Distillation with Conditional Adversarial Networks
Deep Mutual Learning
Born Again Neural Networks
知识蒸馏是什么?一份入门随笔
BERT等预训练模型两个重要的发展方向便是两个极端:不断地加数据量、算力和模型的容量
但是这样的方式只能是大公司的游戏,而且就算有预训练的模型放出来,fine-tune后推断所需的时间的也是很长的。另一个方向便是在不损失模型性能的同时尽可能的减少模型训练所需的参数量,这时模型压缩就发挥了很大的作用,例如量化、权重剪枝和知识蒸馏等。而知识蒸馏作为模型压缩中一种重要的手段,研究人员便尝试将BERT和知识蒸馏进行结合,最后得到的模型容量大大的减少了,同时模型在下游任务的效果又不受太多的影响。
DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter
DistilBERT是huggingface公司发表于NIPS 2019 上的成果,它就是使用知识蒸馏技术来训练一个小型化的BERT,模型的大小减少了40%,虽然性能降低了约3%,但是推断的速度提升了60%。
它具体的做法如下:
- 利用Google原始的BERT作为老师网络
- 学生网络是在BERT-base的基础上只使用一半的层数
- 同时利用老师网络的输出做软目标和老师网络本身 hidden layer 的参数来同时提供指导信息训练学生网络
另外蒸馏过程涉及的损失项这里是三个,如下所示:
L
c
f
L_{cf}
:传统知识蒸馏中老师网络和学生网络的输出经过softmax后分布的交叉熵
L
m
l
m
L_{mlm}
:学生网络本身的softmax输出和真实标签的交叉熵
L
c
o
s
L_{cos}
:这个是新加的,计算的是老师网络和学生网络每个 hidden layer 输出之间的余弦相似度
Distilling Task-Specific Knowledge from BERT into Simple Neural Networks
DistilBERT是将小型的多层Transformer来做学生网络,这样蒸馏过程所需的算力仍然是不晓得,因此,本文学生网络选择了更为简单的BiLSTM。其中BiLSTM针对于单句分类和句子对分类分别使用了两种类别的模型。
针对于单句分类问题,作者是将句子所有的词作为BiLSTM的输入,然后将前向和后向的状态进行拼接后通过全连接层进行分类。
针对于句子对的分类,首先分别得到两个句子的表示向量,然后做拼接后再进行分类,拼接的公式为:
f
(
h
1
,
h
2
)
=
[
h
1
,
h
2
,
h
1
⊙
h
2
,
∣
h
1
−
h
2
∣
]
f\left(h_{1}, h_{2}\right)=\left[h_{1}, h_{2}, h_{1} \odot h_{2},\left|h_{1}-h_{2}\right|\right]
。
除了两种类型的学生网络外,文中还提出了三种数据增广的方法:
- MASK:使用
[
M
A
S
K
]
[MASK]
随机替换某个词 - POS-guided word replacement:使用命名实体标签一样的词进行替换
- n-gram:从
{
1
,
2
,
.
.
.
,
5
}
\{1,2,…,5\}
随机选择数做为 n 进行替换 n-gram 的替换
实验效果如下所示:
还有比较有代表性的就是Google的AlBERT和华为的TinyBERT,有兴趣的可以浏览以下的博文。
NLP中的预训练语言模型(四)—— 小型化bert(DistillBert, ALBERT, TINYBERT)
加速BERT模型有多少种方法?从架构优化、模型压缩到模型蒸馏,最新进展详解!
BERT 瘦身之路:Distillation,Quantization,Pruning