本文中有较多Latex数学公式,博客上有一些数学公式格式渲染不正确,可以查看flash_attention简要笔记

优化效果

原来,attention部分的计算量和中间激活占用显存的复杂度都是$O(N^2)$

计算量部分原来QK矩阵乘和attn_score@V矩阵乘的计算量,复杂度都是$O(N^2)$;中间激活因为中间有一个attn_score,所以复杂度也是$O(N^2)$

现在,attention部分的中间激活占用显存的复杂度变为$O(N)$,计算量的复杂度没有变但是通过减少访存加快了计算速度,而且fa与原attention完全等价

具体过程

flash-attention还是基于kernel融合的思想,将QK矩阵乘法、mask、softmax、dropout合并成一个kernel,这样不仅减少了中间变量对显存的占用,而且也减少了计算过程中的访存

一些符号表示:

  • $S_{ij}=Q_i \times K_j^T$,Q分块和K分块的乘积,形状为$[B_r, B_c]$

  • $\widetilde{m}{ij}=rowmax(S{ij})$:对分块$S_{ij}$而言,得到其每行的最大值,形状为$[B_r, 1]$

  • $\widetilde{P}{ij}=e^{S{ij}-\widetilde{m}{ij}}=e^{S{ij}-rowmax(S_{ij})}$:每个分块$S_{ij}$减去其局部rowmax $\widetilde{m}_{ij}$,形状为$[B_r, B_c]$

  • $\widetilde{l}{ij}=rowsum(\widetilde{P}{ij})=rowsum(e^{S_{ij}-rowmax(S_{ij})})$:对$\widetilde{P}_{ij}$而言,按行求和,形状为$[B_r, 1]$

  • $m^{new}i=max(\widetilde{m}{i0}, \widetilde{m}{i1}, … , \widetilde{m}{ij})=rowmax(concat(S_{i0}, S_{i1}, … , S_{ij}))$:即$contcat(S_{i0}, S_{i1}, … , S_{ij})$这j+1个分块的每行的最大值,形状为$[Br, 1]$

  • $m_i$:$m_i^{new}$位于SRAM上,将$m_i^{new}$写回到HBM就是$m_i$,初始化$m=-\infty$

  • $l^{new}i=e^{m_i-m_i^{new}}l_i + e^{\widetilde{m}{ij}-m_i^{new}} \widetilde{l}{ij}=rowsum[e^{S{00}-max(\widetilde{m}{00},…,\widetilde{m}{0j})}] + … + rowsum[e^{S_{0j}-max(\widetilde{m}{00},…,\widetilde{m}{0j})}]$:

  • $l_i$:$l_i^{new}$位于SRAM上,将$l_i^{new}$写回到HBM就是$l_i$,初始化$l=0$

如果不使用flash-attention,具体过程为:

  1. $S = Q K ^T $
  2. $P = softmax(S+mask)$
  3. $O = P V$

如果使用flash-attention,前向过程为:

image-20240916230932000

大致过程为:

