作者投稿和查稿 主编审稿 专家审稿 编委审稿 远程编辑

计算机工程 ›› 2023, Vol. 49 ›› Issue (5): 97-104. doi: 10.19678/j.issn.1000-3428.0064139

• 人工智能与模式识别 • 上一篇    下一篇

基于彩票假设的软剪枝算法

马嘉翔, 宋晓宁   

  1. 江南大学 人工智能与计算机学院, 江苏 无锡 214122
  • 收稿日期:2022-03-09 修回日期:2022-05-01 发布日期:2022-05-25
  • 作者简介:马嘉翔(1996-),男,硕士研究生,主研方向为网络模型加速;宋晓宁(通信作者),教授、博士、博士生导师。
  • 基金资助:
    国家自然科学基金(61876072);国家社会科学基金(21&ZD166);江苏省自然科学基金(BK20221535)。

Soft Pruning Algorithm Based on Lottery Ticket Hypothesis

MA Jiaxiang, SONG Xiaoning   

  1. School of Artificial Intelligence and Computer Science, Jiangnan University, Wuxi 214122, Jiangsu, China
  • Received:2022-03-09 Revised:2022-05-01 Published:2022-05-25

摘要: 神经网络层数的不断增加使网络复杂度也呈指数级上升,导致应用场景受到限制。提出一种基于彩票假设的软剪枝算法实现网络加速。通过使用前一阶段的剪枝网络对其进行知识蒸馏来补偿的方法恢复错误参数,并在知识蒸馏的损失函数中加入稀疏约束来保持稀疏性。在此基础上,将当前阶段得到的剪枝网络与知识蒸馏得到的学生网络进行融合。在进行网络融合时,计算剪枝网络与学生网络的相似性,并通过设计特定的融合公式来突出相近的网络参数和抑制相离的网络参数,使得网络在剪枝率提高后仍然表现良好。在CIFAR-10/100数据集上对VGG16、ResNet-18和ResNet-56模型进行实验,结果显示:剪枝率为80%时,VGG16在CIFAR-10数据集上的分类精度下降0.07个百分点;剪枝率为60%时,ResNet-56在CIFAR-10数据集上的分类精度提升0.06个百分点;剪枝率为85%、95%和99%时,ResNet-18在CIFAR-100数据集上的分类精度仅下降1.03、1.51和2.04个百分点。实验结果表明,所提算法在提高网络剪枝率的同时仍能使其保持较高的精度,验证了算法的有效性。

关键词: 网络加速, 彩票假设, 全局剪枝, 稀疏蒸馏, 模型融合

Abstract: The increasing number of neural network layers exponentially increases the network complexity and limits its application scenarios.To solve this problem,this study proposes a soft pruning algorithm based on lottery ticket hypothesis.The pruning network of the previous stage is used to compensate for knowledge distillation.To maintain the sparsity in knowledge distillation,the wrongly-pruned parameters are recovered and sparse constraints are added to its loss function.Subsequently,the pruning network obtained at the current stage is integrated with the student network obtained through knowledge distillation.The similarity between the pruning and student networks during the network fusion is then calculated and a specific fusion formula is designed to highlight similar network parameters and inhibit discrete network parameters.Consequently,the network continues to perform well after the pruning rate is increased.The experimental results of VGG16,ResNet-18,and ResNet-56 models on CiFAR-10/100 dataset indicate the following:when the pruning rate is 80%,the classification accuracy of VGG16 in CIFAR-10 dataset decreases by 0.07 percentage points;when the pruning rate is 60%,the classification accuracy of ResNet-56 in CIFAR-10 dataset is improved by 0.06 percentage points;and when the pruning rates are 85%,95%,and 99%,the accuracy of ResNet-18 on CIFAR-100 dataset only decreased by 1.03,1.51,and 2.04 percentage points,respectively.This shows that the proposed algorithm can improve the pruning rate of the network while maintaining high accuracy,thus,proving the effectiveness of the proposed algorithm.

Key words: network acceleration, lottery ticket hypothesis, global pruning, sparse distillation, model fusion

中图分类号: