跳转到内容
123xiao | 无名键客

《大模型推理加速实战:从 KV Cache、量化到连续批处理的性能优化路径》

字数: 0 阅读时长: 1 分钟

大模型推理加速实战:从 KV Cache、量化到连续批处理的性能优化路径

做大模型推理优化,最容易掉进一个坑:只盯着“模型参数量”,却忽略了真正决定线上体验的,往往是 首 Token 延迟、单请求尾延迟、吞吐量、显存占用 这几个指标之间的拉扯。

我第一次系统做这件事时,以为“上更强的 GPU + 开 FP16”就够了,结果上线后发现两个现实问题:

  1. 用户一多,吞吐立刻掉;
  2. 长上下文请求一来,显存压力和延迟一起飙升。

这篇文章不讲太空泛的概念,而是按一条更实用的路径走一遍:

  • 先理解 KV Cache 为什么能显著降低解码成本;
  • 再看 量化 如何在显存和吞吐之间换空间;
  • 最后进入更偏工程的 连续批处理(Continuous Batching),把 GPU 真正“喂饱”。

如果你正在做 LLM 服务化、API 网关、RAG 推理节点,或者只是想把本地推理跑快,这条路线很值得掌握。


背景与问题

大模型推理通常分成两个阶段:

  1. Prefill 阶段:把整段输入 prompt 跑一遍,建立上下文;
  2. Decode 阶段:每次生成一个新 token,再继续生成下一个。

这两个阶段的瓶颈不一样:

  • Prefill 更像大矩阵并行计算,吃算力;
  • Decode 更像频繁的小步迭代,容易受显存带宽、调度和 batch 策略影响。

一个常见现象是:

  • 短 prompt、短输出:看起来速度还行;
  • 长 prompt、多并发:性能迅速恶化。

根因通常有三个:

1. Attention 计算重复太多

如果每生成一个 token,都把历史上下文重新算一遍,代价会越来越高。上下文越长,越慢。

2. 模型太大,显存吃紧

参数、激活值、KV Cache 都要占显存。显存一紧张,batch 放不上去,吞吐就上不来。

3. 请求调度粗糙

很多服务还是“凑整批再算”的静态批处理模式。问题是,LLM 请求长度差异很大,一个慢请求会把整批拖住,形成典型的“木桶短板”。

所以真正可落地的优化顺序,通常不是“上来就魔改模型”,而是:

flowchart LR
    A[识别瓶颈] --> B[启用 KV Cache]
    B --> C[做权重量化]
    C --> D[优化批处理策略]
    D --> E[监控 TTFT/TPOT/吞吐/显存]
    E --> F[按业务目标迭代]

这里提到几个关键指标:

  • TTFT:Time To First Token,首 token 延迟
  • TPOT:Time Per Output Token,平均每个输出 token 时间
  • Throughput:吞吐量,tokens/s 或 req/s
  • Memory Footprint:显存占用

后面我们会围绕这些指标来讲优化路径。


前置知识与环境准备

为了能自己跑一遍示例,建议准备:

  • Python 3.10+
  • PyTorch 2.x
  • CUDA 环境(有 GPU 最好,没有也可以先理解代码)
  • transformers
  • 可选:bitsandbytes 用于 8bit / 4bit 量化

安装示例:

pip install torch transformers accelerate
pip install bitsandbytes

如果你在 Apple Silicon 或 CPU 环境上,量化部分可能要稍作调整,但 KV Cache 的原理和代码结构是通用的。


核心原理

这一部分是全文的关键。理解后,你会知道为什么有些优化“看起来只是一个参数”,实际上是在改推理的成本结构。

一、KV Cache:把历史注意力结果存起来

在 Transformer 的自注意力里,每一层都会生成:

  • Query(Q)
  • Key(K)
  • Value(V)

生成第 t 个 token 时,理论上它只需要拿当前 token 的 Q,去和历史所有 token 的 K、V 做注意力计算。

如果不使用 KV Cache,每次生成新 token 都要把历史 token 再经过一遍投影,重复生成 K、V。这种重复工作在长上下文里非常浪费。

用了 KV Cache 之后:

  • 旧 token 的 K、V 存起来;
  • 新 token 来了,只算它自己的 K、V;
  • 再把它拼到缓存里。

于是 Decode 阶段的开销显著下降。

