AIxiv专栏是机器之心发布学术和技术内容的专栏。几年来,机器之心AIxiv专栏已收到2000余篇报道,覆盖全球各大高校和企业的顶级实验室,有效促进了学术交流和传播。如果您有优秀的作品想要分享,请随时投稿或联系我们进行举报。投稿邮箱:liyazhou@jiqizhixin.com; zhaoyunfeng@jiqizhixin.com
论文共同第一作者张金涛和黄浩峰分别来自清华大学计算机系和交叉信息研究院。论文通讯作者陈建飞副教授及其他共同作者均来自清华大学计算机系。
在大型模型中,线性层的低比特量化已经逐渐实现。然而,对于注意力模块,目前几乎每个模型仍然使用高精度(例如 FP16 或 FP32)注意力操作来进行训练和推理。而且,随着大型模型需要处理的序列长度不断增加,Attention(注意力操作)的时间开销逐渐成为主要开销。
此前,清华大学陈剑飞团队提出了一种8位即插即用的Attention(SageAttention),将Attention中的QK^T量化为INT8,保持PV为FP16精度,并使用FP16精度的矩阵乘法累加器。同时提出Smooth K技术保持量化注意力的准确性,实现比FlashAttention2 2倍的加速,在各种大型模型上保持端到端的准确性。
目前,SageAttention已被业界和社区广泛应用于各种开源和商业模型中,如CogvideoX、Mochi、Flux、Llama3、Qwen等。
近日,陈剑飞团队进一步提出了4-Bit即插即用Attention(SageAttention2),相比FlashAttention2和xformers分别实现了3倍和4.5倍的即插即用推理加速,并在视频、图像、文本方面实现了效果- 即使生成大型模型,也能保持端到端的精度性能。
即插即用示例
SageAttention2实现了高效的Attention算子,可以实现即插即用的推理加速。输入任意Q、K、V矩阵,SageAttention2都可以快速返回Attention Output(O)。
具体来说,SageAttention2使用起来非常方便。克隆存储库(git clone)并执行 python setup.py install 后,只需一行代码即可获得 Attention 输出。您可以使用此接口轻松替换任何模型中的Attention函数:
效果方面,以开源视频生成模型CogvideoX-1.5-5B为例,使用SageAttention2可以端到端加速1.8倍,生成的视频无损。
更重要的是,SageAttention2提供了比SageAttention更广泛的硬件支持。除了 RTX 4090 上 FlashAttention 的 3 倍加速之外,L20、L40 和 L40S 上的 2 倍加速,以及 A100、A800 和 A6000 上的 1.45-1.6 倍加速(基于 SageAttention)。
接下来,研究团队将从前言、挑战、方法和实验结果四个方面介绍SageAttention2(整体流程图如下)。
前言
随着大型模型需要处理的序列长度越来越长,Attention的速度优化变得越来越重要。下图展示了标准Transformer模型中每个操作随着序列长度变化的时间比例:
为了方便在attention操作中引用矩阵,我们先回顾一下attention的计算公式:
尽管SageAttention建议将Q和K量化为INT8,但将P和V保持在FP16精度并使用FP16矩阵乘法累加器来加速Attention。但这样做的缺点是:1)INT8矩阵乘法只能达到INT4矩阵乘法速度的一半,2)使用FP16乘法累加器对FP16矩阵乘法的加速仅对RTX4090和RTX3090显卡有效。
为了克服上述缺点,SageAttention2提出将Q和K量化为INT4,将P和V量化为FP8来加速Attention。然而,这样做的挑战是巨大的。
4-Bit 注意力量化有什么问题?
研究团队发现,直接量化 INT4 的注意力操作中的 Q 和 K 会导致几乎所有模型和任务的结果极差。例如,在CogVideoX Vincent视频模型中,会得到完全模糊的视频; Llama2-7B 在四项选择题上取得了 25% 的准确率。
经过仔细分析,研究团队发现导致量化注意力不准确的主要原因有两个:
(1) INT4的数值范围相对于INT8来说非常小,导致当Q、K矩阵中出现一些异常值时,其量化误差变得非常明显。大多数模型都会在 Q、K 维度异常值中显示较大的通道。这大大降低了QK^⊤矩阵乘法的精度。
(2)研究团队发现,在Nvidia显卡上,FP8矩阵乘法指令(mma.f32.f8.f8.f32)的乘法累加器并不是官方公布的FP32精度,而只有FP22精度,结果在PV矩阵中,乘法中会出现较大的累积误差。
技术方案
为了解决上述两个挑战,研究团队提出了相应的解决方案。
(1) 在保留SageAttention中K的平滑的同时,提出对Q进行平滑:Q –mean(Q)。其中mean (Q) 是沿通道维度的平均向量。完成平滑操作后,在Attention计算过程中,需要将mean(Q)和K^T的向量和矩阵相乘的结果补偿到S中。
这与直接量化Q、K到INT4相比,使得精度发生了质的变化。下表显示了该方法和直接量化Q、K到INT4在Cogvideo和Llama3.1上的端到端性能。
平滑矩阵Q前后的数据分布可视化结果如下。可以发现平滑后的Q更好地利用了INT4数据范围:
(2) 对 Q 和 K 进行逐线程量化。对于矩阵 Q 和 K,SageAttention2 采用根据 mma 指令进行矩阵内存排列的要求,并根据 GPU 线程对 Q 和 K 中的 Token 进行分组,使得量化粒度为 16比 SageAttention 中的 per-block 精细 1 倍,这极大地提高了 4Bit QK^⊤ 乘法的精度,而不会引入任何额外的开销。
具体来说,在SageAttention中,Q的每个块将被分为c_w个段,这些段将由GPU流处理器(SM)中的c_w个GPU扭曲进行处理。每个包含 32 个线程的 warp 然后使用 NVIDIA 的 mma.m16n8k64 PTX 指令来执行 QK^⊤ 操作。根据该指令的布局要求,研究团队发现一个warp内的Q[8×(n%8)]可以共享一个量化缩放参数,而K[8×(n%8)]和K[8× (n%8+1)]还可以共享量化缩放参数,其中n是令牌索引。
这种量化方法更加细粒度,并且不会增加额外的开销。这是因为它根据MMA指令的布局将不同的GPU线程分配给不同的量化Token组,并且每个线程只对应一个量化缩放参数进行反量化。与 Per-token 量化不同,每个线程对应多个量化缩放参数。
如下表所示,可以发现per-thread量化的精度远高于SageAttention中使用的per-block量化,并且精度与per-token量化几乎没有差别。
(3)对于FP8的PV矩阵乘法,FP32寄存器用于累加每个FlashAttention块粒度的PV的FP22乘法结果。该方法可以有效防止FP22乘法累加器沿序列长度累加过多的误差,将FP22累加器引起的误差控制在FlashAttention块的粒度内,提高FP8的PV乘法的精度。
(4)对于P和V,研究团队比较了多种定量数据类型,发现使用E4M3数据格式的FP8的精度最为准确,基本接近FP16的精度。因此,P和V被量化为E4M3。
下图展示了SageAttention2的算法流程:
SageAttention2 实现了两种类型的内核。区别在于 Q 和 K 是用 INT4 还是 INT8 量化:
此外,SageAttention2还提出了一种可选的矩阵V平滑技术,可以进一步提高PV矩阵乘法的精度。具体来说,当某些模型中的V矩阵存在通道维度的偏移时,可以通过用V减去其通道维度的均值(V)来去除偏移,然后进行正常的量化Attention操作。只需要在最终的 Attention Output 中添加平均值(V)即可保持计算的准确性。
这种方法提高准确率的原因如下图所示。在FP22的表示范围内,数值越大,相对于FP32的误差越大。 P的范围在0到1之间。当V矩阵的列有较大的数值偏移时,PV FP22累加器的精度会变差。通过平滑V消除偏移后,可以强化PV矩阵。乘法的准确性。
实验效果
SageAttention实现了底层GPU CUDA Kernel,在算子速度和每个模型的端到端精度方面都有非常好的表现。
具体来说,算子速度比FlashAttention2和xformers大约快3倍和4.5倍:
算子的精度也比 Q 和 K 的 SmoothQuant 和 Hadamard 变换更准确:
在真实场景下各个模型的端到端准确率表现中,各个模型在视频、图像、文本生成等大型模型上都保持了端到端的准确率表现:
下图是混元视频中的可视化示例:
下图是Cogvideo中的可视化示例:
下表展示了SageAttention2在各种语言、视频和图像生成模型中的端到端准确率表现:
在端到端速度性能方面,两种SageAttention2 Kernel实现都可以有效加速长序列模型。例如CogVideoX1.5-5B可以端到端加速1.8倍,其他型号也可以加速1.6到1.8倍。 。