所有文章 > AI驱动 > 深度学习知识蒸馏的研究综述

深度学习知识蒸馏的研究综述

简介

常用的模型压缩方法有4种:知识蒸馏(Knowledge Distillation,KD)、 轻量化模型架构、 剪枝(Pruning)、 量化(Quantization)。

知识蒸馏是一种在深度学习中用于模型压缩和知识传递的技术。它通过将大型复杂模型(教师模型)的知识转移给小型简单模型(学生模型),从而优化学生模型的性能。

这种方法被广泛应用于降低模型复杂性和减少计算资源需求。知识蒸馏是通过将教师模型的输出概率分布传递给学生模型,并使用软标签作为学生模型的训练目标来实现的。蒸馏可以通过最小化软标签和学生模型输出之间的交叉熵损失来优化。这种方法已经在各种任务和数据集上取得了显著的成功,包括图像分类、自然语言处理和语音识别。

图1为知识蒸馏的发展历程和各个时期较为代表性的工作。

图1知识蒸馏发展历程

知识蒸馏面临的挑战主要可以分为:模型问题包括教师模型和学生模型不匹配、深层模型和浅层模型之间的容量差距等;成本问题包括模型训练过程对超参数较为敏感以及对损失函数依赖较大等;可解释性不足则是指知识蒸馏的原理解释主要是基于黑盒模型,难以解释模型的决策过程和内部机制。

该综述主要贡献包括:

1)对知识的分类进行了细化,增加了中间层知识、参数知识、图表示知识,完整地涵盖了目前知识的全部形式;

2)以表格的方式对不同方法之间的优缺点、适用场景等进行详细的总结对比,便于掌握;

3)遵循了主题式分类原则,不仅分析了单篇文献,还分析相关领域中知识蒸馏的重要研究,并对知识蒸馏的学习目的、原理和解释、发展趋势等方面做了较为全面的阐释。

图2是知识蒸馏模型的整体结构,由一个多层的教师模型和学生模型组成,教师模型主要负责向学生模型传递知识。

图2 知识蒸馏教师学生模型结构流程图

此外,本文对知识蒸馏相关研究进行了总结,主要从知识传递形式、学习的方式、学习的目的、交叉领域、主要应用等方面对其进行分类,其分类框架如图3所示,具体内容将在后续的文章中展开。

图3 知识蒸馏整体分类框架

知识传递形式

根据知识在教师-学生模型之间传递的形式可以将知识蒸馏方法归类为标签知识、中间层知识、参数知识、结构化知识和图表示知识。

不同知识传递形式下的蒸馏方法的形式化表示及其相关解释整理为表1所示:

表1 不同知识传递形式下的蒸馏方法形式化表示对比表

相关的优缺点和实验对比,见表2~表3所示:

表2 不同知识形式的代表性蒸馏方法在CIFAR100数据集上实验结果

表3 不同“知识”表达形式的优缺点

标签知识

标签知识是指通过训练好的模型对数据集进行预测得到的标签信息,通常也被称为“暗知识”。标签知识方法简单通用,易于实现,适合分类、识别、分割等几乎所有任务。但是,标签知识也存在知识单一、依赖于损失函数的设计且对参数敏感等缺点。此外,标签知识中也包含了很多不确定信息,这些信息通常反映了样本间的相似度或干扰性、样本预测的难度。因此,标签知识通常提供的信息十分有限且有相对的不确定性,但它仍然是基础蒸馏方法研究的重点和热点之一,因为其与传统的伪标签学习或者自训练方法有着密切的联系,这实际上为半监督学习开辟了新的道路。标签知识是各种任务中知识蒸馏的基础之一,适用于安全隐私要求相对不高的场景。

中间层知识

中间层知识是指教师模型中间层的特征作为学生模型的目标,相比标签知识蒸馏更加丰富,大大提高了传输知识的表征能力和信息量,有效提升了蒸馏训练效果。中间层知识所表达的是深度神经网络的中间层部件所提取出的高维特征,具有更具表征能力的特征知识。中间层知识可以提高传输知识的表征能力和信息量,有效提升蒸馏训练效果。但是不同架构的教师学生模型的中间层知识表征空间通常难以直接匹配基于中间层知识的蒸馏方法在实践中通常需要考虑教师和学生模型的网络结构,可以将其分为同构蒸馏和异构蒸馏两种情况,如图4所示,同构知识蒸馏(a)中教师和学习模型具有相同的架构,层与层,块与块之间对应,可直接蒸馏;异构知识蒸馏(b)中教师模型和学生模型各个层或块不能完全对应需要通过桥接模块来实现蒸馏。

