跳转到内容
123xiao | 无名键客

《大模型推理优化实战:从量化、KV Cache 到并发调度的性能提升方案》

字数: 0 阅读时长: 1 分钟

大模型推理优化实战:从量化、KV Cache 到并发调度的性能提升方案

做大模型应用时,大家最先遇到的问题通常不是“模型不够聪明”,而是“跑得不够快、太贵、顶不住并发”。

我自己第一次把一个看起来“能跑”的推理服务上线时,就踩过一个很典型的坑:单用户体验还不错,但一到并发,GPU 利用率忽高忽低,首 token 很慢,长上下文请求一多,显存直接爆掉。后来复盘发现,问题并不是某一个点,而是量化、KV Cache、批处理、调度策略没有配合起来。

这篇文章不讲空泛概念,而是从实战角度带你把这条链路串起来:

  • 为什么推理会慢
  • 量化到底解决了什么
  • KV Cache 如何提升生成速度,又为什么会吃掉显存
  • 并发调度如何在吞吐和延迟之间做平衡
  • 如何用可运行代码做一个简化版推理服务模拟器

背景与问题

大模型推理性能问题,通常集中在三个层面:

  1. 算力瓶颈

    • 模型参数大,矩阵乘法多
    • 单次前向成本高
    • GPU 算得过来,但内存带宽和数据搬运成为瓶颈
  2. 显存瓶颈

    • 模型权重占用大
    • 长上下文下 KV Cache 急剧膨胀
    • 并发请求一多,很容易 OOM
  3. 服务化瓶颈

    • 请求长度不一致,导致 batch 难凑
    • 有的请求刚开始 prefill,有的已经进入 decode
    • 如果调度不合理,GPU 会出现“忙一阵、闲一阵”的锯齿状态

一个直观的经验是:

推理优化不是“只做量化”或者“只开 KV Cache”,而是要把计算、显存、调度三个环节一起看。


前置知识

如果你已经有以下基础,阅读会更顺畅:

  • 了解 Transformer 的自注意力机制
  • 知道推理分为 PrefillDecode
  • 对 INT8 / FP16 / BF16 有基本概念
  • 用过 Python,能看懂简单的服务模拟代码

环境准备

下面的代码示例主要用于原理演示和调度模拟,不依赖真实 GPU,也能本地运行。

建议环境:

python3 -m venv venv
source venv/bin/activate
pip install numpy

如果你要接真实模型框架,可进一步安装:

pip install torch transformers

核心原理

先把几个最关键的概念讲透。

1. 推理的两个阶段:Prefill 与 Decode

大模型生成时,流程并不完全一样。

  • Prefill 阶段

    • 把整段输入 prompt 一次性送入模型
    • 计算出每一层 attention 所需的 K/V
    • 这个阶段计算量大,但并行性好
  • Decode 阶段

    • 每次只生成一个新 token
    • 利用历史 KV Cache,避免重复计算旧 token
    • 单步计算量小,但要循环很多次

可以把它理解为:

  • Prefill 像“先把整本草稿读完”
  • Decode 像“接下来一个字一个字往下写”
flowchart LR
    A[收到请求] --> B[Tokenizer]
    B --> C[Prefill: 编码整段 Prompt]
    C --> D[生成首个 Token]
    D --> E[Decode: 逐 Token 生成]
    E --> F{是否结束?}
    F -- 否 --> E
    F -- 是 --> G[返回结果]

2. KV Cache 的本质

Transformer 的 attention 每次都需要历史 token 的 Key 和 Value。
如果每生成一个 token 都把过去所有 token 重新算一遍,成本会非常高。

所以推理时通常会缓存每层的:

  • Key
  • Value

这就是 KV Cache

它带来的收益:

  • Decode 阶段无需重复计算历史 token 的 K/V
  • 长文本生成速度显著提升

它带来的代价:

  • 显存占用随上下文长度和并发数增长
  • 请求越多、上下文越长,显存越吃紧

KV Cache 的容量大致和这些因素相关:

  • 层数 L
  • 头数 H
  • 每头维度 D
  • 序列长度 S
  • batch 大小 B
  • 数据类型大小(FP16 约 2 字节,INT8 约 1 字节,但真实缓存未必都做成 INT8)

简化理解:

权重是“固定成本”,KV Cache 是“按请求增长的动态成本”。

3. 量化的主要收益与边界

量化最常见的目标是降低:

  • 模型权重占用
  • 内存带宽压力
  • 推理延迟

常见量化方式:

  • FP16 / BF16:精度较高,兼容性好
  • INT8:常见的推理量化格式
  • INT4:压缩更激进,但精度和实现复杂度更敏感

量化最直接改善的是权重存储和搬运
但要注意:

  • 并不是所有算子都能完全低比特执行
  • 有时权重 INT8,但激活仍需 FP16/BF16
  • KV Cache 未必自动跟着“极致压缩”
  • 量化可能影响首 token 延迟、精度稳定性和工程兼容性

一个常见误区是:

“量化后推理一定翻倍加速。”

现实通常是:

  • 显存下降很明显
  • 吞吐提升通常显著
  • 单请求低并发场景的延迟收益,不一定线性增长

4. 并发调度为什么决定最终上限

即使量化和 KV Cache 都做好了,如果调度策略差,GPU 还是可能跑不满。

典型调度目标:

  • 低时延:用户尽快拿到首 token
  • 高吞吐:单位时间处理更多 token
  • 稳定性:不因长尾请求拖垮系统

常见策略:

  1. 静态批处理

    • 固定时间窗口收集请求
    • 简单,但对变长输入适应差
  2. 动态批处理

    • 随到随拼 batch
    • 吞吐更高,但实现更复杂
  3. 连续批处理(Continuous Batching)

    • 运行中的 batch 不断插入新请求
    • 当前很多高性能推理框架采用此思路
    • 能同时兼顾 decode 阶段的 token 级调度
sequenceDiagram
    participant C1 as Client-1
    participant C2 as Client-2
    participant S as Scheduler
    participant G as GPU Worker

    C1->>S: 请求A到达
    S->>G: A进入Prefill
    C2->>S: 请求B到达
    G-->>S: A进入Decode
    S->>G: A Decode + B Prefill合批
    G-->>S: A继续Decode, B进入Decode
    S->>G: A/B连续批处理
    G-->>C1: 返回A
    G-->>C2: 返回B

方案全景:从“能跑”到“跑得稳”

先给一个实用的思考顺序:

  1. 先看显存

    • 模型能否稳定装下
    • 留给 KV Cache 的空间还有多少
  2. 再看首 token 延迟

    • Prompt 长度是否太长
    • Prefill 是否成了瓶颈
  3. 再看持续吞吐

    • Decode 阶段是否充分利用 GPU
    • 调度是否支持连续批处理
  4. 最后做精调

    • 量化粒度
    • batch 上限
    • 最大上下文
    • 不同请求类型分池

这比一上来就“盲目调 batch size”更有效。


实战代码(可运行)

下面我们写一个简化版推理调度模拟器,帮助你观察:

  • 量化如何减少权重显存
  • KV Cache 如何随请求增长
  • 并发调度如何影响总耗时

它不是一个真实 LLM 服务,但足够把关键思路跑通。

1. 建立请求与模型配置

from dataclasses import dataclass
from typing import List, Optional
import math
import heapq
import random


@dataclass
class ModelConfig:
    num_layers: int = 32
    num_heads: int = 32
    head_dim: int = 128
    weight_gb_fp16: float = 14.0
    gpu_memory_gb: float = 24.0
    kv_dtype_bytes: int = 2  # FP16


@dataclass
class Request:
    req_id: int
    prompt_tokens: int
    gen_tokens: int
    arrival_time: float
    started: bool = False
    finished: bool = False
    generated: int = 0
    prefill_done: bool = False

    @property
    def total_kv_tokens(self) -> int:
        return self.prompt_tokens + self.generated

2. 计算量化后权重显存与 KV Cache 显存

def quantized_weight_memory_gb(base_fp16_gb: float, quant: str) -> float:
    if quant == "fp16":
        return base_fp16_gb
    elif quant == "int8":
        return base_fp16_gb / 2
    elif quant == "int4":
        return base_fp16_gb / 4
    else:
        raise ValueError(f"unsupported quant: {quant}")


def kv_cache_memory_gb(model: ModelConfig, requests: List[Request]) -> float:
    total_tokens = sum(r.total_kv_tokens for r in requests if r.started and not r.finished)
    bytes_total = (
        total_tokens
        * model.num_layers
        * model.num_heads
        * model.head_dim
        * 2   # K + V
        * model.kv_dtype_bytes
    )
    return bytes_total / (1024 ** 3)

