今天小编分享的科学经验:模型知识蒸馏新SOTA!告别传统散度蒸馏,欢迎阅读。
用大模型 " 蒸馏 " 小模型,有新招了!
甚至能在不同类型和架构的 LLMs(大语言模型)上达到新 SOTA。
这就是来自中科大、腾讯优图实验室提出的一种基于 Sinkhorn 距离的知识蒸馏方法,能把大的、复杂的教师模型的知识 " 蒸馏 " 到小的、简单的学生模型中,从而让小模型也能像大模型一样工作。
之所以提出新方法,主要是现有的知识蒸馏(KD)方法都有各自的局限性:
当两个模型的输出差异较大时,它们就不太管用了。
KL 散度:会导致学生模型的输出变得过于平滑,失去了区分性;
RKL 散度:会让学生的输出变得太简单,不能很好地模仿教师模型;
JS 散度:会让学生模型低估稀有事件的概率;
而基于 Sinkhorn 距离的新方法能更准确地衡量和缩小教师模型和学生模型之间的差异,从而提高了学生模型的性能。
此外,研究还提出了一种基于批量的重构方法,从而在高维空间中捕捉跨样本分布的几何复杂性。
最终,通过在两个流行的自然语言处理测试集(GLUE 和 SuperGLUE)上测试,新方法在编码器、编码器 - 解码器以及解码器等不同架构的所有类型 LLMs 上均优于当前的最先进方法。
研究背景
知识蒸馏的提出是为了通过对齐教师模型的软目标(例如输出 logits 和中间层表示)来将教师模型内在固有的知识传递给学生模型。
给定训练集中的一个样本 x_i 及其真实标签 ∈ ℝ,来自教师模型和学生模型的输出 logits ∈ ℝ和 ∈ ℝ可以由以下式子得到:
其中为 softmax 函数, τ 是温度参数 , d 是输出 logits 的维度。基于 logit 的知识蒸馏的目标是 σΤ 最小化测量散度 J(,)以实现知识传递。
研究动机
现有研究已经尝试使用 Kullback-Leibler(KL)散度、反 Kullback-Leibler(RKL)散度和 Jensen-Shannon(JS)散度。
所有这些度量都可以被视为f- 散度度量的变体,而 f- 散度度量在量化缺乏实质性交集的任何两个分布时都存在明显局限性。
此外,每种度量都有其自身的缺陷:
KL 蒸馏会导致模式平均,使学生学习到一个过于平滑的分布,涵盖了教师的整个支撑集;
RKL 会引起模式塌陷,学生仅关注教师分布中高概率的显著区網域,而忽视了其余部分;
JS 蒸馏会产生模式低估,由于惩罚不足,学生会低估稀有事件的概率。
为了解决传统散度度量的问题,研究做出了以下贡献:
提出了一种知识蒸馏方法 SinKD,采用 Sinkhorn 距离作为散度度量。它不仅解决了 KL、RKL 和 JS 散度在极端场景下的局限性,而且避免了计算 Wasserstein 距离的负担。
深入探讨了 Sinkhorn 距离的性质,并将 SinKD 重新 reformulated 为 batch-wise OT,扩展了它在 NLP 任务中的适用性。
通过大量的可比性、有效性和泛化性实验证明了 SinKD 相较于目前最先进的方法的优越性。并为实际应用提供了使用 SinKD 进行蒸馏的实用指导方针。
传统散度度量的缺陷
首先,KL 散度是不对称的,表现为 JKL(,)≠ JKL(,),这一性质违反了距离度量的对称性特性,从而引入了一些不一致性。
其次,由于使用 KL 损失进行优化,学生模型试图对教师模型的多模态分布进行平均化,从而导致对这些模式的拟合不足。这被称为 " 模式平均问题 "(mode-averaging problem)。
因此,学生模型无法捕获数据中的所有关键模式,最终影响模型性能。
第三,KL 散度对应的是一个非平滑函数,这为优化过程带来了挑战。
与 KL 散度一样,具有内在的不对称性,从而导致在捕捉分布差异时出现不一致性。
此外,优化的学生模型倾向于仅关注教师分布中概率较高的事件,这被称为" 模式崩塌问题 "(mode-collapsing)。
如果教师对某个事件赋予零概率,学生模型也被迫做出相同的预测。
其中 m = 1/2(+)受制于非平滑性,JS 损失在优化过程中面临挑战。
另外,由于 JS 损失在低概率区網域的匹配上惩罚不足,学生模型可能会过度低估稀有事件的概率。
对于分布之间重叠较少甚至完全不重叠的情况退化为常数时,还存在梯度消失的风险。
最优传输距离的优势
Wasserstein 距离通过求解两个分布之间的最优传输计划来量化它们的差异。
直观地看,它可以被认为是将一个分布(即学生的 logits 分布)转换为另一个分布(即教师的 logits 分布)所需的最小 " 代价 ",其中 " 代价 " 可以定义为被移动的质量与移动距离的乘积。
与传统的散度度量相比,Wasserstein 距离作为蒸馏的成本函数更为合理,因为它不依赖于对被测量分布的隐式假设。此外,它几乎处处可微,从而便于优化。
另外,现有的散度度量只能独立处理每个样本对,进行逐一 logit 的匹配,对于一批样本,这些方法无法定位来自同一样本的教师和学生的 logits 对,从而无法实现整体距离的最小化。
由于计算 Sinkhorn 距离的过程可以实现来自同一样本的两个输出之间的精确逐元素匹配,研究提出了" 批量化 " 的 SinKD 方法(batchified SinKD)。
通过这种方式,即使通过低维观测,也能够捕捉复杂且隐式分布的几何结构。
方法介绍
这里简要介绍 SinKD 的核心方法,详细推导过程可以参阅原论文。
批量重构的 Sinkhorn 距离
对于本问题,Wasserstein 距离的定义如下:
其中,
Wasserstein 距离本身在解析计算上存在困难,其计算成本对于蒸馏大型语言模型来说高得难以承受。
在这种情况下,研究使用Sinkhorn 距离作为一种高效的近似方法。它不仅保留了 Wasserstein 距离的所有优点,同时也大大缓解了其在在线蒸馏中所面临的成本问题。
Sinkhorn 距离的定义如下:
逐样本蒸馏将每个实例独立处理,但忽略了一个批次样本中的整体趋势。
研究摒弃了仅在每对教师 - 学生样本对上工作的逐样本知识蒸馏方法,转而在教师和学生样本组上执行知识蒸馏。
一个包含 b 个样本的批次会整体参与散度度量。通过批量重构,这种方法有效地增加了 " 观测 " 空间的维度,特别是在 d 远小于 b 的情况下表现尤为显著。
对于常规分类任务的蒸馏,研究使用如下 "batchified" 代价函数:
并初始化如下候选传输矩阵:
通过重构和化简,研究可以使用如下迭代式计算最优传输矩阵(具体推导过程参见论文):
由此,可以算出最优传输距离:
SinKD 的变体
拓展到回归任务:对于回归任务,模型不会为每个选项生成概率,而是仅生成一个标量(d=1)。对于一个包含 b 个样本的批次,教师模型和学生模型的输出分别表示为 ∈ ℝ bx1 和 ∈ ℝ bx1。
为了计算教师和学生之间的批量化 Sinkhorn 距离,成本矩阵的元素由 " 批量化 " 回归输出之间的绝对差值确定:
拓展到独热标签微调:SinKD 方法也适用于仅有独热(one-hot)标签且无法获取教师模型 logits 的模型微调。
在这种情况下,可以将单热标签视为 " 假想 " 的单热教师模型的 logits。由于单热 logits 中以零为主,传统的散度度量(例如 KL 散度)在处理这种极端情况下的散度量化时显得无能为力。
实验与分析
(1)数值结果。与基线和 SOTA 方法对比,论文方法在大部分任务上均取得了更好的性能。
(2)消融实验。得出的结论如下:
Sinkhorn 损失在所有损失中对学生模型的收益最大
批量化的 SinKD 优于逐样本的 SinKD
SinKD 超越了基于 f- 散度变体的蒸馏方法
(3)生成式大语言模型实验。SinKD 可以推广到生成式大语言模型,并在基于类 GPT 架构的模型的蒸馏上取得不俗的成绩表现。
但同时研究也观察到,蒸馏效果的影响会随着 PROMPT 模板的变化而改变。
这意味着,同样的任务設定下,更加合理的 PROMPT 设计能够更充分地利用教师模型的固有知识。
(4)可视化结果如下。
为了增强内在评估,研究还进行了以下附加分析:
隐藏状态的表示
注意力机制的模式
层级性能分析
(5)拓展到独热标签微调。与现有的散度度量方法(例如 KL 散度)不同,SinKD 方法还可以扩展用于使用独热标签 ( one-hot label ) 微调语言模型。
(6)拓展到计算机视觉领網域深度网络。SinKD 在所有测试的配置中均稳定地超越了所有基线方法。
总结
研究引入了 SinKD 以解决现有蒸馏方法的局限性。此外,作者们提出了基于批次的重构方法,以捕捉高维空间中样本分布的几何复杂性。最后,研究在各类任务、数据集和模型架构上进一步验证 SinKD 的有效性。
更多细节欢迎查阅原论文。
COLING 2024 会议论文:
https://arxiv.org/abs/2402.17110
IEEE TNNLS 期刊论文:
https://hal.science/hal-04803835
— 完 —
投稿请发邮件到:
标题注明【投稿】,告诉我们:
你是谁,从哪来,投稿内容
附上论文 / 项目主页链接,以及联系方式哦
我们会(尽量)及时回复你
点这里关注我,记得标星哦~
一键三连「分享」、「点赞」和「在看」
科技前沿进展日日相见 ~
>