目录
- 1. KV Cache 核心概念
- 2. Self-Attention 计算回顾
- 3. 推理过程中的冗余计算问题
- 4. KV Cache 的工作原理与优化
- 5. KV Cache 的性能与内存权衡 (Trade-off)
- 6. 实践中的 KV Cache
- 7. 基于 KV Cache 的未来优化方向
- AI 总结
1. KV Cache 核心概念
KV Cache 是一种应用于大语言模型推理 (Inference) 阶段的关键优化技术。
Screenshot-[00:22]
-
1. 应用于推理阶段:
- KV Cache 不用于训练,仅在模型生成内容时使用。
- 核心前提是:在自回归生成过程中,对于已经生成的 token,其 K (Key) 和 V (Value) 向量是不会改变的。
-
2. 存在于 Decoder 结构中:
- KV Cache 是 Transformer 模型中 Decoder(解码器)架构的一部分。
- 纯 Encoder(编码器)架构的模型(如 BERT)无法使用 KV Cache。
- 对于 Encoder-Decoder 架构,KV Cache 也只作用于 Decoder 部分。
-
3. 主要目的:
- 提升 Attention 计算速度,尤其是在
QKV的两次矩阵相乘过程中,从而加快 token 的生成速度。 Screenshot-[01:01]
- 提升 Attention 计算速度,尤其是在
-
4. 副作用:
- 增加内存占用。Cache 本质是缓存,需要将大量的 K 和 V 向量存储在 GPU 内存中,这会导致显存消耗显著增加。
2. Self-Attention 计算回顾
首先回顾一下 Attention 机制的核心计算公式:
- 计算流程:
- Query (
Q) 与 Key 的转置 (K^T) 进行矩阵相乘,得到初始的注意力分数。 - 对分数进行缩放(除以 )。
- 应用 Mask(掩码),例如在 Decoder 中防止未来的 token 影响当前 token。
- 通过
Softmax函数将分数归一化为概率分布。 - 将归一化后的分数与 Value (
V) 矩阵相乘,得到最终的输出。
- Query (
Screenshot-[01:44]
- 上下文长度的影响:
- 在实际应用中,上下文长度 (
Context Length) 可能非常长(如 4k, 8k, 甚至 100万)。 - 这导致
Q和K矩阵的行数(序列长度)远大于列数(特征维度d_model)。 - 因此,注意力分数矩阵(
QK^T)会变成一个巨大的方阵(例如,10k x 10k),这成为计算瓶颈和优化的主要目标。 Screenshot-[02:56]
- 在实际应用中,上下文长度 (
3. 推理过程中的冗余计算问题
大语言模型的生成过程是自回归 (Auto-regressive) 的,即逐个 token 生成。
- 生成示例 (不使用 KV Cache):
- 用户输入: “中华人民”
- 第1步: 模型输入 “中华人民”,预测出 “共”。
- 第2步: 模型输入 “中华人民共”,预测出 “和”。
- 第3步: 模型输入 “中华人民共和”,预测出 “国”。
- … 以此类推,直到生成结束。
Screenshot-[04:18]
-
核心问题:
- 在每一步生成中,模型都会将完整的、已有的序列重新输入并进行完整的 Attention 计算。
- 例如,在计算第 10 个 token 时,前 9 个 token(如 “中华人民共…“)的 Attention 计算被完全重复了一遍。
- 这种重复计算随着序列长度的增加而急剧增长,导致推理速度变慢。 Screenshot-[05:17]
-
矩阵计算视角:
- 假设输入为 “中华人民”,在预测第 4 个 token “民” 之后的内容时:
- 不使用 Cache:
Q、K、V都是 4xd_model的矩阵。计算QK^T会得到一个 4x4 的注意力矩阵。 - 在预测下一个 token 时,输入序列变为 5 个 token,
Q、K、V变为 5xd_model,需要计算一个 5x5 的注意力矩阵。 - 这个新的 5x5 矩阵的左上角 4x4 部分,其数值与上一步计算的 4x4 矩阵是完全相同的。这就是计算冗余所在。 Screenshot-[09:51]
4. KV Cache 的工作原理与优化
KV Cache 的核心思想是:缓存并重用已经计算过的 Key 和 Value 向量,避免重复计算。
Screenshot-[10:17]
- 工作流程 (使用 KV Cache):
- 第 1 步 (t=1): 输入 “中”,计算出其
K_1,V_1向量,并将它们存入 Cache 中。 - 第 2 步 (t=2): 输入 “华”,只计算 “华” 对应的
Q_2,K_2,V_2。- Query: 使用最新的
Q_2。 - Key: 从 Cache 中取出
K_1,与新计算的K_2拼接成[K_1, K_2]。 - Value: 从 Cache 中取出
V_1,与新计算的V_2拼接成[V_1, V_2]。 - 将
K_2,V_2存入 Cache。
- Query: 使用最新的
- 第 3 步 (t=3): 输入 “人”,只计算 “人” 对应的
Q_3,K_3,V_3。- Query: 使用最新的
Q_3。 - Key: 从 Cache 中取出
[K_1, K_2],与K_3拼接成[K_1, K_2, K_3]。 - Value: 从 Cache 中取出
[V_1, V_2],与V_3拼接成[V_1, V_2, V_3]。 - 将
K_3,V_3存入 Cache。
- Query: 使用最新的
- … 以此类推。
- 第 1 步 (t=1): 输入 “中”,计算出其
Screenshot-[12:14]
- 为什么不缓存 Q (Query)?
Q向量代表当前正在处理的 token 的“查询”视角,它需要与所有历史 token(包括它自己)的K向量进行交互以计算注意力。- 在每一步生成新的 token 时,这个“查询”都是全新的,因此
Q必须重新计算,不能被缓存。 K和V向量代表的是 token 本身的“身份”或“内容”,一旦生成,它们在后续的上下文中保持不变,因此可以被缓存。 Screenshot-[13:00]
5. KV Cache 的性能与内存权衡 (Trade-off)
KV Cache 用空间(内存)换时间(计算速度)。
- 性能提升: 极大地减少了计算量,将 Attention 计算的复杂度从序列长度的二次方级降低到线性级,从而显著提升了推理速度。
- 内存消耗: 随着序列长度的增加,缓存的 K 和 V 向量会线性增长,导致 GPU 内存占用非常大。
Screenshot-[16:43]
-
内存占用计算公式:
2: 代表 K 和 V 两个矩阵。batch_size: 批处理大小。sequence_length: 上下文长度。num_layers: 模型的层数(每一层都有自己的 KV Cache)。hidden_dim: 模型的隐藏层维度 (num_heads*head_dim)。precision_bytes: 参数精度占用的字节数(如 FP16 为 2 字节,FP32 为 4 字节)。
-
实例分析 (Llama 3 7B on A10 GPU):
- GPU: NVIDIA A10,拥有 24GB 显存。
- 模型加载: Llama 3 7B 模型(FP16)本身约占 14GB 显存。
- 可用显存: 剩余约 10GB 用于推理。
- KV Cache 限制: 这 10GB 显存主要被 KV Cache 占用,理论上最大可支持约 20k token 的上下文长度(在
batch_size=1的情况下)。 - 多用户场景: 如果有 2 个并发请求(
batch_size=2),最大上下文长度减半至 10k。如果有 100 个并发请求,则骤减至 200 左右。 Screenshot-[18:05]
6. 实践中的 KV Cache
-
Hugging Face 实现:
- 在
transformers库中,只需在生成配置或generate方法中设置use_cache=True即可启用。 - 性能对比: 生成 1000 个 token 的实验中:
- 使用 KV Cache: 平均耗时约 1 秒。
- 不使用 KV Cache: 平均耗时约 56 秒。
- 速度提升了数十倍。 Screenshot-[21:45]
- 在
-
手动 PyTorch 实现:
- 创建一个列表或嵌套列表来存储每一层的 K 和 V 缓存。
- 在生成循环中:
- 从当前 token 的
hidden_state计算出新的Q_new,K_new,V_new。 - 将
K_new,V_new分别拼接到对应层的K_cache,V_cache之后。 - 使用
Q_new和更新后的完整K_cache,V_cache进行 Attention 计算。 - 将更新后的
K_cache,V_cache传递给下一次循环。 Screenshot-[22:52]
- 从当前 token 的
7. 基于 KV Cache 的未来优化方向
KV Cache 的内存瓶颈催生了许多后续的优化技术,主要目标是减少内存占用和计算量。
Screenshot-[23:35]
- Multi-Head Attention (MHA): 标准的多头注意力机制。
- Grouped-Query Attention (GQA): 将多个 Q 头分组,共享同一组 K 和 V 头。这显著减少了需要缓存的 K 和 V 向量的数量,从而节省内存。
- Multi-Query Attention (MQA): GQA 的一个极端形式,所有 Q 头共享唯一的一组 K 和 V 头,最大化地节省内存。
- Infini-attention: Google 提出的新技术,用于处理无限长的上下文。
- 量化 (Quantization): 降低 KV Cache 的存储精度(如从 FP16 降至 INT8 或 INT4),以减少内存占用。
AI 总结
KV Cache 是一种针对大语言模型在推理阶段的关键优化技术。其核心原理是通过缓存并重用已经计算过的 Token 的 Key (K) 和 Value (V) 向量,来避免在自回归生成过程中对历史上下文进行重复的 Attention 计算。
这种方法将计算复杂度从序列长度的二次方级降低到线性级,从而极大地提升了 Token 的生成速度。然而,它也带来了显著的副作用:随着上下文序列变长,缓存的数据会线性增长,导致巨大的 GPU 内存(显存)消耗。
因此,KV Cache 本质上是一种用“空间换时间”的策略。为了缓解其内存压力,社区发展出了如 GQA (Grouped-Query Attention)、MQA (Multi-Query Attention) 和量化等一系列后续优化技术。理解 KV Cache 是掌握现代大语言模型推理优化的基础。