FA
  1. 首先对QKV进行分块,K、V分块方法相同(V的分块图中没画出来),首先可以计算$S_{ij}=Q_i\times K_j^T$。因为对QKV进行了分块,所以每次SRAM上能保留$S_{ij}$和$\widetilde{P}_{ij}$(橙黄色表示存储在SRAM上;橙红色表示虽然也存储在SRAM上,但是这些部分每次outer loop会写回到HBM中)
  2. 如果有mask,此时对$S_{ij}$进行mask
  3. 使用一个局部变量$\widetilde{m}{ij}$和一个全局变量$m$(或者说$m^{new}$,$m^{new}$的值在SRAM上,但是每次outer loop会写回到HBM中)来记录分块$S{ij}$局部rowmax和中间遍历过的分块$S_{i:}$的历史rowmax
  4. 然后基于分块$S_{ij}$计算局部的safe softmax的分子部分,即$e^{S_{ij}-rowmax(S_{ij})}$,safe softmax的分子部分累加就是分母部分,这样,就得到了一个针对分块$S_{ij}$的、局部的safe softmax的分母$\widetilde{l}{ij}$,和 一个 遍历过的历史分块$S{i:}$的 safe softmax分子部分的 累加和$l^{new}$(注意断句,写公式有点晦涩难懂,用语言描述又不太好描述),局部的$\widetilde{l}{ij}$就是用来更新全局的$l$(或者说$l^{new}$,$l^{new}$的值在SRAM上,但是每次outer loop会写回到HBM中),对$\widetilde{l}{ij}$举一个例子:
    • 当j=0,i=0时,$l_0^{new}=e^{m_0-m_0^{new}} l_0+e^{\widetilde{m}{00}-m_0^{new}} \widetilde{l}{00}=\widetilde{l}_{00}$
    • 当j=1,i=0时,$l_0^{new} = rowsum(e^{S_{00}-max⁡(\widetilde{m}{00}, \widetilde{m}{01})})+rowsum(e^{S_{01}-max⁡(\widetilde{m}{00}, \widetilde{m}{01})})$
  5. 然后对$\widetilde{P}_{ij}$进行dropout
  6. 然后相当于要进行$O+=\widetilde{P}_{ij} V_i$了,对于算法的第15行,可以使用分配律拆开看,其中有两个操作:
    1. 后半部分:对于当前的$\widetilde{P}{ij} V_i$相乘,$\widetilde{P}{ij}$中减去的是分块$S_{ij}$局部的rowmax,需要调整到 此时已经见过的、所有分块$S_{i:}$的rowmax,就是第15行后半部分中$e^{\widetilde{m}_{ij}-m_i^{new}}$的意思
    2. 前半部分:调整上一次的$O$,先乘旧的$l_i$恢复到safe softmax的分子部分,然后乘以$e^{m_i-m_i^{new}}$更新一下safe softmax分子部分中减去的全局rowmax,最后再除以当前的safe softmax的分母

(反向过程还是看别的博客吧)

简要分析

首先分析一下fa的FLOPs(只分析大块的矩阵乘法,其他小的操作就不计算了):

  • 一开始的$Q_i K^T_j$矩阵相乘,其中$Q_i$的形状为$[B_r, d]$,$K_j^t$的形状为$[d, B_c]$,此时FLOPs=$2d \times B_r \times B_c$
  • 后面计算O的时候有一个$\widetilde{P}{ij} V_i$矩阵相乘,其中$\widetilde{P}{ij}$的形状为$[B_r, B_c]$,$V_i$的形状为$[B_c, d]$,此时FLOPs=$2B_c \times B_r \times d$一共进行了$\frac{N}{B_r} \times \frac{N}{B_c}$次上面的循环,所以FLOPs=$4N^2d$,如果d远小于N,则计算复杂度就变成了$O(N^2)$,计算复杂度相比于standard attention没有变化

然后再分析一下显存占用(显存占用说的是HBM上的显存占用,假设计算精度为$w$ Bytes)

  • HBM上需要维护一个全局的rowmax和expsum,占用显存为$w\times N$
  • 然后还要存储一个最后的输出$O$,占用显存为$wNd$,但是这个部分是必须的
  • 因此,显存占用的复杂度为$O(Nd)$(或者$O(N)$,如果不考虑$O$的话)。standard attention需要保存中间的$S, P$,显存占用复杂度为$O(N^2)$

fa相对于standard attention一个优势,在于减小了计算过程中的访存量,最后来分析一下访存次数:

  • standard attention
    • 从HBM中读取Q,K(形状都是$[N, d]$),访存量=$wNd$,计算$S=QK^T$,然后向HBM中写回S(形状为$[N, N]$),访存量=$wN^2$
    • 从HBM中读取S,访存量=$w N^2$,计算$P=softmax(S)$,向HBM中写回P,访存量=$w N^2$
    • 从HBM中读取P(形状为$[N, N]$)、V(形状为$[N, d]$),访存量=$w N^2 + wNd$,计算$O=PV$,向HBM中写回O(形状为$[N, d]$),访存量=$wNd$
    • 总的访存量=$w(3Nd+4N^2)$,如果d远小于N,则访存量的复杂度变成了$O(N^2)$
  • flash attention(分析时将inner loop作为一个整体进行分析,就像上面示意图画的那样)
    • 从HBM中读取分块$Q_i, i=0, …, T_r -1$,读取分块$K_j$,访存量=$w(Nd+B_c d)$;后面$S_{ij}, \widetilde{P}_{ij}$不需要写回HBM;$m, l$只是一个向量,数据量很少,忽略;再后面读取和写入分块$O_i, i = 0, …,T_r =1$,访存量=$w(2\times Nd)$
    • outer loop共有$\frac{N}{B_c}=T_c$次,总的访存量=$w\times \frac{N}{B_c} \times (Nd + B_cd + 2Nd)=w(Nd+\frac{3N^2d}{B_c})=w(T_c+1)Nd$
    • 比如N=1024,d=64,B=64,standard_attention访存量-flash_attention访存量=$w(3Nd+4N^2-Nd-\frac{3N^2d}{B_c})=w(2Nd+(4-\frac{3d}{B_c})N^2)=w(2Nd+N^2)$,可以看出少了很多访存

