简单有效!差分Transformer竟能消除注意力噪声
Transformer 的强大实力已经在诸多大型语言模型(LLM)上得到了证明,但该架构远非完美,也有很多研究者致力于改进这一架构,比如 Reformer 和 Infini-Transformer。
今天我们又将介绍另一种新型 Transformer 架构:Differential Transformer(差分 Transformer,简称 Diff Transformer)。该架构来自微软研究院和清华大学,有四位共一作者:Tianzhu Ye、Li Dong、Yuqing Xia、Yutao Sun。
- 开源代码:https://github.com/Jaykef/ai-algorithms/blob/main/DIFF_Transformer.ipynb
- 论文标题:Differential Transformer
- 论文地址:https://arxiv.org/pdf/2410.05258
在 Hacker News 及 Twitter 等社交网络上,该论文都反响热烈,有网友表示差分 Transformer 提出的改进简单又美丽,而带来的提升又非常显著。
甚至已有开发者做出了差分 Transformer 的轻量实现!
那么差分 Transformer 弥补了原生 Transformer 的哪些问题呢?如下图所示,Transformer 往往会过度关注不相关的上下文,该团队将此称为注意力噪声(attention noise)。而差分 Transformer 则能放大对答案范围的注意力并消除噪音,从而增强上下文建模的能力。这就要用到该团队新提出的差分注意力机制(differential attention mechanism)了。
差分注意力机制可以消除注意力噪声,鼓励模型重点关注关键信息。该方法有些类似于电气工程中的降噪耳机和差分放大器。
下面我们就来详细了解一下差分 Transformer 的设计思路。
差分 Transformer
差分 Transformer 是一种用于序列建模的基础模型架构。为了方便说明,他们使用了仅解码器(decoder-only)模型作为示例来描述该架构。
该模型堆叠了 L 个 Diff Transformer 层。给定一个输入序列 x,将输入嵌入打包成 X^0。输入会被进一步上下文化来获得输出 X^L。每一层都由两个模块组成:一个差分注意力模块和之后的前向网络模块。
相比于 Transformer,差分 Transformer 的主要差别在于使用差分注意力替换了传统的 softmax 注意力,同时保持整体宏观布局不变。此外,他们也参考 LLaMA 采用了 pre-RMSNorm 和 SwiGLU 这两项改进措施。
差分注意力
差分注意力机制的作用是将查询、键和值向量映射成输出。这里使用查询和键向量来计算注意力分数,然后计算值向量的加权和。
此处的关键设计是使用一对 softmax 函数来消除注意力分数的噪声。具体来说,给定输入 X,首先将它们投射成查询、键和值 Q_1、Q_2、K_1、K_2、V。然后差分注意力算子 DiffAttn (・) 通过以下方式计算输出:
其中 W^Q、W^K 、W^V 是参数,λ 是可学习的标量。为了同步学习动态,将标量 λ 重新参数化为:
其中 λ_q1、λ_k1、λ_q2、λ_k2 是可学习的向量,λ_init ∈ (0, 1) 是用于初始化 λ 的常数。该团队通过经验发现,设置 λ_init = 0.8 − 0.6 × exp (−0.3・(l − 1)) 在实践中效果很好,其中 l ∈ [1, L] 表示层索引。它在实验中被用作默认策略。
他们也探索了另一种初始化策略:对所有层使用相同的 λ_init(例如 0.8)。如后面消融研究所示,使用不同的初始化策略时,性能相对稳健。
差分注意力利用两个 softmax 注意力函数之间的差来消除注意力噪声。这个想法类似于电气工程中提出的差分放大器,其中两个信号之间的差用作输出,这样就可以消除输入的共模噪声。此外,降噪耳机的设计也基于类似的想法。
- 多头差分注意力机制
该团队也为差分注意力使用了多头机制。令 h 表示注意力头的数量。他们对各个头使用不同的投影矩阵 W^Q_i 、W^K_i 、W^V_i ,i ∈ [1, h]。标量 λ 在同一层内的头之间共享。然后对头输出执行归一化,并投射成最终结果,如下所示:
其中 λ_init 是 (2) 式中的常数标量,W^O 是可学习的投影矩阵,LN (・) 是对每个头使用 RMSNorm,Concat (・) 的作用是沿通道维度将头连接在一起。这里使用一个固定乘数(1 − λ_init)作为 LN (・) 的缩放尺度,以使梯度与 Transformer 对齐。
- 逐头归一化
图 2 使用了 GroupNorm (・) 来强调 LN (・) 独立应用于每个 head。由于差分注意力往往具有更稀疏的模式,因此头之间的统计信息更加多样化。为了改进梯度的统计情况,LN (・) 算子会在连接操作之前对每个头进行归一化。
整体架构
其整体架构会堆叠 L 层,其中每层包含一个多头差分注意力模块和一个前向网络模块。如此,便可将差分 Transformer 层描述为:
其中 LN (・) 是 RMSNorm,SwiGLU (X) = (swish (XW^G) ⊙ XW_1) W_2,且 W^G、W_1、W_2 是可学习的矩阵。
实验
该团队从以下角度评估了差分 Transformer 在 LLM 中的应用,包括对比评估、应用评估和消融研究。这里我们仅关注实验结果,更多实验过程请访问原论文。
语言建模评估
该团队评估了差分 Transformer 的语言建模能力。为此,他们使用 1T token 训练了一个 3B 大小的差分 Transformer 语言模型,并与之前的 Transformer 语言模型做了比较。
结果见表 1,其中报告的是在 LM Eval Harness 基准上的零样本结果。
可以看到,3B 规模下,差分 Transformer 语言模型的表现优于之前的 Transformer 语言模型。此外,实验也表明差分 Transformer 在多种任务上都胜过 Transformer,详见原论文附录。
与 Transformer 的可扩展性比较
该团队也比较了新旧 Transformer 的可扩展性。结果见图 3,其中 a 比较了模型规模方面的可扩展性,而 b 则是训练 token 数量方面的可扩展性。
可以看到,在这两个方面,差分 Transformer 的可扩展性均优于常规 Transformer:仅需后者 65% 左右的模型大小或训练 token 数量就能达到相媲美的性能。
长上下文评估
当 3B 模型上下文长度增长至 64K,模型的表现又如何呢?又使用另外 1.5B token 训练了 3B 版本的检查点模型之后,该团队发现随着上下文长度的增加,累积平均负对数似然(NLL)持续下降。差分 Transformer 得到的 NLL 值低于常规 Transformer。见图 4,这样的结果表明,差分 Transformer 可以有效地利用不断增加的上下文。
关键信息检索
为了检验差分 Transformer 检索关键信息的能力,该团队执行了 Needle-In-A-Haystack(草堆找针)测试。
表 2 给出了 4K 上下文长度的情况,其中 N 是针的数量,R 是查询引用的数量。可以看到,差分 Transformer 的多针检索准确度高于常规 Transformer,尤其是当针数量较多时,差分 Transformer 的优势会更加明显。
那么当上下文长度提升至 64K 时,又会如何呢?结果见图 5,这里使用的上下文长度在 8K 到 64K 之间,使用了 N = 8 和 R = 1 的设置。
可以看到,在不同的上下文长度下,差分 Transformer 能够保持相对稳定的性能。而当上下文长度越来越大时,常规 Transformer 的性能会逐渐下降。
另外,表 3 展示了分配给关键信息检索任务的答案范围和噪声上下文的注意力分数。该分数可代表模型保留有用信息、抵抗注意力噪声的能力。
可以看到,相比于常规 Transformer,差分 Transformer 能为答案范围分配更高的注意力分数,同时为注意力噪声分配更低的注意力分数。
上下文学习能力评估
该团队从两个角度评估模型的上下文学习能力,包括多样本分类和上下文学习的稳健性。
图 6 展示了新旧 Transformer 模型的多样本分类结果。结果表明,在不同的数据集和不同的演示样本数量上,差分 Transformer 均稳定地优于 Transformer。此外,差分 Transformer 的平均准确度优势也很明显,从 5.2% 到 21.6% 不等。
图 7 则展示了两种模型的上下文学习稳健性结果。该分析基于 TREC 数据集,并且采用了两种提示词格式:示例随机排列(图 7a)和按类别交替排列(图 7b)。
在这两种设置下,差分 Transformer 的性能方差要小得多。结果表明,新方法在上下文学习任务中更为稳健。相比之下,Transformer 容易受到顺序排列的影响,导致最佳结果与最差结果之间差距巨大。
上下文幻觉评估
该团队基于文本摘要和问答任务评估了模型的上下文幻觉现象。结果见表 4。
可以看到,相比于常规 Transformer,差分 Transformer 在摘要和问答任务上的上下文幻觉更低。该团队表示,原因可能是差分 Transformer 能更好地关注任务所需的基本信息,而不是无关上下文。
激活异常值分析
在 LLM 中,一部分激活值明显大于大多数激活值的现象被称为激活异常值(activation outliers)。异常值导致训练和推理过程中模型量化困难。实验表明差分 Transformer 可以降低激活异常值的幅度,从而可能实现更低的量化位宽。
表 5 展示了两个训练得到 Transformer 和差分 Transformer 模型的激活值统计情况。这里分析了两种类型的激活,包括注意力 logit(即 pre-softmax 激活)和隐藏状态(即层输出)。可以看到,尽管中位数相似,但与 Transformer 相比,差分 Transformer 的较大激活值要低得多。这表明新方法产生的激活异常值较少。
图 8 则展示了将注意力 logit 量化到更低位的情况。这里使用的方案是:使用 absmax 量化的动态后训练量化。其中,16 位配置表示未经量化的原始结果。模型逐步量化为 8 位、6 位和 4 位。这里报告的是在 HellaSwag 上的零样本准确度,但该团队也指出在其它数据集上也有类似表现。
从图中可知,即使降低位宽,差分 Transformer 也能保持较高性能。相较之下,常规 Transformer 的准确度在 6 位和 4 位量化时会显著下降。这一结果表明,差分 Transformer 本身就能缓解注意力分数中的激活异常值问题,从而可为低位 FlashAttention 的实现提供新机会。
最后,该团队也进行了消融实验,证明了各个新设计的有效性。
文章转自微信公众号@算法进阶