sequenceDiagram
    participant U as 用户请求
    participant M as 模型
    participant C as KV Cache

    U->>M: 输入 Prompt
    M->>C: 计算并保存历史 K/V
    M-->>U: 输出第一个 token

    loop 每生成一个 token
        U->>M: 继续生成
        M->>C: 读取历史 K/V
        M->>M: 仅计算当前 token 的 K/V
        M->>C: 追加新 K/V
        M-->>U: 输出下一个 token
    end

KV Cache 的收益与代价

收益:

  • 降低 Decode 阶段重复计算;
  • 长输出时收益尤其明显;
  • 提升 tokens/s。

代价:

  • 显存增加:因为每层、每个 head、每个 token 的 K/V 都要存;
  • 上下文越长,并发越高,KV Cache 越大。

所以 KV Cache 不是“白嫖优化”,它本质上是 拿显存换速度


二、量化:拿精度换容量,再换吞吐

量化的目标是把参数从高精度表示变成低精度表示。

常见形式:

  • FP16 / BF16:半精度
  • INT8:8bit 量化
  • INT4:4bit 量化

为什么量化对推理加速有帮助?

1. 参数更小

比如一个 7B 模型:

  • FP16:参数大约需要 14GB
  • INT8:大约减半
  • INT4:还可以继续下降

参数占用变小后,你能做的事就多了:

  • 模型能放进单卡;
  • 同一张卡能放更大的 batch;
  • 给 KV Cache 留出更多显存;
  • 减少显存带宽压力。

2. 吞吐通常会上升,但不总是线性

量化不一定永远让延迟变短,因为实际效果还取决于:

  • 算子是否有高效实现;
  • GPU 架构是否适配;
  • 是否引入额外反量化开销;
  • batch 是大还是小。

所以量化最核心的工程价值,经常不是“某个算子快了多少”,而是:

同样显存下,我能放进更多请求,整体吞吐更高。


三、连续批处理:动态把 GPU 利用起来

静态批处理的问题是:一批请求必须一起开始、一起结束。

但 LLM 推理里,请求长度差异极大:

  • 有人只问一句;
  • 有人让模型写 1000 字;
  • 有人 prompt 很长;
  • 有人 prompt 很短。

如果把这些请求硬塞进一个静态 batch,就会出现:

  • 短请求已经完成,但 GPU 还得等长请求;
  • 某些 batch slot 空了也不能及时补新请求;
  • GPU 利用率不稳定。

连续批处理的思路是:

  • 每个推理 step 都重新组织 batch;
  • 完成的请求及时移出;
  • 新请求随时补进空位;
  • 尽量维持 GPU 始终忙碌。
flowchart TB
    A[请求进入队列] --> B[调度器]
    B --> C{是否有空闲 batch slot}
    C -- 有 --> D[加入当前迭代批次]
    C -- 无 --> E[继续排队]
    D --> F[执行一个 decode step]
    F --> G{请求是否结束}
    G -- 是 --> H[释放 slot]
    G -- 否 --> I[保留在活跃批次]
    H --> B
    I --> B

这就是很多高性能推理引擎的关键设计之一。它不只是“批处理”,而是 面向 token step 的动态调度


优化路径:从“能跑”到“跑得稳、跑得快”

如果你现在有一个基础推理服务,我建议按这个顺序优化:

第一步:先打开 KV Cache

适用场景:

  • 几乎所有 Decoder-only LLM
  • 尤其是长输出、对话场景

优先级非常高,因为它往往是最直接的收益来源。

第二步:做权重量化

适用场景:

  • 显存紧张
  • 想提高并发数
  • 想让更大模型在更小设备上运行

如果你是线上服务,建议优先试:

  • BF16 / FP16 作为基线
  • INT8 作为稳妥量化方案
  • INT4 适合资源更紧的环境,但要仔细验证质量损失

第三步:引入连续批处理

适用场景:

  • 多用户并发服务
  • 请求长度分布差异大
  • GPU 利用率起伏明显

这一项对单请求加速未必最明显,但对 整体吞吐和成本 非常关键。


实战代码(可运行)

下面用一个尽量简洁、能直接运行的示例,演示三件事:

  1. 如何启用 KV Cache;
  2. 如何对比量化/非量化加载方式;
  3. 如何做一个“简化版连续批处理调度器”。

说明:这里以 Hugging Face transformers 为例。真正生产环境往往会用更专业的推理引擎,但这份代码足够帮助你验证原理。


示例一:基线推理与 KV Cache