实际使用

接口返回值

flash-attention开源代码中,针对不同qkv、是否是varlen、是否需要kv_cache等不同需求封装了不同的接口,这里说一下返回值。这些接口的返回值都相同,除了返回输出的$O$之外,如果设置了return_attn_probs=True,还会返回softmax_lse和S_dmask:

  • softmax_lse(形状$[nheads, seqlen]$):在计算$S=\frac{QK^T}{scale}$之后,会得到形状为$[bs, seqlen, seqlen]$的方阵S,在计算softmax的过程中,需要按行求和,得到一个列向量,然后再取log,写成表达式即为:$softmax_lse=log[\sum_je^{S_{ij}}]$,注意不是$softmax_lse=log[\sum_je^{S_{ij}-rowmax(S_{ij})}]$,参考issue:What’s the exactly formula of softmax_lse? #404
  • S_dmask(形状$[bs, nheads, seqlen, seqlen]$):就是返回$P=softmax(\frac{QK^T}{scale}+mask)$的这个P矩阵

varlen attention

特别的,这里再说一下flash_attn_varlen_func等一些支持varlen的接口,其函数形参中还有cu_seqlens_qcu_seqlens_kmax_seqlen_qmax_seqlen_k等特有的参数。这里介绍一些varlen是什么。

varlen即变长序列,产生的背景是”数据拼接“,即LLM使用的训练数据集中,长度较短的序列占大多数,这些短序列为了能够符合Transformer固定长度的输入,就要进行padding,序列越短,padding越多,而我们不太想要padding,padding只是无奈之举。此时,我们可以使用varlen特性,简单来说就是将多个短序列拼接成一个长序列,但是还是每个短序列自己内部计算注意力,短序列之间是隔离的,这样减少了padding,节省计算量和显存。

这里举个例子(参考),比如一些短序列长度分别是:70,300,180, …,260,120,1200,…等,attention固定输入长度是4096,此时我们将这些短序列拼接起来,使用varlen_attn后,就像右图所示,每个短序列自己内部计算attention,短序列之间不计算attention(否则就像左图这样,白白多了很多浪费的计算)

XTuner

为了实现varlen特性,需要对接口有一些调整。比如不使用varlen的flash_attn接口中,传入的Q、K、V的形状一般为$[bs, seqlen, nheads, head_dim]$(K和V的nheads可以少于Q的nheads,此时就是GQA/MQA)。在使用varlen的flash_attn接口中,主要有两点变化:

  • Q、K、V的形状一般为$[total_seq, nheads, head_dim]$,这里将多个batch拼接起来,拼起来的长度为$total_seq$
  • 多了cu_seqlens_qcu_seqlens_kmax_seqlen_qmax_seqlen_k等特有的参数
    • cu_seqlens_q是对每个短序列的Q的长度的exclusive_scan,作用就是找到原来每个batch的起始点(offset),比如上面的例子,此时cu_seqlens_q=[0, 70, 370, 550, ... ],如果cu_seqlens_q的形状为$[batch_size+1]$,则需要在最后拼接上序列Q的总长度
    • max_seqlen_q好理解,就是短序列的Q的最长长度

在具体实现中,对每个序列的每个head分别launch kernel,来实现并行计算,这个过程中要通过cu_seqlens_q来确定对应Q的start_idx和end_idx。

参考:

Flash attention变长batching API使用

How did flash-attn compute attention for cu_seqlens #850

参考

图解大模型计算加速系列:FlashAttention V1,从硬件到计算逻辑

优质好文:

[Attention优化][2w字]🔥原理&图解: 从Online-Softmax到FlashAttention V1/V2/V3