知识蒸馏(Knowledge Distillation)的Pytorch实现以及分析


       知识蒸馏(Knowledge Distillation)的概念由Hinton大神于2015年在论文《Distilling the Knowledge in a Neural Network》中提出,论文见:https://arxiv.org/abs/1503.02531。此方法的主要思想为:通过结构复杂、计算量大但是性能优秀的教师神经网络,对结构相对简单、计算量较小的学生神经网络进行指导,以提升学生神经网络的性能。论文中提出了“暗知识”这一概念,即:比如我们在识别一张猫猫的图片的时候,一个性能良好的神经网络经过softmax变换后的输出,在一般该向量中代表猫猫的位置会得到一个非常高的值,比如,0.9,而代表其它分类的值在传统的研究中就不那么受重视了。Hinton大神认为,其它位置得到的值能够为我们提供一些额外的信息,比如,在猫得到0.9的同时,识别为狮子的值可能因为相似的缘故给到了0.09,而识别为汽车的值则可能只有0.0001。在我的理解中,这种目标间的相似性,就是“暗知识”的本质。为了要放大这种“暗知识”所包含的信息,Hinton在传统的softmax函数中加入温度参数T,变为下式所示:

                                                                        

       那么,知识蒸馏的步骤分别为:

一、采用传统方式训练一个教师网络。

二、建立学生网络模型,模型的输出采用传统的softmax函数,拟合目标为one-hot形式的训练集输出,它们之间的距离记为loss1。

三、将训练完成的教师网络的softmax分类器加入温度参数,作为具有相同温度参数softmax分类器的学生网络的拟合目标,他们之间的距离记为loss2。

四、引入参数alpha,将loss1×(1-alpha)+loss2×alpha作为网络训练时使用的loss,训练网络。

       重点就在于将暗知识放大之后,让学生网络的暗知识去拟合教师网络的暗知识,同时由于教师网络会带有一定的bias,表现为教师网络在训练完成后,对训练集识别的正确率会高于测试集,所以加上loss1来减缓这种趋势,实际应用的时候,可以考虑将alpha首先设置的接近1,比如0.95啥的,来快速拟合教师网络,再逐步调低alpha的值,来确保网络的分类正确率,不过这只是理论上可行的,我也没试验就是了……

       那我们就开搞啦,首先是搭建教师网络,我这里选择的是resnet18,并且由于电脑训练速度的原因(渣机无力……)将网络中所有卷积核的数目减少了一半,训练集采用Cifar10,训练时对图像进行了padding之后随机裁剪以及随机水平翻转来加入噪声。优化器采用带动量项的SGD(lr=0.1, momentum=0.9, weight_decay=5e-4),训练200个epoch,其中在第100以及第150个epoch时将学习率除10,详细的代码见文章末尾的github地址好啦。训练完成后,网络对测试集的识别结果如下所示:

Accuracy of the network on the 10000 test images: 93.970000 %
Accuracy of plane : 97.727273 %
Accuracy of   car : 100.000000 %
Accuracy of  bird : 84.210526 %
Accuracy of   cat : 86.046512 %
Accuracy of  deer : 93.877551 %
Accuracy of   dog : 96.875000 %
Accuracy of  frog : 98.113208 %
Accuracy of horse : 93.750000 %
Accuracy of  ship : 95.833333 %
Accuracy of truck : 100.000000 %

       这结果当然并不算特别好,所以作为学生的网络,得选个效果比较差的,这样才能体现出教师的价值对吧(笑)。这里我们就简单的架一个三层卷积神经网络作为学生网络好啦,网络具体结构见github。还是使用cifar10经过相同的图像变换过程后,采用adam(lr=0.001)作为优化器对网络训练100个epoch,在完全相同的条件下训练四次,测试集识别结果分别如下,我们可以看到,这几次的训练结果平均一下大概在84%左右。

第一次训练结果:
Accuracy of the network on the 10000 test images: 84.350000 %
Accuracy of plane : 88.636364 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 78.947368 %
Accuracy of   cat : 72.093023 %
Accuracy of  deer : 83.673469 %
Accuracy of   dog : 81.250000 %
Accuracy of  frog : 94.339623 %
Accuracy of horse : 87.500000 %
Accuracy of  ship : 83.333333 %
Accuracy of truck : 93.103448 %
第二次训练结果:
Accuracy of the network on the 10000 test images: 83.870000 %
Accuracy of plane : 97.727273 %
Accuracy of   car : 90.625000 %
Accuracy of  bird : 63.157895 %
Accuracy of   cat : 76.744186 %
Accuracy of  deer : 91.836735 %
Accuracy of   dog : 81.250000 %
Accuracy of  frog : 84.905660 %
Accuracy of horse : 84.375000 %
Accuracy of  ship : 85.416667 %
Accuracy of truck : 96.551724 %
第三次训练结果:
Accuracy of the network on the 10000 test images: 84.760000 %
Accuracy of plane : 88.636364 %
Accuracy of   car : 96.875000 %
Accuracy of  bird : 68.421053 %
Accuracy of   cat : 72.093023 %
Accuracy of  deer : 83.673469 %
Accuracy of   dog : 84.375000 %
Accuracy of  frog : 90.566038 %
Accuracy of horse : 87.500000 %
Accuracy of  ship : 89.583333 %
Accuracy of truck : 86.206897 %
第四次训练结果:
Accuracy of the network on the 10000 test images: 84.240000 %
Accuracy of plane : 93.181818 %
Accuracy of   car : 90.625000 %
Accuracy of  bird : 81.578947 %
Accuracy of   cat : 74.418605 %
Accuracy of  deer : 77.551020 %
Accuracy of   dog : 81.250000 %
Accuracy of  frog : 83.018868 %
Accuracy of horse : 93.750000 %
Accuracy of  ship : 89.583333 %
Accuracy of truck : 82.758621 %

       接下来,因为之前看到网上有人说,教师网络本身在训练的时候,是有采用加噪数据进行训练的,所以它的输出的暗知识在理论上可能会包含有噪声项的信息,我们就先在不对数据集进行变换的情况下进行训练。这里我们选取alpha=0.95,T选取2和10分别训练两次,结果如下。我们可以看到,其训练的结果比之前的方法是要差的,这可能是因为学生网络还是直接过拟合了教师网络的输出,所以导致测试集正确率较低。

