如果你在调整vLLM的--max-num-seqs参数,或者发现并发请求一多系统就OOM,或者不理解为什么输入越长服务越容易崩——这篇文章解释背后发生了什么。
KV Cache是大模型推理里最重要的工程机制。不理解它,你就没法真正理解推理系统的性能瓶颈在哪里,也没法做出正确的配置和扩容决策。
从注意力计算说起
大模型在生成每个新token时,需要"看到"之前所有已经生成的token。具体的计算方式是注意力机制:对每个已有token,计算它的Key(K)和Value(V)向量,然后用当前token的Query(Q)和所有K做点积,得到注意力权重,最后用这些权重对所有V加权求和。
这里有个关键点:已经生成的token,它的K和V向量是固定的,不会变。每次生成新token,这些K和V都需要重新使用,但值本身不需要重新计算。
如果不缓存这些K和V:
# 没有KV Cache的朴素生成(概念示意)defgenerate_without_cache(model,prompt_tokens,max_new_tokens):generated=list(prompt_tokens)forstepinrange(max_new_tokens):# 每一步都把完整序列输入模型# 模型内部每一层都重新计算所有token的K和V# 序列每长一个token,这一步的计算量就多一份all_tokens=torch.tensor([generated])logits=model(all_tokens)next_token=logits[0,-1].argmax()generated.append(next_token.item())returngenerated# 如果prompt是1000个token,要生成200个token:# 第1步:处理1001个token的完整序列# 第2步:处理1002个token的完整序列# ...# 第200步:处理1200个token的完整序列# 总计算量 ∝ 1001 + 1002 + ... + 1200 ≈ 220,100次token计算# 而实际上1000个prompt token的K和V计算了200次,完全是浪费有了KV Cache:
# 有KV Cache的生成(概念示意)defgenerate_with_cache(model,prompt_tokens,max_new_tokens):# Prefill阶段:处理完整prompt,计算并缓存所有K和Vlogits,kv_cache=model.prefill(prompt_tokens)generated=[]last_token=logits[-1].argmax().item()generated.append(last_token)forstepinrange(max_new_tokens-1):# Decode阶段:只输入最新的一个token# 从缓存直接读取之前所有token的K和V# 只计算新token自己的K和V,然后追加到缓存logits,kv_cache=model.decode(token=last_token,kv_cache=kv_cache# 直接复用)last_token=logits[-1].argmax().item()generated.append(last_token)returngenerated# 计算量:# Prefill:1000个token,计算一次# 每个Decode步骤:只计算1个新token的K和V# 总计算量 ≈ 1000 + 200 × 1 = 1200次token计算# 节省了约99%的重复计算这就是KV Cache的本质:用显存换计算。把已经算好的K和V存在显存里,避免重复计算。
KV Cache占多大显存?这才是工程上的关键
KV Cache的大小是可以精确计算的:
defcompute_kv_cache_size(num_layers:int,num_kv_heads:int,# 注意:是KV heads,不是总heads数head_dim:int,sequence_length:int,dtype_bytes:int=2,# FP16=2字节,INT8=1字节batch_size:int=1)->dict:""" 计算KV Cache的显存占用 公式: KV Cache大小 = 2 × 层数 × KV头数 × 头维度 × 序列长度 × 数据类型字节数 × 批大小 (2是因为K和V各一份) """bytes_per_token=(2*# K和Vnum_layers*num_kv_heads*head_dim*dtype_bytes)total_bytes=bytes_per_token*sequence_length*batch_sizereturn{"bytes_per_token":bytes_per_token,"bytes_per_token_human":f"{bytes_per_token}bytes/token","total_bytes":total_bytes,"total_mb":round(total_bytes/1024**2,2),