818 words
4 minutes
【八股】KV Cache
1. 啥是 KV cache
KV cache 是 Transformer 在自回归生成(decoder-only 或 encoder-decoder 的 decoder 侧)做推理加速的核心机制:把历史 token 在每一层注意力里算出来的 Key/Value 张量缓存起来,后续每生成一个新 token 时直接复用这些 K/V,而不是把整段上下文再前向一遍去重算一遍,是一种空间换时间的优化。
2. KV cache 是怎么来的
既然 KV cache 是 decoder 推理时的优化,我们不妨看看优化前的推理过程有哪些不足的地方。
自回归解码(decoder-only 或 decoder 部分)在第 步:给定前缀 ,需要计算当前位置的表示并预测 。
在某一层 self-attention(先忽略多头,或把它理解成对每个 head 独立做同样的事),对该层输入隐状态矩阵
做三次线性投影:
注意力输出为:
其中 是 causal mask(上三角为 ),保证只能看历史。
不加 KV cache 的问题在于当从 生成到 时,历史部分 在同一条生成轨迹里并不会变化,但你仍会在每一层、每一步重复计算它们对应的 和 (以及相关张量拼接),这就是大量冗余。
KV cache 的想法是只缓存未来还会反复被用到的量,也就是历史 token 的 。
把第 步拆开写:
- 只对新 token 的隐状态 投影:
- 维护缓存(逐步 append):
- 第 步输出(只需要当前 query 对全历史 keys 做匹配):
那为什么只对 进行缓存而不缓存 ?
因为第 步只用到当前 去查历史 ,过去的 不会在后续步骤再被用到,因此缓存它没有复用价值。
3. 空间占用分析
这部分参考了 satsuki26681534 的 【大模型】什么是KV cache
可以这样算 KV cache 所占用的空间:
举个栗子:
层数 L = 32隐藏维度 d = 4096头数 h = 32每个头维度 d_h = 128数据类型: float16 (2字节)
每个token每层缓存大小 = 2 × (h × d_h) × 2字节 = 2 × 4096 × 2 = 16,384字节 ≈ 16KB
每个token总缓存 = 16KB × 32层 = 512KB
生成1000个token的缓存 = 512KB × 1000 = 512MB4. 代码实现(带 KV cache 的 MHA)
下面代码在之前笔记中的 MHA 代码基础上修改。
class MultiHeadAttention(nn.Module): def __init__(self, input_dim, num_heads, dim_qk, dim_v): super().__init__() self.num_heads = num_heads self.head_dim_qk = dim_qk // num_heads self.head_dim_v = dim_v // num_heads
self.q = nn.Linear(input_dim, dim_qk) self.k = nn.Linear(input_dim, dim_qk) self.v = nn.Linear(input_dim, dim_v)
self.scale = sqrt(self.head_dim_qk)
self.out = nn.Linear(dim_v, input_dim)
#新增,初始化 cache self.cache_k = None self.cache_v = None
def forward(self, x, use_cache = False): batch, seq = x.shape[:2]
q = self.q(x) k = self.k(x) v = self.v(x)
q = q.view(batch, seq, self.num_heads, self.head_dim_qk).transpose(1, 2) k = k.view(batch, seq, self.num_heads, self.head_dim_qk).transpose(1, 2) v = v.view(batch, seq, self.num_heads, self.head_dim_v).transpose(1, 2)
#新增,如果使用缓存且缓存非空则拼接 KV if use_cache and self.cache_k is not None: k = torch.cat([self.cache_k, k], dim = -2) v = torch.cat([self.cache_v, v], dim = -2) #新增,更新 cache if use_cache: self.cache_k = k self.cache_v = v
scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale weights = torch.softmax(scores, dim = -1)
out = torch.matmul(weights, v) out = out.transpose(1, 2).contiguous().view(batch, seq, -1) out = self.out(out)
return out 【八股】KV Cache
https://fuwari.vercel.app/posts/note/kvcache/