目录

核心主题与目标

视频主要围绕 Flash Attention 的工作原理展开,旨在解决标准 Transformer Attention 机制在处理长序列时面临的内存占用高计算速度慢的问题。观众看完视频后,应该能够理解 Flash Attention 是如何利用 GPU 硬件特性(特别是 SRAM)进行优化,其核心的 Tiling (分块)Online Softmax 算法,以及它带来的显著性能优势。

关键知识点梳理

1. 标准 Attention 的瓶颈

标准的自注意力机制计算过程如下:

  1. 矩阵相乘: 计算查询 Q 和键 K 的转置 K^T 的乘积,得到注意力分数矩阵 S
  2. 缩放: 将 S 除以维度的平方根 sqrt(d_k)
  3. Softmax: 对缩放后的分数矩阵按行进行 Softmax 归一化。
  4. Dropout: 应用 Dropout 防止过拟合。
  5. 与 V 相乘: 将 Softmax 的结果与值 V 矩阵相乘,得到最终输出 O

Screenshot-[00:16]

瓶颈分析:

  • 内存瓶颈: 在处理长序列 (如 4K, 8K, 甚至 100K) 时,中间产生的 S 矩阵(尺寸为 N x N,N 为序列长度)会变得极其巨大,消耗大量 GPU 内存。
  • 速度瓶颈 (I/O Bound): 整个计算过程中,MaskSoftmaxDropout 等操作都需要从 GPU 的主内存 (HBM) 中读取数据,计算后再写回,这些内存读写操作(I/O)远比计算本身耗时,成为性能瓶颈。

Screenshot-[04:22]

2. GPU 存储器层次结构

理解 Flash Attention 的关键在于了解 GPU 的内存体系。

  • HBM (High-Bandwidth Memory): 我们通常所说的“显存”。
    • 特点: 容量大(如 A100 有 40GB/80GB),带宽较高(约 1.5 TB/s)。
    • 瓶颈: 相比于片上内存,其访问速度依然是瓶颈。Out of Memory (OOM) 错误通常发生在此。
  • SRAM (Static RAM): 片上内存(On-chip Memory)。
    • 特点: 容量极小(如 A100 的每个 SM 仅有 192KB),但速度极快(约 19 TB/s),比 HBM 快约20倍。
    • 组成: GPU 由多个流式多处理器(Streaming Multiprocessor, SM)构成,每个 SM 内部都有自己的 SRAM。

Screenshot-[01:22]

3. Flash Attention 的核心思想

Flash Attention 的核心思想是 Kernel Fusion (内核融合),利用快速但微小的 SRAM 来规避缓慢的 HBM 读写。

它将 Attention 计算中的多个步骤(矩阵乘、Mask、Softmax、与V相乘)融合成一个单一的 GPU kernel。这个 kernel 将计算任务分解成小块,并将这些小块加载到 SRAM 中完成所有中间计算,最后才将最终结果写回 HBM。这样就极大地减少了与 HBM 的数据交换次数,从而实现加速。

Screenshot-[05:05]

4. Tiling (分块) 计算策略

由于 SRAM 容量非常小,无法容纳整个 Q, K, V 矩阵,Flash Attention 采用 Tiling (分块)Blocking 的策略。

  1. Q, K, V 矩阵在序列长度(N)维度上切分成小块 (Block)。
  2. 迭代地将 Q 的一个块 (Block of Q) 和 K 的一个块 (Block of K) 加载到 SRAM 中。
  3. 在 SRAM 内部计算这两个块之间的注意力分数(QK^T)。
  4. 紧接着,在不将中间结果写回 HBM 的情况下,在 SRAM 中直接进行 Softmax 计算,并与从 HBM 加载的相应 V 块相乘。
  5. 得到一个局部的输出结果,并将其累加到 HBM 中的最终输出矩阵 O 中。
  6. 重复此过程,直到遍历完所有块。

Screenshot-[07:39]

分块尺寸选择: 块的大小(B_rB_c)需要精心选择,以最大化利用 SRAM。其大小通常由 SRAM 容量 M 和头维度 d 决定,一个简化的公式是 B_c = M / (4*d),以确保 Q, K, V, O 四个矩阵的块都能装入 SRAM。

Screenshot-[17:49]

5. Tiled Softmax 的挑战与解决方案

Tiling 策略最大的挑战在于 Softmax 的计算。

  • 挑战: Softmax 的计算需要知道整行所有元素的和作为分母。但在分块计算时,每个 kernel 只能看到行的一部分数据,无法直接计算出正确的分母。