图4 同构-异构蒸馏知识迁移结构图

参数知识

参数知识是指直接利用教师模型的部分训练好的参数或网络模块参与蒸馏训练,它通常无法作为一个独立的方法,而是与其他蒸馏方法结合使用。目前存在两种形式的参数知识蒸馏方法:教师平均法作为一种稳定训练过程,可以通过对教师模型的多次训练得到多个教师模型,然后将这些教师模型的参数进行平均得到一个更加稳定的教师模型;模块注入法则是将教师模型的某些模块直接注入到学生模型中,以提高学生模型的性能。

结构化知识

结构化知识的传递可以通过两种方式实现:一是直接将教师模型的结构信息复制到学生模型中,二是通过一些规则或算法将教师模型的结构信息转化为学生模型的结构信息。结构化知识的传递可以提高学生模型的泛化能力和可解释性,但也存在一些挑战,如教师模型和学生模型的结构不匹配、结构信息的复杂性等。结构化知识在深度学习中的应用非常广泛,可以用于图像分类、目标检测、自然语言处理等领域。例如,在图像分类任务中,教师模型可以学习到不同类别之间的关系,将这些关系传递给学生模型可以帮助学生模型更好地理解不同类别之间的区别和联系。

传统的知识蒸馏(a)主要是在特征上直接蒸馏;结构化知识蒸馏(b)在特征之上构建特征之间的结构关系(如距离和角度)),两者的对比如图6所示:

图5 传统知识特征与结构化知识特征对比

图表示知识

图表示知识是指将特征向量映射至图结构来表示其中的关系,以满足非结构化数据表示的学习需求。图表示知识的传递可以通过两种方式实现:一是直接将教师模型中的图表示知识复制到学生模型中,二是通过一些规则或算法将教师模型中的特征向量转化为学生模型中的图表示知识。图表示知识的传递可以提高学生模型的泛化能力和可解释性,但也存在一些挑战,如图表示知识的复杂性、图结构的匹配问题等。

使用图表示知识的蒸馏方法主要集中于两类场景:一是从经典深度神经网络中提取特征的图结构化关系表示知识,二是图神经网络(Graph Neural Networks,GNN)上的知识蒸馏。

图6展示了图表示知识与图表示知识蒸馏示意图,其中,图表示知识(a)通常构建成节点和边的连接形式,而图表示知识蒸馏(b)需要建立在边表示的节点关系或局部图结构上。

图6 图表示知识与图表示知识蒸馏示意图

学习方式

类似于人类教师和学生间的学习模式,神经网络的知识蒸馏学习方式也有着多种模式。如离线蒸馏、在线蒸馏、自蒸馏、无数据蒸馏、多模型蒸馏和特权蒸馏。图7为知识蒸馏的三种基本学习方式分类结构示意图(T 为教师模型,S 为学习模型,下同)。

图7 学习方式分类结构示意图

不同蒸馏方法的优缺点如表4所示:

表4不同蒸馏方法的优缺点比较

离线蒸馏

离线蒸馏是指教师模型和学生模型分别独立训练,学生模型只使用教师模型的输出作为标签进行训练。离线蒸馏的优点是灵活可控、易于操作、成本较低,但缺点是无法满足多任务、多领域任务。离线蒸馏主要适用于单任务学习,安全隐私要求相对不高,教师模型可访问的场景。

在线蒸馏

在线蒸馏是指教师模型和学生模型同时参与训练和参数更新。在线蒸馏的优点是能够满足多任务、多领域任务,能够实时调整教师模型的知识提炼过程,但缺点是计算量大、时间成本高。在线蒸馏主要适用于多任务学习、安全隐私要求较高、教师模型无法访问的场景。在线蒸馏学习模式有互学习、共享学习和协同学习。

互学习。互学习的特点是将两个或多个学生模型一起训练并将他们的输出知识作为互相之间的学习目标。互学习的方法包括两个学生模型之间互相学习、多个学生模型互学习等,它们在不同的场景中都有着广泛的应用。互学习的优势在于模型之间可以相互促进实现互补。

