大模型推理加速实战:从 KV Cache、量化到连续批处理的性能优化路径
做大模型推理优化,最容易掉进一个坑:只盯着“模型参数量”,却忽略了真正决定线上体验的,往往是 首 Token 延迟、单请求尾延迟、吞吐量、显存占用 这几个指标之间的拉扯。
我第一次系统做这件事时,以为“上更强的 GPU + 开 FP16”就够了,结果上线后发现两个现实问题:
- 用户一多,吞吐立刻掉;
- 长上下文请求一来,显存压力和延迟一起飙升。
这篇文章不讲太空泛的概念,而是按一条更实用的路径走一遍:
- 先理解 KV Cache 为什么能显著降低解码成本;
- 再看 量化 如何在显存和吞吐之间换空间;
- 最后进入更偏工程的 连续批处理(Continuous Batching),把 GPU 真正“喂饱”。
如果你正在做 LLM 服务化、API 网关、RAG 推理节点,或者只是想把本地推理跑快,这条路线很值得掌握。
背景与问题
大模型推理通常分成两个阶段:
- Prefill 阶段:把整段输入 prompt 跑一遍,建立上下文;
- Decode 阶段:每次生成一个新 token,再继续生成下一个。
这两个阶段的瓶颈不一样:
- Prefill 更像大矩阵并行计算,吃算力;
- Decode 更像频繁的小步迭代,容易受显存带宽、调度和 batch 策略影响。
一个常见现象是:
- 短 prompt、短输出:看起来速度还行;
- 长 prompt、多并发:性能迅速恶化。
根因通常有三个:
1. Attention 计算重复太多
如果每生成一个 token,都把历史上下文重新算一遍,代价会越来越高。上下文越长,越慢。
2. 模型太大,显存吃紧
参数、激活值、KV Cache 都要占显存。显存一紧张,batch 放不上去,吞吐就上不来。
3. 请求调度粗糙
很多服务还是“凑整批再算”的静态批处理模式。问题是,LLM 请求长度差异很大,一个慢请求会把整批拖住,形成典型的“木桶短板”。
所以真正可落地的优化顺序,通常不是“上来就魔改模型”,而是:
flowchart LR
A[识别瓶颈] --> B[启用 KV Cache]
B --> C[做权重量化]
C --> D[优化批处理策略]
D --> E[监控 TTFT/TPOT/吞吐/显存]
E --> F[按业务目标迭代]
这里提到几个关键指标:
- TTFT:Time To First Token,首 token 延迟
- TPOT:Time Per Output Token,平均每个输出 token 时间
- Throughput:吞吐量,tokens/s 或 req/s
- Memory Footprint:显存占用
后面我们会围绕这些指标来讲优化路径。
前置知识与环境准备
为了能自己跑一遍示例,建议准备:
- Python 3.10+
- PyTorch 2.x
- CUDA 环境(有 GPU 最好,没有也可以先理解代码)
transformers- 可选:
bitsandbytes用于 8bit / 4bit 量化
安装示例:
pip install torch transformers accelerate
pip install bitsandbytes
如果你在 Apple Silicon 或 CPU 环境上,量化部分可能要稍作调整,但 KV Cache 的原理和代码结构是通用的。
核心原理
这一部分是全文的关键。理解后,你会知道为什么有些优化“看起来只是一个参数”,实际上是在改推理的成本结构。
一、KV Cache:把历史注意力结果存起来
在 Transformer 的自注意力里,每一层都会生成:
- Query(Q)
- Key(K)
- Value(V)
生成第 t 个 token 时,理论上它只需要拿当前 token 的 Q,去和历史所有 token 的 K、V 做注意力计算。
如果不使用 KV Cache,每次生成新 token 都要把历史 token 再经过一遍投影,重复生成 K、V。这种重复工作在长上下文里非常浪费。
用了 KV Cache 之后:
- 旧 token 的 K、V 存起来;
- 新 token 来了,只算它自己的 K、V;
- 再把它拼到缓存里。
于是 Decode 阶段的开销显著下降。
sequenceDiagram
participant U as 用户请求
participant M as 模型
participant C as KV Cache
U->>M: 输入 Prompt
M->>C: 计算并保存历史 K/V
M-->>U: 输出第一个 token
loop 每生成一个 token
U->>M: 继续生成
M->>C: 读取历史 K/V
M->>M: 仅计算当前 token 的 K/V
M->>C: 追加新 K/V
M-->>U: 输出下一个 token
end
KV Cache 的收益与代价
收益:
- 降低 Decode 阶段重复计算;
- 长输出时收益尤其明显;
- 提升 tokens/s。
代价:
- 显存增加:因为每层、每个 head、每个 token 的 K/V 都要存;
- 上下文越长,并发越高,KV Cache 越大。
所以 KV Cache 不是“白嫖优化”,它本质上是 拿显存换速度。
二、量化:拿精度换容量,再换吞吐
量化的目标是把参数从高精度表示变成低精度表示。
常见形式:
- FP16 / BF16:半精度
- INT8:8bit 量化
- INT4:4bit 量化
为什么量化对推理加速有帮助?
1. 参数更小
比如一个 7B 模型:
- FP16:参数大约需要 14GB
- INT8:大约减半
- INT4:还可以继续下降
参数占用变小后,你能做的事就多了:
- 模型能放进单卡;
- 同一张卡能放更大的 batch;
- 给 KV Cache 留出更多显存;
- 减少显存带宽压力。
2. 吞吐通常会上升,但不总是线性
量化不一定永远让延迟变短,因为实际效果还取决于:
- 算子是否有高效实现;
- GPU 架构是否适配;
- 是否引入额外反量化开销;
- batch 是大还是小。
所以量化最核心的工程价值,经常不是“某个算子快了多少”,而是:
同样显存下,我能放进更多请求,整体吞吐更高。
三、连续批处理:动态把 GPU 利用起来
静态批处理的问题是:一批请求必须一起开始、一起结束。
但 LLM 推理里,请求长度差异极大:
- 有人只问一句;
- 有人让模型写 1000 字;
- 有人 prompt 很长;
- 有人 prompt 很短。
如果把这些请求硬塞进一个静态 batch,就会出现:
- 短请求已经完成,但 GPU 还得等长请求;
- 某些 batch slot 空了也不能及时补新请求;
- GPU 利用率不稳定。
连续批处理的思路是:
- 每个推理 step 都重新组织 batch;
- 完成的请求及时移出;
- 新请求随时补进空位;
- 尽量维持 GPU 始终忙碌。
flowchart TB
A[请求进入队列] --> B[调度器]
B --> C{是否有空闲 batch slot}
C -- 有 --> D[加入当前迭代批次]
C -- 无 --> E[继续排队]
D --> F[执行一个 decode step]
F --> G{请求是否结束}
G -- 是 --> H[释放 slot]
G -- 否 --> I[保留在活跃批次]
H --> B
I --> B
这就是很多高性能推理引擎的关键设计之一。它不只是“批处理”,而是 面向 token step 的动态调度。
优化路径:从“能跑”到“跑得稳、跑得快”
如果你现在有一个基础推理服务,我建议按这个顺序优化:
第一步:先打开 KV Cache
适用场景:
- 几乎所有 Decoder-only LLM
- 尤其是长输出、对话场景
优先级非常高,因为它往往是最直接的收益来源。
第二步:做权重量化
适用场景:
- 显存紧张
- 想提高并发数
- 想让更大模型在更小设备上运行
如果你是线上服务,建议优先试:
- BF16 / FP16 作为基线
- INT8 作为稳妥量化方案
- INT4 适合资源更紧的环境,但要仔细验证质量损失
第三步:引入连续批处理
适用场景:
- 多用户并发服务
- 请求长度分布差异大
- GPU 利用率起伏明显
这一项对单请求加速未必最明显,但对 整体吞吐和成本 非常关键。
实战代码(可运行)
下面用一个尽量简洁、能直接运行的示例,演示三件事:
- 如何启用 KV Cache;
- 如何对比量化/非量化加载方式;
- 如何做一个“简化版连续批处理调度器”。
说明:这里以 Hugging Face
transformers为例。真正生产环境往往会用更专业的推理引擎,但这份代码足够帮助你验证原理。
示例一:基线推理与 KV Cache
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_NAME = "distilgpt2" # 演示用小模型,方便快速跑通
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
prompt = "请用简单的话解释什么是 KV Cache:"
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
start = time.time()
outputs = model.generate(
**inputs,
max_new_tokens=50,
use_cache=True, # 启用 KV Cache
do_sample=False
)
end = time.time()
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(text)
print(f"Elapsed: {end - start:.3f}s")
这个例子最关键的参数就是:
use_cache=True
在支持的模型中,这会让生成阶段复用历史 K/V,而不是每步都重算历史。
示例二:手动解码,观察 cache 的工作方式
如果你想更明确看到 past_key_values 是怎么流动的,可以自己写解码循环。
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_NAME = "distilgpt2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
prompt = "The role of KV cache in transformer inference is"
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
generated = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
past_key_values = None
max_new_tokens = 30
with torch.no_grad():
start = time.time()
for step in range(max_new_tokens):
if past_key_values is None:
outputs = model(
input_ids=generated,
attention_mask=attention_mask,
use_cache=True
)
else:
outputs = model(
input_ids=generated[:, -1:],
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True
)
logits = outputs.logits[:, -1, :]
next_token = torch.argmax(logits, dim=-1, keepdim=True)
past_key_values = outputs.past_key_values
generated = torch.cat([generated, next_token], dim=-1)
attention_mask = torch.cat(
[attention_mask, torch.ones((attention_mask.shape[0], 1), device=DEVICE, dtype=attention_mask.dtype)],
dim=-1
)
end = time.time()
print(tokenizer.decode(generated[0], skip_special_tokens=True))
print(f"Elapsed: {end - start:.3f}s")
这里有两个观察点:
- 第一次前向会处理完整 prompt;
- 后续每一步只喂最后一个 token,并复用
past_key_values。
这就是 Decode 阶段被加速的根本原因。
示例三:8bit 量化加载
如果环境支持 bitsandbytes,可以试试 8bit 量化。
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
MODEL_NAME = "distilgpt2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
quant_config = BitsAndBytesConfig(
load_in_8bit=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=quant_config,
device_map="auto"
)
model.eval()
prompt = "Explain why quantization can help LLM inference throughput:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
start = time.time()
outputs = model.generate(
**inputs,
max_new_tokens=50,
use_cache=True,
do_sample=False
)
end = time.time()
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
print(f"Elapsed: {end - start:.3f}s")
注意:
- 小模型上你不一定能看到非常明显的时间优势;
- 但在大模型上,量化带来的显存释放往往能转化成更高 batch、更高并发和更稳定的服务能力。
示例四:一个简化版连续批处理调度器
这不是工业级实现,但能帮助你建立正确的思维模型:按 step 调度,而不是整请求调度。
from collections import deque
class Request:
def __init__(self, req_id, prompt_len, max_new_tokens):
self.req_id = req_id
self.prompt_len = prompt_len
self.max_new_tokens = max_new_tokens
self.generated = 0
self.finished = False
def step(self):
if not self.finished:
self.generated += 1
if self.generated >= self.max_new_tokens:
self.finished = True
class ContinuousBatchScheduler:
def __init__(self, max_batch_size):
self.max_batch_size = max_batch_size
self.waiting_queue = deque()
self.active_batch = []
def submit(self, req):
self.waiting_queue.append(req)
def refill(self):
while len(self.active_batch) < self.max_batch_size and self.waiting_queue:
self.active_batch.append(self.waiting_queue.popleft())
def run_one_step(self):
self.refill()
if not self.active_batch:
return False
print(f"\n[STEP] active={[r.req_id for r in self.active_batch]}")
for req in self.active_batch:
req.step()
print(f" req={req.req_id}, generated={req.generated}/{req.max_new_tokens}, finished={req.finished}")
self.active_batch = [r for r in self.active_batch if not r.finished]
return True
if __name__ == "__main__":
scheduler = ContinuousBatchScheduler(max_batch_size=3)
scheduler.submit(Request("A", prompt_len=20, max_new_tokens=2))
scheduler.submit(Request("B", prompt_len=100, max_new_tokens=5))
scheduler.submit(Request("C", prompt_len=50, max_new_tokens=3))
scheduler.submit(Request("D", prompt_len=10, max_new_tokens=4))
scheduler.submit(Request("E", prompt_len=80, max_new_tokens=1))
while scheduler.run_one_step():
pass
运行后你会看到类似行为:
- 第一轮 A/B/C 进入活跃批次;
- A 或 C 提前结束;
- 空位立刻由 D、E 补上。
这就是连续批处理的本质:让 batch 槽位持续流动,而不是等整批全部结束。
逐步验证清单
如果你准备把这些优化真正落地,建议按下面顺序验证,不要一口气全开。否则出了问题很难定位。
验证 1:只打开 KV Cache
观察:
- TTFT 是否变化不大或略变;
- TPOT 是否明显下降;
- 长输出是否更快。
验证 2:引入量化
观察:
- 显存是否显著下降;
- 模型输出质量是否可接受;
- 单请求延迟是否变化;
- 最大并发是否上升。
验证 3:引入连续批处理
观察:
- GPU 利用率是否更平稳;
- 高峰期吞吐是否增加;
- p95 / p99 延迟是否改善;
- 是否出现短请求被长请求拖慢的问题缓解。
验证 4:做组合测试
要测下面几种组合:
- FP16 + KV Cache
- INT8 + KV Cache
- FP16 + KV Cache + 连续批处理
- INT8 + KV Cache + 连续批处理
很多时候,最佳方案不是“最激进量化”,而是 在质量、吞吐、稳定性之间最平衡的那一组。
常见坑与排查
这一段我尽量写得贴近实战,因为这些坑真的很常见。
坑一:开了 KV Cache,显存却爆了
原因通常是:
- 上下文太长;
- 并发太高;
- batch 太大;
- 模型层数/head 数本身就多。
排查思路:
- 降低
max_new_tokens; - 限制最大上下文长度;
- 降低并发或 batch;
- 打印显存曲线,确认是不是 KV Cache 增长导致。
可以用下面代码查看显存:
import torch
if torch.cuda.is_available():
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
坑二:量化后反而变慢
这很正常,不必惊讶。
可能原因:
- 当前 GPU 对量化算子支持一般;
- 模型太小,量化收益不明显;
- batch 太小,吞吐优势体现不出来;
- 框架存在反量化额外开销。
排查建议:
- 不只看单请求耗时,也看 吞吐和最大并发;
- 对比不同 batch size;
- 在同样质量要求下评估 INT8 与 FP16;
- 如果 INT4 质量掉得明显,先回到 INT8。
坑三:连续批处理后,首 token 延迟变差
这是调度策略引起的常见副作用。
如果调度器过于追求吞吐,可能会:
- 新请求等待太久才被加入;
- Prefill 阶段和 Decode 阶段互相争抢资源。
排查方向:
- 将 Prefill 和 Decode 分开看;
- 给新请求保留一定接入优先级;
- 限制单次调度窗口,避免老请求长期占满。
一句话总结:连续批处理不是只看吞吐,还要兼顾 TTFT。
坑四:不同长度请求混跑,尾延迟很难看
原因:
- 长 prompt 的 Prefill 很重;
- 长输出请求在 decode 阶段停留太久;
- 调度没有做长度感知。
可尝试的办法:
- 按 prompt 长度分桶;
- 按预估输出长度做优先级;
- 将超长请求转入单独队列。
这类优化很像“排队论 + GPU 调度”的结合,不是单纯模型问题。
坑五:线上效果和离线压测差别很大
离线压测常常是规则输入,而线上流量通常更复杂:
- 上下文长度波动大;
- 多租户混部;
- 采样参数不同;
- 工具调用、RAG 检索等引入额外延迟。
解决方式:
- 压测数据尽量模拟真实长度分布;
- 区分纯模型时间与端到端时间;
- 将 TTFT、TPOT、上下文长度、输出长度一起打点。
安全/性能最佳实践
这一节我把“稳定上线”最重要的建议放在一起。
1. 为上下文长度设硬限制
不要让请求无限长。否则:
- KV Cache 无限膨胀;
- 单请求拖垮整机;
- 尾延迟急剧恶化。
建议:
- 设置最大输入 token;
- 超限直接截断或拒绝;
- 对系统 prompt、检索内容做预算分配。
2. 给输出长度设上限
max_new_tokens 一定要限制。
否则某些请求会长期占据 decode slot,影响整批吞吐。对外服务时,最好根据业务类型给不同上限,比如:
- 聊天问答:128~512
- 摘要改写:256~1024
- 代码生成:按场景单独评估
3. 量化前先定义质量红线
不要只看“能不能跑”,还要看:
- 事实性是否下降;
- 代码生成是否退化;
- 中文、多轮对话是否受影响。
建议至少保留一套基线评测集,比较:
- 正确率
- 输出稳定性
- 幻觉率
- 特定业务任务表现
4. 监控必须区分 Prefill 和 Decode
很多系统只看总耗时,这不够。
最好分别监控:
- Prefill latency
- TTFT
- TPOT
- tokens/s
- active batch size
- KV Cache 使用量
- GPU memory / utilization
这样你才能知道问题是在“输入太长”,还是“解码调度不行”。
classDiagram
class Metrics {
+prefill_latency
+ttft
+tpot
+throughput_tokens_per_s
+gpu_utilization
+gpu_memory
+kv_cache_usage
+active_batch_size
}
5. 连续批处理要有公平性策略
如果只追求吞吐,可能让短请求、刚进入的新请求体验很差。
建议至少考虑:
- 新请求最大等待时间;
- 长请求限速或分级;
- 不同租户配额;
- 超长任务降级到低优先级队列。
6. 谨慎清理与复用缓存
KV Cache 很有价值,但也要注意边界:
- 不同请求之间不要错误复用;
- 会话结束要及时释放;
- 多租户场景避免缓存串用导致数据泄漏。
这是性能问题,也是安全问题。
7. 不要迷信单一优化
在真实系统里,最有效的通常不是某一个点,而是组合拳:
- KV Cache 降低重复计算;
- 量化释放显存;
- 连续批处理提高利用率;
- 限长和监控保证稳定性。
如果只做其中一个,收益常常有限;如果组合得好,才会出现比较明显的跃迁。
一个实用的决策表
如果你要快速判断“先做什么”,可以参考这张表。
| 现象 | 主要瓶颈 | 优先优化 |
|---|---|---|
| 长输出越来越慢 | Decode 重复计算 | KV Cache |
| 显存放不下模型或 batch 很小 | 参数占用高 | 量化 |
| 高并发时吞吐差、GPU 忙闲不均 | 调度粗糙 | 连续批处理 |
| 首 token 很慢 | Prefill 重、排队多 | 限长 + 调度优化 |
| p99 延迟很差 | 长请求拖累 | 长度分桶 + 连续批处理 |
总结
把大模型推理优化讲得再复杂,落到工程上,最重要的其实是三件事:
- KV Cache:减少 Decode 阶段重复计算,是最基础也最常见的提速手段;
- 量化:核心价值不只是“更快”,更是“更省显存,从而能支撑更高并发和更大 batch”;
- 连续批处理:解决多请求场景下 GPU 利用率不稳的问题,是服务化部署的关键能力。
如果你问我一个最实用的落地顺序,我会建议:
- 单机先跑通 KV Cache;
- 再评估 INT8/INT4 量化 对质量和显存的影响;
- 最后在并发场景上引入 连续批处理,并配套监控 TTFT、TPOT、吞吐、显存。
边界条件也要记住:
- 显存非常紧时,KV Cache 可能成为新瓶颈;
- 量化并不保证单请求一定更快;
- 连续批处理如果调不好,会牺牲首 token 延迟。
所以,真正成熟的方案不是“某个参数一开就万事大吉”,而是建立一条 可观测、可回退、可分阶段验证 的优化路径。
如果你现在手上正有一个 LLM 服务,最值得立刻做的一件事是:
先把 TTFT、TPOT、上下文长度、输出长度、显存占用打点起来。
没有这些数据,优化只能靠感觉;有了这些数据,KV Cache、量化、连续批处理该怎么排优先级,通常一眼就清楚了。