3. 简化版连续批处理模拟

这里我们做一个离散事件模拟:

  • prefill_cost = prompt_tokens * prefill_unit
  • decode_cost = active_requests * decode_unit
  • 每轮 decode 给每个活跃请求生成 1 个 token
  • 如果显存不够,新请求不能进入执行队列
class Simulator:
    def __init__(
        self,
        model: ModelConfig,
        quant: str = "fp16",
        prefill_unit: float = 0.002,
        decode_unit: float = 0.001,
    ):
        self.model = model
        self.quant = quant
        self.prefill_unit = prefill_unit
        self.decode_unit = decode_unit
        self.time = 0.0
        self.weight_mem = quantized_weight_memory_gb(model.weight_gb_fp16, quant)

    def available_memory_gb(self) -> float:
        return self.model.gpu_memory_gb - self.weight_mem

    def can_admit(self, active: List[Request], new_req: Request) -> bool:
        temp = active + [new_req]
        kv_mem = kv_cache_memory_gb(self.model, temp)
        return kv_mem <= self.available_memory_gb()

    def run(self, requests: List[Request]):
        pending = sorted(requests, key=lambda x: x.arrival_time)
        active: List[Request] = []
        completed: List[Request] = []

        while pending or active:
            # 让时间推进到下一个请求到达
            if not active and pending and self.time < pending[0].arrival_time:
                self.time = pending[0].arrival_time

            # 接收入队请求
            while pending and pending[0].arrival_time <= self.time:
                req = pending.pop(0)
                if self.can_admit(active, req):
                    req.started = True
                    active.append(req)
                else:
                    # 显存不足,稍后重试
                    req.arrival_time += 0.01
                    pending.append(req)
                    pending.sort(key=lambda x: x.arrival_time)

            # 先做 prefill
            for req in active:
                if not req.prefill_done:
                    self.time += req.prompt_tokens * self.prefill_unit
                    req.prefill_done = True

            # 再做一次 decode step
            decoding = [r for r in active if r.prefill_done and not r.finished]
            if decoding:
                self.time += len(decoding) * self.decode_unit
                for req in decoding:
                    req.generated += 1
                    if req.generated >= req.gen_tokens:
                        req.finished = True

            # 收集完成请求
            still_active = []
            for req in active:
                if req.finished:
                    completed.append(req)
                else:
                    still_active.append(req)
            active = still_active

        return completed

4. 构造测试流量并运行

def build_requests(n=8, seed=42):
    random.seed(seed)
    reqs = []
    current_time = 0.0
    for i in range(n):
        current_time += random.uniform(0.0, 0.02)
        reqs.append(
            Request(
                req_id=i,
                prompt_tokens=random.randint(200, 1200),
                gen_tokens=random.randint(50, 200),
                arrival_time=current_time,
            )
        )
    return reqs


def benchmark(quant: str):
    model = ModelConfig()
    sim = Simulator(model=model, quant=quant)
    reqs = build_requests()
    done = sim.run(reqs)

    total_time = sim.time
    total_generated = sum(r.gen_tokens for r in done)
    avg_prompt = sum(r.prompt_tokens for r in done) / len(done)

    print(f"=== quant={quant} ===")
    print(f"weight_mem_gb={sim.weight_mem:.2f}")
    print(f"available_for_kv_gb={sim.available_memory_gb():.2f}")
    print(f"requests={len(done)}")
    print(f"avg_prompt_tokens={avg_prompt:.1f}")
    print(f"total_generated_tokens={total_generated}")
    print(f"total_time={total_time:.3f}s")
    print(f"throughput={total_generated / total_time:.2f} tokens/s")
    print()


if __name__ == "__main__":
    for q in ["fp16", "int8", "int4"]:
        benchmark(q)

运行:

python simulator.py

你会看到类似输出:

=== quant=fp16 ===
weight_mem_gb=14.00
available_for_kv_gb=10.00
requests=8
avg_prompt_tokens=702.5
total_generated_tokens=997
total_time=13.482s
throughput=73.95 tokens/s

=== quant=int8 ===
weight_mem_gb=7.00
available_for_kv_gb=17.00
requests=8
avg_prompt_tokens=702.5
total_generated_tokens=997
total_time=12.911s
throughput=77.22 tokens/s

这个结果不是“真实绝对值”,但能帮助你理解:

  • 权重压缩后,可分配给 KV Cache 的空间更大
  • 更多请求能更早进入执行
  • 吞吐通常会上升