共享学习。共享学习在多个训练模型中需要通过构建教师模型来收集和汇总知识,并将知识反馈给各个模型,以达到知识共享的目的。与互学习不同,共享学习的模型之间没有直接的相互作用,而是通过教师模型来进行知识的传递和共享。共享学习的方法包括分层共享、分支共享等。

协同学习。同学习类似于互学习,主要是在任务上训练多个独立的分支后实现知识集成与迁移并实现学生的同时更新。与互学习不同的是,协同学习的模型之间没有直接的相互作用,而是通过任务的分支来进行知识的传递和共享。协同学习的方法包括分支协同、任务协同等。

自蒸馏

自蒸馏学习是指学生模型不依赖于外在模型而是利用自身信息进行蒸馏学习。自蒸馏的优点是不需要预先训练大型教师模型,能够在没有教师模型指导的条件下达到学生模型性能的自我提升,但缺点是需要较长的训练时间和更多的计算资源。自蒸馏主要适用于单任务学习、教师模型无法访问的场景。

无数据蒸馏

无数据蒸馏是指在没有训练数据的情况下,通过对教师模型的分析和理解,直接将其知识传递给学生模型的一种蒸馏方法,也叫零样本蒸馏。这种方法可以在不需要额外标注数据的情况下,提高模型的泛化能力和鲁棒性。无数据蒸馏的优点在于不需要额外的标注数据,可以节省时间和成本。但是需要注意的是,无数据蒸馏的效果可能会受到已有模型的质量和输出的影响。

图无数据蒸馏需要通过噪声合成等效样本同时将知识传递给学生模型,传统知识蒸馏模型和无数据知识蒸馏的结构对比如图8:

图8 传统知识蒸馏模型和无数据知识蒸馏的结构对比

多模型蒸馏

多模型蒸馏是指在蒸馏过程中有多个模型参与,各自集成其他模型输出的知识后进行学习。这种方法可以提高模型的鲁棒性和泛化能力,同时也可以减少过拟合的风险。值得注意的是,多模型蒸馏需要更多的计算资源和时间,因此需要在实际应用中进行权衡。可分为多教师模型和集成学习的多模型蒸馏方式。

多教师蒸馏。多教师蒸馏的研究重点在于设计合适的知识组合策略用于指导学生,学习多个教师的优点而摒弃不足。多教师蒸馏对于多任务、多模态学习等有很重要的指导意义,可以解决传统端到端训练方式面临的许多困难。

集成学习。集成学习类似于多教师蒸馏,关键在于多个模型的知识集成策略的设计,使其达到优势互补的效果。不同的是,集成学习没有严格意义上的教师模型参与,所有学生模型都同时学习和更新参数。并且,它通常采用多个完全同构的模型,因此对中间层特征的利用度很高。

特权蒸馏

特权蒸馏主要用于一些隐私保护的场景,教师模型可以利用特权信息,而学生模型可以间接地通过蒸馏学习获得这些信息,从而提升学生的学习效果,降低训练难度。特权蒸馏的知识传递形式主要是以软标签信息为主,学习形式没有严格约束。特权蒸馏的结构是特殊的,特权数据只能教师模型访问,学生模型无法直接访问,学生模型需要通过教师模型来学习,如图9所示。特权蒸馏方法的实现需要考虑如何保护特权信息的安全性,同时也需要考虑如何提高知识的传递效率和学生模型的泛化能力。

图9 特权蒸馏结构

学习目的

模型压缩

模型压缩是知识蒸馏提出的最初目的,它可以通过减少模型的参数数量、计算复杂度等方式来提高模型的效率和泛化能力。常见的模型压缩方法包括剪枝、量化、低秩分解、高效结构设计以及知识蒸馏等,图10展示三种主要模型压缩方法的原理示意图。这些方法可以单独使用,也可以结合使用,以达到更好的压缩效果。模型压缩在实际应用中具有广泛的应用前景,可以帮助深度学习模型在移动设备、嵌入式设备等资源受限的环境下实现高效的计算和预测。

图10 三种主要模型压缩方法的原理示意图(箭头左边为原始模型,右侧为压缩模型)