T=2第一次训练结果:
Accuracy of the network on the 10000 test images: 79.110000 %
Accuracy of plane : 90.909091 %
Accuracy of   car : 90.625000 %
Accuracy of  bird : 68.421053 %
Accuracy of   cat : 67.441860 %
Accuracy of  deer : 69.387755 %
Accuracy of   dog : 68.750000 %
Accuracy of  frog : 81.132075 %
Accuracy of horse : 78.125000 %
Accuracy of  ship : 85.416667 %
Accuracy of truck : 86.206897 %
T=2第二次训练结果:
Accuracy of the network on the 10000 test images: 76.720000 %
Accuracy of plane : 90.909091 %
Accuracy of   car : 96.875000 %
Accuracy of  bird : 60.526316 %
Accuracy of   cat : 62.790698 %
Accuracy of  deer : 73.469388 %
Accuracy of   dog : 59.375000 %
Accuracy of  frog : 77.358491 %
Accuracy of horse : 81.250000 %
Accuracy of  ship : 83.333333 %
Accuracy of truck : 79.310345 %
T=10第一次训练结果:
Accuracy of the network on the 10000 test images: 78.600000 %
Accuracy of plane : 93.181818 %
Accuracy of   car : 90.625000 %
Accuracy of  bird : 63.157895 %
Accuracy of   cat : 62.790698 %
Accuracy of  deer : 75.510204 %
Accuracy of   dog : 62.500000 %
Accuracy of  frog : 83.018868 %
Accuracy of horse : 78.125000 %
Accuracy of  ship : 89.583333 %
Accuracy of truck : 86.206897 %
T=10第二次训练结果:
Accuracy of the network on the 10000 test images: 76.550000 %
Accuracy of plane : 88.636364 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 73.684211 %
Accuracy of   cat : 67.441860 %
Accuracy of  deer : 75.510204 %
Accuracy of   dog : 62.500000 %
Accuracy of  frog : 86.792453 %
Accuracy of horse : 78.125000 %
Accuracy of  ship : 89.583333 %
Accuracy of truck : 75.862069 %

       最后是对图片进行了相应的变换加入噪声后,对学生网络进行训练,结果如下:

T=2第一次训练结果:
Accuracy of the network on the 10000 test images: 85.190000 %
Accuracy of plane : 93.181818 %
Accuracy of   car : 96.875000 %
Accuracy of  bird : 78.947368 %
Accuracy of   cat : 83.720930 %
Accuracy of  deer : 81.632653 %
Accuracy of   dog : 84.375000 %
Accuracy of  frog : 92.452830 %
Accuracy of horse : 75.000000 %
Accuracy of  ship : 87.500000 %
Accuracy of truck : 93.103448 %
T=2第二次训练结果:
Accuracy of the network on the 10000 test images: 84.490000 %
Accuracy of plane : 93.181818 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 73.684211 %
Accuracy of   cat : 76.744186 %
Accuracy of  deer : 85.714286 %
Accuracy of   dog : 78.125000 %
Accuracy of  frog : 81.132075 %
Accuracy of horse : 84.375000 %
Accuracy of  ship : 87.500000 %
Accuracy of truck : 89.655172 %
T=10第一次训练结果:
Accuracy of the network on the 10000 test images: 85.310000 %
Accuracy of plane : 100.000000 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 60.526316 %
Accuracy of   cat : 83.720930 %
Accuracy of  deer : 87.755102 %
Accuracy of   dog : 75.000000 %
Accuracy of  frog : 92.452830 %
Accuracy of horse : 87.500000 %
Accuracy of  ship : 93.750000 %
Accuracy of truck : 89.655172 %
T=10第二次训练结果:
Accuracy of the network on the 10000 test images: 85.370000 %
Accuracy of plane : 95.454545 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 76.315789 %
Accuracy of   cat : 74.418605 %
Accuracy of  deer : 85.714286 %
Accuracy of   dog : 78.125000 %
Accuracy of  frog : 88.679245 %
Accuracy of horse : 84.375000 %
Accuracy of  ship : 87.500000 %
Accuracy of truck : 89.655172 %

       虽然测试集的正确率具有一定程度的不确定性,我们还是可以看出,测试集正确率相比原始的训练方法有所提升。这也可以大致说明这种方法的有效性。当然,这种训练方式目前也产生了很多的变体,比如再生网络等等、

       最后是相关程序与训练完成的网络参数文件的github地址:https://github.com/PolarisShi/distillation