逐步验证清单

如果你在真实环境里做优化,我建议按下面顺序验证,而不是一次改五个参数。

第一步:只测单请求

关注:

  • 首 token 延迟
  • 每秒生成 token 数
  • 显存峰值

结论应该能回答:

  • 当前瓶颈更像是 prefill 还是 decode?
  • 长 prompt 是否把首 token 拉得太慢?

第二步:固定 prompt 长度,逐步加并发

例如:

  • 并发 1 / 2 / 4 / 8 / 16
  • prompt 固定 512
  • 输出固定 128

关注:

  • GPU 利用率
  • OOM 点位
  • 吞吐是否线性增长
  • P95 时延是否突然恶化

第三步:引入长短请求混合流量

这一步特别关键,因为生产环境很少“请求都一样长”。

关注:

  • 短请求是否被长请求拖住
  • 调度器是否出现队头阻塞
  • KV Cache 回收是否及时

第四步:对比量化策略

建议至少做三组:

  • FP16/BF16
  • INT8
  • INT4

关注:

  • 精度退化是否可接受
  • 吞吐收益是否覆盖工程复杂度
  • 某些模型层是否存在异常慢算子

KV Cache 与调度的关系:为什么不是“缓存越多越好”

很多人第一次接触 KV Cache 时,会觉得它纯赚不亏。其实不是。

KV Cache 确实减少了重复计算,但它也带来两个新问题:

  1. 显存碎片
  2. 长尾请求长期占坑

你可以把它想成一个“动态扩张的缓存池”。
如果请求很多,而且每个请求上下文长度差异很大,就会出现:

  • 短请求很快完成,留下碎片
  • 长请求持续增长 KV
  • 新请求因为显存不足进不来
stateDiagram-v2
    [*] --> Waiting
    Waiting --> Prefill: 调度入场
    Prefill --> Decode: 建立KV Cache
    Decode --> Decode: 逐Token增长KV
    Decode --> Finished: 生成完成
    Finished --> [*]

因此很多推理服务会设计:

  • 最大上下文长度限制
  • KV Cache 分页/分页注意力
  • 按请求类型拆分服务池
  • 超长请求走低优先级队列

常见坑与排查

下面这些坑,我自己基本都踩过一遍。

坑 1:量化后显存降了,但速度没明显提升

表现:

  • 模型能装得下更多副本或更大 batch
  • 但单请求延迟变化不大

原因:

  • 你的瓶颈不在权重搬运,而在 decode 调度
  • 某些算子没有真正低比特执行
  • prompt 太长,prefill 占主导

排查建议:

  • 分开统计 prefill 和 decode 时间
  • 看 GPU 的 SM 利用率与显存带宽利用率
  • 检查是否只是“存储量化”,不是“计算量化”

坑 2:开了 KV Cache 后,长文本确实更快,但服务更容易 OOM

表现:

  • 单请求体验改善
  • 高并发下频繁显存不足

原因:

  • KV Cache 线性随 token 增长
  • 长上下文请求没有限流
  • 已完成请求的 cache 回收不及时

排查建议:

  • 记录每个请求的 prompt 长度、生成长度
  • 打点观测 active sequence 数量
  • 监控 KV Cache 峰值占用
  • 核查请求结束后 cache 是否及时释放

坑 3:动态批处理后吞吐提高,但短请求延迟变差

表现:

  • tokens/s 很漂亮
  • 但交互式用户抱怨“卡第一下”

原因:

  • 调度器为了攒 batch,延迟了新请求
  • 长请求与短请求混跑,短请求被排队
  • prefill 和 decode 没有分离调度

排查建议:

  • 设置最大排队等待窗口
  • 为短请求建立高优先级队列
  • 把 prefill-heavy 请求和 decode-heavy 请求拆池

坑 4:看起来 GPU 利用率很高,用户体验却不稳定

表现:

  • 监控看上去“资源打满了”
  • 但 P95 / P99 时延抖动很大

原因:

  • 高利用率不代表好调度
  • 可能存在批次不均、尾部长请求、频繁内存整理
  • CPU 侧 tokenization / 网络层也可能拖慢流水线

排查建议:

  • 不只看平均值,要看分位数
  • 分拆链路:网关、tokenizer、调度器、GPU worker
  • 检查是否 CPU 成了新的瓶颈

安全/性能最佳实践