跨模态/跨领域

跨模态数据的存在形式称为模态,它可以是不同领域的数据,如视觉、文本、语音等。跨模态学习可以建立不同数据之间的关系,从而使得学习效果得到改进。同时,跨领域数据也是一种常见的数据形式,它可以是不同领域的数据,如医疗、金融、交通等。跨领域学习可以将不同领域的知识进行迁移,从而提高模型的泛化能力和效率。跨领域学习可以在不同领域之间共享知识,从而提高模型的性能和应用效果。跨模态/跨领域学习在深度学习中具有广泛的应用前景,可以帮助深度学习模型更好地理解和学习任务,从而提高模型的性能和应用效果。

跨模态 / 跨领域知识蒸馏模型结构如图11所示:

图11 跨领域和跨模态模型结构对比

隐私保护

传统的深度学习模型很容易受到隐私攻击,例如攻击者可以从模型参数或目标模型中恢复个体的敏感信息。因此,出于隐私或机密性的考虑,大多数数据集都是私有的,不会公开共享。特别是在处理生物特征数据、患者的医疗数据等方面,而且企业通常也不希望自己的私有数据被潜在竞争对手访问。因此,模型获取用于模型训练优质数据,并不现实。知识蒸馏可以通过教师学生结构的知识蒸馏来隔离的数据集的访问,让教师模型学习隐私数据,并将知识传递给外界的模型。例如,Gao等人提出的知识转移结合了隐私保护策略,这个过程中教师模型访问私有的敏感数据并将学习到的知识传递给学生,而学生模型不能公开获取数据但是可以利用教师模型的知识来训练一个可以公开发布的模型,以防止敏感的训练数据直接暴露给应用。因此,知识蒸馏是一种有效的隐私保护方法,可以帮助深度学习模型在保护隐私的同时实现高效的计算和预测。

持续学习

持续学习是指一个学习系统能够不断地从新样本中学习新的知识,并且保存大部分已经学习到的知识,其学习过程也十分类似于人类自身的学习模式。但是持续学习需要面对一个非常重要的挑战是灾难性遗忘,即需要平衡新知识与旧知识之间的关系。知识蒸馏能够将已学习的知识传递给学习模型实现“知识迁移”,从而在持续学习中起到重要的作用。因此,知识蒸馏是一种有效的持续学习方法,可以帮助深度学习模型在不断学习新知识的同时保留旧知识,从而提高模型的泛化能力和效率。

交叉领域

生成对抗网络

生成对抗网络(GAN)是一种深度学习模型,它由两个神经网络组成:生成器和判别器。生成器的目标是生成与真实数据相似的假数据,而判别器的目标是区分真实数据和假数据。通过不断地训练,生成器可以逐渐生成更加逼真的假数据,而判别器也可以逐渐提高对真假数据的判别能力。生成对抗网络在图像生成、图像修复、图像转换等方面具有广泛的应用,是深度学习领域的一个重要研究方向。知识蒸馏结合GANs压缩还存在着不易训练、不可解释等方面的挑战。

图12展示了生成对抗网络结合知识蒸馏结构示意图:

图12 生成对抗网络结合知识蒸馏结构示意图(T 为教师模型,S为学生模型,D为生成器,G 为判别器

强化学习

强化学习(LR)又称为增强学习,它通过智能体与环境的交互来学习最优的行为策略,如图13所示。在强化学习中,智能体通过观察环境的状态,采取相应的行动,并根据环境的反馈获得奖励或惩罚。通过不断地试错和学习,智能体可以逐渐学习到最优的行为策略,从而实现任务的最优化。强化学习在游戏、机器人控制、自然语言处理等领域具有广泛的应用,是深度学习领域的一个重要研究方向。

图13 强化学习原理图(智能体在环境中根据观察的状态作为决策,采取相应的行为并期望获得最大的奖励)

知识蒸馏与深度强化相结合的过程有两种方式,策略蒸馏和双策略蒸馏,深度强化教师模型将经验值存到记忆重播池中,学生模型从策略池中学习教师模型的经验.双策略模型的两个模型从环境中学习经验并互相蒸馏知识。如图14所示:

图14 强化学习中的知识蒸馏示意图

元学习

元学习(Meta Learning)的目标是学习如何学习。元学习的核心思想是通过学习一些基本的学习算法或策略,来快速适应新的任务或环境。元学习可以帮助机器学习模型在少量样本的情况下快速适应新的任务,从而提高模型的泛化能力。近年来,元学习在少样本分类、强化学习等领域得到了广泛的应用和研究。

元学习知识蒸馏结构如图15所示:

图15 元学习知识蒸馏结构图(在教师和学生模型中构建 “元知识”用于辅助学生训练)

知识蒸馏结合的元学习作为小样本环境下提高性能的手段,在知识迁移过程中也会面临着一些挑战,诸如过拟合、结构不匹配、新旧任务不关联等问题。

自动机器学习

自动机器学习(AutoML)是通过自动化特征工程、模型构建和超参数优化等过程,来实现机器学习的自动化。AutoML可以帮助非专业人士快速构建和优化机器学习模型,从而降低了机器学习的门槛。在AutoML中,神经结构搜索(NAS)和超参数优化(HPO)是两个重要的技术方向。NAS通过搜索最优的神经网络结构来提高模型的性能,而HPO则是通过自动化搜索最优的超参数组合来提高模型的性能。,NAS结合知识蒸馏的过程中,还有一些需要解决的挑战的难题,包括结构不匹配、搜索空间复杂、鲁棒性不足等问题。AutoML在图像分类等计算机视觉领域有着广泛的应用。

传统模型学习与自动机器学习对比如图16所示:

图16 传统模型学习与自动机器学习对比图

自监督学习

自监督学习(SSL)是一种预训练微调的方法,它通过构建辅助任务来训练模型,并将得到的预训练模型通过微调的方式应用于下游任务。监督学习和自监督学习蒸馏结构对比如图17所示。自监督学习的核心思想是利用大量的无标签数据来训练模型,从而提高模型的泛化能力。自监督学习可以帮助机器学习模型在少量标签数据的情况下快速适应新的任务,从而降低了数据标注的成本。但是自监督学习的缺点在于学习辅助任务和目标任务时只能使用同构模型或者其中的一部分,这也导致了目前绝大部分自监督学习的方法在预训练和微调时都是使用的相同架构。

图17 监督学习和自监督学习蒸馏结构对比图(传统的监督学习的蒸馏在标签数据集上构建预训练模型(标签任务),而自监督学习蒸馏则是在无标签数据集上训练并‘总结’出知识(辅助任务),用于目标模型的训练。

主要应用

计算机视觉

应用知识蒸馏的视觉研究主要集中在视觉检测和视觉分类上。视觉检测主要有目标检测、人脸识别、行人检测、姿势检测;而视觉分类的研究热点主要是语义分割,如表5所示。另外,视觉中还有视频分类、深度估计和光流/场景流估计等。

表5 计算机视觉主要蒸馏方法应用与对比

注:‘A’表示离线蒸馏,‘B’表示在线蒸馏,‘C’表示自蒸馏,‘D’表示无数据蒸馏,‘E’表示多模型蒸馏,‘F’表示特权蒸馏;‘L’表示标签知识,‘I’表示中间层知识,‘P’表示参数知识,‘S’表示结构知识;‘M’表示模型压缩,‘K’表示跨模态/领域,‘H’表示隐私保护,‘J’表示持续学习,下同。

自然语言处理

结合知识蒸馏较为广泛的自然语言处理(NLP)任务主要有机器翻译(Neural Machine Translation, NMT),问答系统(Question Answer System, QAS)等领域。表6列举了知识蒸馏结合机器翻译和问答系统的代表性的研究工作。另外,BERT模型近年来被广泛应用于NLP的各个领域,表6中一并列举。

表6 自然语言处理的主要蒸馏方法应用与对比

推荐系统

推荐系统(Recommender Systems, RS)被广泛应用于电商、短视频、音乐等系统中,对各个行业的发展起到了很大的促进作用。推荐系统通过分析用户的行为,从而得出用户的偏好,为用户推荐个性化的服务。因此,推荐系统在相关行业中有很高的商业价值。深度学习应用于推荐系统也面临着模型复杂度和效率的问题。表7中整理了目前关于推荐系统和知识蒸馏工作的相关文献,可供参考。

表7 推荐系统中的主要蒸馏方法应用与对比

文章转自微信公众号@算法进阶