目录


1. KV Cache 核心概念

KV Cache 是一种应用于大语言模型推理 (Inference) 阶段的关键优化技术。

Screenshot-[00:22]

  • 1. 应用于推理阶段:

    • KV Cache 不用于训练,仅在模型生成内容时使用。
    • 核心前提是:在自回归生成过程中,对于已经生成的 token,其 K (Key)V (Value) 向量是不会改变的。
  • 2. 存在于 Decoder 结构中:

    • KV Cache 是 Transformer 模型中 Decoder(解码器)架构的一部分。
    • Encoder(编码器)架构的模型(如 BERT)无法使用 KV Cache。
    • 对于 Encoder-Decoder 架构,KV Cache 也只作用于 Decoder 部分。
  • 3. 主要目的:

    • 提升 Attention 计算速度,尤其是在 QKV 的两次矩阵相乘过程中,从而加快 token 的生成速度。 Screenshot-[01:01]
  • 4. 副作用:

    • 增加内存占用。Cache 本质是缓存,需要将大量的 K 和 V 向量存储在 GPU 内存中,这会导致显存消耗显著增加。

2. Self-Attention 计算回顾

首先回顾一下 Attention 机制的核心计算公式:

  • 计算流程:
    1. Query (Q) 与 Key 的转置 (K^T) 进行矩阵相乘,得到初始的注意力分数。
    2. 对分数进行缩放(除以 )。
    3. 应用 Mask(掩码),例如在 Decoder 中防止未来的 token 影响当前 token。
    4. 通过 Softmax 函数将分数归一化为概率分布。
    5. 将归一化后的分数与 Value (V) 矩阵相乘,得到最终的输出。

Screenshot-[01:44]

  • 上下文长度的影响:
    • 在实际应用中,上下文长度 (Context Length) 可能非常长(如 4k, 8k, 甚至 100万)。
    • 这导致 QK 矩阵的行数(序列长度)远大于列数(特征维度 d_model)。
    • 因此,注意力分数矩阵(QK^T)会变成一个巨大的方阵(例如,10k x 10k),这成为计算瓶颈和优化的主要目标。 Screenshot-[02:56]

3. 推理过程中的冗余计算问题

大语言模型的生成过程是自回归 (Auto-regressive) 的,即逐个 token 生成。

  • 生成示例 (不使用 KV Cache):
    • 用户输入: “中华人民”
    • 第1步: 模型输入 “中华人民”,预测出 “共”。
    • 第2步: 模型输入 “中华人民共”,预测出 “和”。
    • 第3步: 模型输入 “中华人民共和”,预测出 “国”。
    • … 以此类推,直到生成结束。

Screenshot-[04:18]

  • 核心问题:

    • 在每一步生成中,模型都会将完整的、已有的序列重新输入并进行完整的 Attention 计算。
    • 例如,在计算第 10 个 token 时,前 9 个 token(如 “中华人民共…“)的 Attention 计算被完全重复了一遍。
    • 这种重复计算随着序列长度的增加而急剧增长,导致推理速度变慢。 Screenshot-[05:17]
  • 矩阵计算视角:

    • 假设输入为 “中华人民”,在预测第 4 个 token “民” 之后的内容时:
    • 不使用 Cache: QKV 都是 4xd_model 的矩阵。计算 QK^T 会得到一个 4x4 的注意力矩阵。
    • 在预测下一个 token 时,输入序列变为 5 个 token,QKV 变为 5xd_model,需要计算一个 5x5 的注意力矩阵。
    • 这个新的 5x5 矩阵的左上角 4x4 部分,其数值与上一步计算的 4x4 矩阵是完全相同的。这就是计算冗余所在。 Screenshot-[09:51]

4. KV Cache 的工作原理与优化

KV Cache 的核心思想是:缓存并重用已经计算过的 Key 和 Value 向量,避免重复计算。

Screenshot-[10:17]

  • 工作流程 (使用 KV Cache):
    1. 第 1 步 (t=1): 输入 “中”,计算出其 K_1, V_1 向量,并将它们存入 Cache 中。
    2. 第 2 步 (t=2): 输入 “华”,只计算 “华” 对应的 Q_2, K_2, V_2
      • Query: 使用最新的 Q_2
      • Key: 从 Cache 中取出 K_1,与新计算的 K_2 拼接成 [K_1, K_2]
      • Value: 从 Cache 中取出 V_1,与新计算的 V_2 拼接成 [V_1, V_2]
      • K_2, V_2 存入 Cache。
    3. 第 3 步 (t=3): 输入 “人”,只计算 “人” 对应的 Q_3, K_3, V_3
      • Query: 使用最新的 Q_3
      • Key: 从 Cache 中取出 [K_1, K_2],与 K_3 拼接成 [K_1, K_2, K_3]
      • Value: 从 Cache 中取出 [V_1, V_2],与 V_3 拼接成 [V_1, V_2, V_3]
      • K_3, V_3 存入 Cache。
    • … 以此类推。

Screenshot-[12:14]

  • 为什么不缓存 Q (Query)?
    • Q 向量代表当前正在处理的 token 的“查询”视角,它需要与所有历史 token(包括它自己)的 K 向量进行交互以计算注意力。
    • 在每一步生成新的 token 时,这个“查询”都是全新的,因此 Q 必须重新计算,不能被缓存。
    • KV 向量代表的是 token 本身的“身份”或“内容”,一旦生成,它们在后续的上下文中保持不变,因此可以被缓存。 Screenshot-[13:00]

5. KV Cache 的性能与内存权衡 (Trade-off)

KV Cache 用空间(内存)换时间(计算速度)

  • 性能提升: 极大地减少了计算量,将 Attention 计算的复杂度从序列长度的二次方级降低到线性级,从而显著提升了推理速度。
  • 内存消耗: 随着序列长度的增加,缓存的 K 和 V 向量会线性增长,导致 GPU 内存占用非常大。

Screenshot-[16:43]

  • 内存占用计算公式:

    • 2: 代表 K 和 V 两个矩阵。
    • batch_size: 批处理大小。
    • sequence_length: 上下文长度。
    • num_layers: 模型的层数(每一层都有自己的 KV Cache)。
    • hidden_dim: 模型的隐藏层维度 (num_heads * head_dim)。
    • precision_bytes: 参数精度占用的字节数(如 FP16 为 2 字节,FP32 为 4 字节)。
  • 实例分析 (Llama 3 7B on A10 GPU):

    • GPU: NVIDIA A10,拥有 24GB 显存。
    • 模型加载: Llama 3 7B 模型(FP16)本身约占 14GB 显存。
    • 可用显存: 剩余约 10GB 用于推理。
    • KV Cache 限制: 这 10GB 显存主要被 KV Cache 占用,理论上最大可支持约 20k token 的上下文长度(在 batch_size=1 的情况下)。
    • 多用户场景: 如果有 2 个并发请求(batch_size=2),最大上下文长度减半至 10k。如果有 100 个并发请求,则骤减至 200 左右。 Screenshot-[18:05]

6. 实践中的 KV Cache

  • Hugging Face 实现:

    • transformers 库中,只需在生成配置或 generate 方法中设置 use_cache=True 即可启用。
    • 性能对比: 生成 1000 个 token 的实验中:
      • 使用 KV Cache: 平均耗时约 1 秒。
      • 不使用 KV Cache: 平均耗时约 56 秒。
      • 速度提升了数十倍。 Screenshot-[21:45]
  • 手动 PyTorch 实现:

    • 创建一个列表或嵌套列表来存储每一层的 K 和 V 缓存。
    • 在生成循环中:
      1. 从当前 token 的 hidden_state 计算出新的 Q_new, K_new, V_new
      2. K_new, V_new 分别拼接到对应层的 K_cache, V_cache 之后。
      3. 使用 Q_new 和更新后的完整 K_cache, V_cache 进行 Attention 计算。
      4. 将更新后的 K_cache, V_cache 传递给下一次循环。 Screenshot-[22:52]

7. 基于 KV Cache 的未来优化方向

KV Cache 的内存瓶颈催生了许多后续的优化技术,主要目标是减少内存占用和计算量。

Screenshot-[23:35]

  • Multi-Head Attention (MHA): 标准的多头注意力机制。
  • Grouped-Query Attention (GQA): 将多个 Q 头分组,共享同一组 K 和 V 头。这显著减少了需要缓存的 K 和 V 向量的数量,从而节省内存。
  • Multi-Query Attention (MQA): GQA 的一个极端形式,所有 Q 头共享唯一的一组 K 和 V 头,最大化地节省内存。
  • Infini-attention: Google 提出的新技术,用于处理无限长的上下文。
  • 量化 (Quantization): 降低 KV Cache 的存储精度(如从 FP16 降至 INT8 或 INT4),以减少内存占用。

AI 总结

KV Cache 是一种针对大语言模型在推理阶段的关键优化技术。其核心原理是通过缓存并重用已经计算过的 Token 的 Key (K) 和 Value (V) 向量,来避免在自回归生成过程中对历史上下文进行重复的 Attention 计算。

这种方法将计算复杂度从序列长度的二次方级降低到线性级,从而极大地提升了 Token 的生成速度。然而,它也带来了显著的副作用:随着上下文序列变长,缓存的数据会线性增长,导致巨大的 GPU 内存(显存)消耗

因此,KV Cache 本质上是一种用“空间换时间”的策略。为了缓解其内存压力,社区发展出了如 GQA (Grouped-Query Attention)、MQA (Multi-Query Attention) 和量化等一系列后续优化技术。理解 KV Cache 是掌握现代大语言模型推理优化的基础。