import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_NAME = "distilgpt2"  # 演示用小模型,方便快速跑通
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()

prompt = "请用简单的话解释什么是 KV Cache:"
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

with torch.no_grad():
    start = time.time()
    outputs = model.generate(
        **inputs,
        max_new_tokens=50,
        use_cache=True,   # 启用 KV Cache
        do_sample=False
    )
    end = time.time()

text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(text)
print(f"Elapsed: {end - start:.3f}s")

这个例子最关键的参数就是:

use_cache=True

在支持的模型中,这会让生成阶段复用历史 K/V,而不是每步都重算历史。


示例二:手动解码,观察 cache 的工作方式

如果你想更明确看到 past_key_values 是怎么流动的,可以自己写解码循环。

import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_NAME = "distilgpt2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()

prompt = "The role of KV cache in transformer inference is"
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

generated = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
past_key_values = None

max_new_tokens = 30

with torch.no_grad():
    start = time.time()

    for step in range(max_new_tokens):
        if past_key_values is None:
            outputs = model(
                input_ids=generated,
                attention_mask=attention_mask,
                use_cache=True
            )
        else:
            outputs = model(
                input_ids=generated[:, -1:],
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                use_cache=True
            )

        logits = outputs.logits[:, -1, :]
        next_token = torch.argmax(logits, dim=-1, keepdim=True)
        past_key_values = outputs.past_key_values

        generated = torch.cat([generated, next_token], dim=-1)
        attention_mask = torch.cat(
            [attention_mask, torch.ones((attention_mask.shape[0], 1), device=DEVICE, dtype=attention_mask.dtype)],
            dim=-1
        )

    end = time.time()

print(tokenizer.decode(generated[0], skip_special_tokens=True))
print(f"Elapsed: {end - start:.3f}s")

这里有两个观察点:

  • 第一次前向会处理完整 prompt;
  • 后续每一步只喂最后一个 token,并复用 past_key_values

这就是 Decode 阶段被加速的根本原因。


示例三:8bit 量化加载

如果环境支持 bitsandbytes,可以试试 8bit 量化。

import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

MODEL_NAME = "distilgpt2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

quant_config = BitsAndBytesConfig(
    load_in_8bit=True
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=quant_config,
    device_map="auto"
)
model.eval()

prompt = "Explain why quantization can help LLM inference throughput:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    start = time.time()
    outputs = model.generate(
        **inputs,
        max_new_tokens=50,
        use_cache=True,
        do_sample=False
    )
    end = time.time()

print(tokenizer.decode(outputs[0], skip_special_tokens=True))
print(f"Elapsed: {end - start:.3f}s")

注意:

  • 小模型上你不一定能看到非常明显的时间优势;
  • 但在大模型上,量化带来的显存释放往往能转化成更高 batch、更高并发和更稳定的服务能力。

示例四:一个简化版连续批处理调度器

这不是工业级实现,但能帮助你建立正确的思维模型:按 step 调度,而不是整请求调度

from collections import deque

class Request:
    def __init__(self, req_id, prompt_len, max_new_tokens):
        self.req_id = req_id
        self.prompt_len = prompt_len
        self.max_new_tokens = max_new_tokens
        self.generated = 0
        self.finished = False

    def step(self):
        if not self.finished:
            self.generated += 1
            if self.generated >= self.max_new_tokens:
                self.finished = True

class ContinuousBatchScheduler:
    def __init__(self, max_batch_size):
        self.max_batch_size = max_batch_size
        self.waiting_queue = deque()
        self.active_batch = []

    def submit(self, req):
        self.waiting_queue.append(req)

    def refill(self):
        while len(self.active_batch) < self.max_batch_size and self.waiting_queue:
            self.active_batch.append(self.waiting_queue.popleft())

    def run_one_step(self):
        self.refill()

        if not self.active_batch:
            return False

        print(f"\n[STEP] active={[r.req_id for r in self.active_batch]}")

        for req in self.active_batch:
            req.step()
            print(f"  req={req.req_id}, generated={req.generated}/{req.max_new_tokens}, finished={req.finished}")

        self.active_batch = [r for r in self.active_batch if not r.finished]
        return True