Screenshot-[11:32]

  • 解决方案 (Online Softmax): Flash Attention 采用了一种巧妙的“在线”算法来解决这个问题。其核心是:在遍历一行中的所有块时,维护并更新三个关键的统计量:

    1. m_i: 到目前为止,当前行 i 的最大值。
    2. l_i: 到目前为止,当前行 iSoftmax 分母的累加和。
    3. O_i: 到目前为止,当前行 i 的累加输出结果。

    当处理一个新的块时,算法会:

    1. 计算新块的最大值,并更新全局最大值 m_i
    2. 利用代数技巧,根据新的最大值对先前计算出的 l_iO_i 进行重新缩放 (rescale)
    3. 计算新块的 Softmax 相关值,并将其累加到已缩放的 l_iO_i 上。

    这个过程保证了在遍历完所有块后,得到的最终 O_i 与对整行进行标准 Softmax 计算的结果完全一致,且整个过程都在 SRAM 中完成,具有数值稳定性。

Screenshot-[21:31]

6. Flash Attention 完整算法流程 (伪代码解析)

Flash Attention 的算法可以概括为两个嵌套的循环:

Screenshot-[29:46]

  • 外层循环 (Outer Loop): 遍历 KV 矩阵的列块 j
    • K_jV_j 块从 HBM 加载到 SRAM。
  • 内层循环 (Inner Loop): 遍历 Q 矩阵的行块 i
    • Q_i 块从 HBM 加载到 SRAM。
    • On-chip 计算:
      1. S_ij = Q_i * K_j^T
      2. 计算当前块的行最大值 m_ijSoftmax 分子 P_ij
      3. 计算当前块的分母和 l_ij
      4. 更新全局行最大值 m_i_new
      5. 核心更新步骤: 使用 rescaling trick 更新全局分母 l_i 和全局输出 O_i
      • O_i_new = f(O_i_old, l_i_old, m_i_old, P_ij, V_j, ...)
    • 将更新后的 O_i, l_i, m_i 写回 SRAM,准备内层循环的下一次迭代。
  • 内层循环结束后,将最终计算好的 O_i 块从 SRAM 写回 HBM。

这个过程通过在 SRAM 中保留和更新运行状态,避免了 O(N^2) 级别的内存写入。

7. Flash Attention V1 vs V2

  • FlashAttention 1:

    • Tri Dao 在斯坦福大学读博期间提出。
    • 实现了基本的 Tiling 和 Kernel Fusion 思想。
    • 针对前向传播 (Forward Pass) 和反向传播 (Backward Pass) 采用了不同的分块策略(行优先 vs 列优先)来优化。
  • FlashAttention 2:

    • 是一个重大升级,进一步优化了 kernel,减少了非矩阵乘法运算的开销,并改进了线程调度。
    • 性能比 V1 更快。
    • 原生支持 MQA (Multi-Query Attention)GQA (Grouped-Query Attention)
    • 已被 Hugging Face 等主流框架广泛集成。

Screenshot-[40:41]

8. 优缺点分析

优势:

  1. 速度快: 相比标准实现,有数倍的速度提升。
  2. 内存高效: 内存占用从 O(N^2) 降低到 O(N),因为它无需在 HBM 中实例化巨大的 N×N 注意力矩阵,从而支持更长的上下文窗口。
  3. 结果精确: 它不是近似算法。其计算结果与标准 Attention 完全相同,没有精度损失。

Screenshot-[43:54]

缺陷/限制:

  1. 硬件依赖: 对 GPU 架构有特定要求,主要支持 NVIDIA 近代 GPU(如 Ampere, Hopper 架构),对早期 GPU 支持不佳。
  2. 安装复杂: 需要针对特定的 GPU 和 CUDA 版本进行编译,环境配置有时会比较麻烦。
  3. 反向传播的 Recomputation: 为了节省内存,Flash Attention 在反向传播时不存储前向传播的中间注意力矩阵,而是在需要时重新计算。这是一种典型的 “以时间换空间” 的策略,虽然节省了大量内存,但会增加训练的计算量。

AI 总结

Flash Attention 是一种针对 Transformer 模型中注意力机制的高效实现。其核心思想是利用 GPU 内部速度极快但容量很小的 SRAM 存储器,通过 Kernel Fusion (内核融合) 技术,将多个内存密集型操作(如 Softmax)合并到一个计算核心中执行,从而大幅减少与主显存 (HBM) 的数据交换次数。

为解决大矩阵无法放入 SRAM 的问题,它采用 Tiling (分块) 策略,将矩阵切成小块处理。其关键算法创新是 Online Softmax,它允许在只看到部分数据的情况下,通过迭代更新和重新缩放(rescaling)技术,精确地计算出全局 Softmax 的结果。

最终,Flash Attention 在不损失任何计算精度的情况下,实现了显著的速度提升和内存优化(内存占用从 O(N^2) 降至 O(N)),使得大模型能够处理更长的上下文序列。