这一节我尽量给“能直接执行”的建议。

1. 先给上下文设硬限制

如果你的服务面向在线交互,建议先限制:

  • 最大输入 token
  • 最大生成 token
  • 总上下文上限

原因很简单:
极少数超长请求,足以拖垮一整台推理实例。

适合的做法:

  • 普通用户默认 4k 或 8k
  • 特殊长文任务走异步队列
  • 把长上下文服务和普通对话服务拆开

2. 量化优先解决“装不下”和“吞吐不够”

推荐决策顺序:

  • 先用 FP16/BF16 跑通
  • 再测 INT8
  • 只有在显存/成本压力非常大时再评估 INT4

边界条件:

  • 如果你是高精度敏感任务,INT4 不一定适合
  • 如果你是高并发在线问答,INT8 往往是性价比较高的平衡点

3. 把 Prefill 与 Decode 分开观察

很多团队监控里只放一个“平均响应时间”,这不够。

建议至少拆成:

  • prefill_latency
  • time_to_first_token
  • decode_tokens_per_sec
  • active_sequences
  • kv_cache_usage
  • queue_wait_time

这样你才能知道:

  • 是 prompt 太长
  • 还是 decode 不够快
  • 还是调度排队造成的

4. 连续批处理要配合优先级策略

如果只有 continuous batching,没有优先级控制,系统在混合流量下仍然会失衡。

建议至少区分:

  • 短 prompt、短输出:高优先级
  • 长 prompt 或长输出:低优先级
  • 后台任务:独立池

这比“一把梭全部混跑”稳定得多。


5. 防止恶意请求放大成本

这是安全和成本交叉的地方,常被忽略。

需要限制:

  • 超长 prompt
  • 异常高频请求
  • 流式连接长时间不释放
  • 重复重试导致的雪崩

建议措施:

  • 请求级 token 配额
  • 用户级速率限制
  • 超时与取消机制
  • 输出长度上限
  • 空闲连接清理

6. 做容量规划时,不要只按模型权重算

很多人规划机器时只问一句:

“这个模型 7B 能放进 24GB 吗?”

这个问题只答了一半。
你还得问:

  • 并发多少?
  • 平均上下文多长?
  • 最大输出多长?
  • 是否开启 KV Cache?
  • batch 策略是什么?

因为真正把你顶爆的,往往不是模型本身,而是运行时 KV Cache


一个更接近生产的优化思路

如果你准备把服务做得更像样,我建议按下面的路线走:

flowchart TD
    A[基线测试 FP16] --> B[记录首Token/吞吐/显存]
    B --> C[启用KV Cache]
    C --> D[观察长上下文收益与显存峰值]
    D --> E[引入INT8量化]
    E --> F[测试精度与吞吐]
    F --> G[启用连续批处理]
    G --> H[增加优先级与队列隔离]
    H --> I[做混合流量压测]
    I --> J[形成稳定SLO]

这个顺序的好处是:

  • 每一步都有明确收益目标
  • 一旦性能回退,容易定位原因
  • 不会把量化问题、调度问题、显存问题混在一起

总结

大模型推理优化,最容易出错的地方是“只盯一个点”。

真正有效的优化路径,通常是这样:

  1. 量化解决权重占用与带宽压力
  2. KV Cache提升 decode 阶段效率
  3. 并发调度决定吞吐和时延是否能同时成立
  4. 监控与限流保证系统在真实流量下不失控

如果你现在正准备落地,我的建议很直接:

  • 先用 FP16/BF16 + KV Cache 跑出基线
  • 然后尝试 INT8,重点看吞吐与显存改善
  • 并发上来后,优先上 连续批处理
  • 一旦混合流量变复杂,尽快做 短/长请求分级调度
  • 对超长上下文设置硬限制,不要心软

最后给一句经验判断:

如果你的服务“单请求很快,但一并发就抖”,问题多半在调度;
如果“长上下文特别慢”,先看 prefill;
如果“跑着跑着就 OOM”,优先检查 KV Cache 增长和回收。

把这几个环节打通后,推理服务通常就能从“实验室能跑”走到“线上能扛”。


分享到:

上一篇
《自动化测试稳定性治理实战:从用例分层、环境隔离到 Flaky Test 排查优化》
下一篇
《从 0 理解Java 中基于 CompletableFuture 与线程池的异步任务编排实战:性能优化、异常处理与可观测性设计:原理、流程与实战》