if __name__ == "__main__":
    scheduler = ContinuousBatchScheduler(max_batch_size=3)

    scheduler.submit(Request("A", prompt_len=20, max_new_tokens=2))
    scheduler.submit(Request("B", prompt_len=100, max_new_tokens=5))
    scheduler.submit(Request("C", prompt_len=50, max_new_tokens=3))
    scheduler.submit(Request("D", prompt_len=10, max_new_tokens=4))
    scheduler.submit(Request("E", prompt_len=80, max_new_tokens=1))

    while scheduler.run_one_step():
        pass

运行后你会看到类似行为:

  • 第一轮 A/B/C 进入活跃批次;
  • A 或 C 提前结束;
  • 空位立刻由 D、E 补上。

这就是连续批处理的本质:让 batch 槽位持续流动,而不是等整批全部结束。


逐步验证清单

如果你准备把这些优化真正落地,建议按下面顺序验证,不要一口气全开。否则出了问题很难定位。

验证 1:只打开 KV Cache

观察:

  • TTFT 是否变化不大或略变;
  • TPOT 是否明显下降;
  • 长输出是否更快。

验证 2:引入量化

观察:

  • 显存是否显著下降;
  • 模型输出质量是否可接受;
  • 单请求延迟是否变化;
  • 最大并发是否上升。

验证 3:引入连续批处理

观察:

  • GPU 利用率是否更平稳;
  • 高峰期吞吐是否增加;
  • p95 / p99 延迟是否改善;
  • 是否出现短请求被长请求拖慢的问题缓解。

验证 4:做组合测试

要测下面几种组合:

  • FP16 + KV Cache
  • INT8 + KV Cache
  • FP16 + KV Cache + 连续批处理
  • INT8 + KV Cache + 连续批处理

很多时候,最佳方案不是“最激进量化”,而是 在质量、吞吐、稳定性之间最平衡的那一组


常见坑与排查

这一段我尽量写得贴近实战,因为这些坑真的很常见。

坑一:开了 KV Cache,显存却爆了

原因通常是:

  • 上下文太长;
  • 并发太高;
  • batch 太大;
  • 模型层数/head 数本身就多。

排查思路:

  1. 降低 max_new_tokens
  2. 限制最大上下文长度;
  3. 降低并发或 batch;
  4. 打印显存曲线,确认是不是 KV Cache 增长导致。

可以用下面代码查看显存:

import torch

if torch.cuda.is_available():
    print(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    print(f"Reserved:  {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

坑二:量化后反而变慢

这很正常,不必惊讶。

可能原因:

  • 当前 GPU 对量化算子支持一般;
  • 模型太小,量化收益不明显;
  • batch 太小,吞吐优势体现不出来;
  • 框架存在反量化额外开销。

排查建议:

  • 不只看单请求耗时,也看 吞吐和最大并发
  • 对比不同 batch size;
  • 在同样质量要求下评估 INT8 与 FP16;
  • 如果 INT4 质量掉得明显,先回到 INT8。

坑三:连续批处理后,首 token 延迟变差

这是调度策略引起的常见副作用。

如果调度器过于追求吞吐,可能会:

  • 新请求等待太久才被加入;
  • Prefill 阶段和 Decode 阶段互相争抢资源。

排查方向:

  • 将 Prefill 和 Decode 分开看;
  • 给新请求保留一定接入优先级;
  • 限制单次调度窗口,避免老请求长期占满。

一句话总结:连续批处理不是只看吞吐,还要兼顾 TTFT。


坑四:不同长度请求混跑,尾延迟很难看

原因:

  • 长 prompt 的 Prefill 很重;
  • 长输出请求在 decode 阶段停留太久;
  • 调度没有做长度感知。

可尝试的办法:

  • 按 prompt 长度分桶;
  • 按预估输出长度做优先级;
  • 将超长请求转入单独队列。

这类优化很像“排队论 + GPU 调度”的结合,不是单纯模型问题。


坑五:线上效果和离线压测差别很大

离线压测常常是规则输入,而线上流量通常更复杂:

  • 上下文长度波动大;
  • 多租户混部;
  • 采样参数不同;
  • 工具调用、RAG 检索等引入额外延迟。

解决方式:

  • 压测数据尽量模拟真实长度分布;
  • 区分纯模型时间与端到端时间;
  • 将 TTFT、TPOT、上下文长度、输出长度一起打点。

安全/性能最佳实践

这一节我把“稳定上线”最重要的建议放在一起。

1. 为上下文长度设硬限制

不要让请求无限长。否则:

  • KV Cache 无限膨胀;
  • 单请求拖垮整机;
  • 尾延迟急剧恶化。

建议:

  • 设置最大输入 token;
  • 超限直接截断或拒绝;
  • 对系统 prompt、检索内容做预算分配。

2. 给输出长度设上限

max_new_tokens 一定要限制。

否则某些请求会长期占据 decode slot,影响整批吞吐。对外服务时,最好根据业务类型给不同上限,比如:

  • 聊天问答:128~512
  • 摘要改写:256~1024
  • 代码生成:按场景单独评估

3. 量化前先定义质量红线

不要只看“能不能跑”,还要看:

  • 事实性是否下降;
  • 代码生成是否退化;
  • 中文、多轮对话是否受影响。

建议至少保留一套基线评测集,比较:

  • 正确率
  • 输出稳定性
  • 幻觉率
  • 特定业务任务表现

4. 监控必须区分 Prefill 和 Decode

很多系统只看总耗时,这不够。

最好分别监控:

  • Prefill latency
  • TTFT
  • TPOT
  • tokens/s
  • active batch size
  • KV Cache 使用量
  • GPU memory / utilization

这样你才能知道问题是在“输入太长”,还是“解码调度不行”。

classDiagram
    class Metrics {
      +prefill_latency
      +ttft
      +tpot
      +throughput_tokens_per_s
      +gpu_utilization
      +gpu_memory
      +kv_cache_usage
      +active_batch_size
    }

5. 连续批处理要有公平性策略

如果只追求吞吐,可能让短请求、刚进入的新请求体验很差。

建议至少考虑:

  • 新请求最大等待时间;
  • 长请求限速或分级;
  • 不同租户配额;
  • 超长任务降级到低优先级队列。

6. 谨慎清理与复用缓存

KV Cache 很有价值,但也要注意边界:

  • 不同请求之间不要错误复用;
  • 会话结束要及时释放;
  • 多租户场景避免缓存串用导致数据泄漏。

这是性能问题,也是安全问题。


7. 不要迷信单一优化

在真实系统里,最有效的通常不是某一个点,而是组合拳:

  • KV Cache 降低重复计算;
  • 量化释放显存;
  • 连续批处理提高利用率;
  • 限长和监控保证稳定性。

如果只做其中一个,收益常常有限;如果组合得好,才会出现比较明显的跃迁。


一个实用的决策表

如果你要快速判断“先做什么”,可以参考这张表。

现象主要瓶颈优先优化
长输出越来越慢Decode 重复计算KV Cache
显存放不下模型或 batch 很小参数占用高量化
高并发时吞吐差、GPU 忙闲不均调度粗糙连续批处理
首 token 很慢Prefill 重、排队多限长 + 调度优化
p99 延迟很差长请求拖累长度分桶 + 连续批处理

总结

把大模型推理优化讲得再复杂,落到工程上,最重要的其实是三件事:

  1. KV Cache:减少 Decode 阶段重复计算,是最基础也最常见的提速手段;
  2. 量化:核心价值不只是“更快”,更是“更省显存,从而能支撑更高并发和更大 batch”;
  3. 连续批处理:解决多请求场景下 GPU 利用率不稳的问题,是服务化部署的关键能力。

如果你问我一个最实用的落地顺序,我会建议:

  • 单机先跑通 KV Cache
  • 再评估 INT8/INT4 量化 对质量和显存的影响;
  • 最后在并发场景上引入 连续批处理,并配套监控 TTFT、TPOT、吞吐、显存。

边界条件也要记住:

  • 显存非常紧时,KV Cache 可能成为新瓶颈;
  • 量化并不保证单请求一定更快;
  • 连续批处理如果调不好,会牺牲首 token 延迟。

所以,真正成熟的方案不是“某个参数一开就万事大吉”,而是建立一条 可观测、可回退、可分阶段验证 的优化路径。

如果你现在手上正有一个 LLM 服务,最值得立刻做的一件事是:
先把 TTFT、TPOT、上下文长度、输出长度、显存占用打点起来。
没有这些数据,优化只能靠感觉;有了这些数据,KV Cache、量化、连续批处理该怎么排优先级,通常一眼就清楚了。


分享到:

上一篇
《区块链中智能合约安全审计实战:从常见漏洞识别到自动化检测流程搭建-354》
下一篇
《Java开发踩坑实战:排查并修复 Spring Boot 项目中的循环依赖、配置优先级与 Bean 初始化顺序问题》