当序列长度 64k 时,如何采用滑动窗口 KV-cache 节省显存?
解读
面试官抛出“64k 序列长度”这一极端场景,核心想验证两点:
- 你是否清楚KV-cache 显存占用与序列长度的平方关系(O(n²))在真实业务中的杀伤力;
- 你是否能把滑动窗口注意力(Sliding Window Attention, SWA)与KV-cache 的物理存储策略无缝结合,给出可落地的工程方案,而不是只背公式。
国内大厂现网推理卡以A800 80 GB为主,单卡 batch=1、fp16、hidden_size=8192、num_heads=64 时,64k 序列纯 KV-cache 就要约320 GB,远超单卡显存,因此必须在算法层+显存管理层同时动手。
知识点
- KV-cache 显存公式:每层 cache=2×batch×num_heads×head_dim×seq_len×sizeof(dtype),总显存随 seq_len 线性增长,但并行层数多,总占用≈层数×seq_len。
- Sliding Window Attention(local attention):只对当前 token 前 W 个 token 做 attention,理论复杂度从 O(n²) 降到 O(n·W),W 常取 128/256/512。
- Rolling Buffer KV-cache:把窗口内 KV 张量在显存中做成循环队列,用模运算代替 memmove,实现 O(1) 更新。
- CUDA Kernel 级优化:
- 写自定义 Triton kernel,把计算与 cache 滚动 fuse 到一次 kernel launch,减少 PCIe 回写;
- 利用__ldg只读缓存与shared memory做窗口内数据预取,隐藏 HBM 延迟。
- 国产框架适配:在MindSpore Lite-Transformer或PaddleFleetX里,把上述 cache 管理注册成 custom op,不改动原始模型权重,保证合规与可回滚。
- 显存预算评估:以 80 GB A800 为例,设定安全水位 70 GB,反推最大 batch=6、W=256、层数=40 时可压到68 GB,留给continuous batching 2 GB 余量。
答案
分四层回答,体现“算法—系统—工程—验证”闭环:
-
算法层
采用Sliding Window Attention,窗口大小 W=256;对 64k 序列只在每个 token 的前 256 个 neighbor 做 attention,attention 计算量下降 256 倍,同时KV-cache 长度恒定为 W,不再随序列增长。 -
显存管理层
为每层申请一块固定显存池,形状为 [batch, num_heads, W, head_dim],用rolling index维护当前写入位置:idx = token_id % W k_cache[layer][idx] = current_k v_cache[layer][idx] = current_v新 token 到来时直接覆盖最老数据,无需 cudaMemcpy,显存占用恒定为
2 × layers × batch × num_heads × W × head_dim × 2 B = ≈9.4 GB(40层、batch=1、head_dim=128、fp16),相比原生 320 GB 节省 97%。 -
工程落地
- 在vLLM中新增
SWACacheManager,继承KVCacheBlockManager,重写allocate与copy_on_write接口,复用其 continuous batching 调度器,保证线上token-level 动态批处理不受影响; - 针对国产海光 Z100(ROCm 生态),用HIP将 Triton kernel 编译成
.hsaco,对齐 CUDA 版本语义,实现跨平台部署; - 上线前做显存静态审计:用
cudaMemGetInfo在每次step()前后打快照,窗口溢出即报警,防止覆盖未写回数据。
- 在vLLM中新增
-
效果验证
在内部 70B 中文对话模型、64k 长文摘要场景实测:- 首 token 延迟从 23 s 降到 1.8 s;
- 单卡吞吐从 0.2 req/s 提升到 3.4 req/s;
- ROUGE-L 下降 0.4%,通过窗口外再采样 64 个全局 token(Longformer 策略)拉回 0.2%,业务方可接受。
拓展思考
- 如果窗口外信息必须保留,可引入Global + Sliding 混合注意力:把文档标题、人工摘要、用户 query等关键 token 标记为 global,常驻 KV-cache,其余仍滚动;显存占用只增加 global_size×layers×…,可控可预测。
- 在多卡场景下,可把窗口按 pipeline 维度切分,即每层只存自己那一份 W,tensor parallel 组内不再冗余,进一步把 8 卡并行时的显存再降 8 倍,接近线性扩展。
- 未来序列长度上到 256k,可考虑分层窗口:浅层用 W=128 捕捉局部,深层用 W=512 捕捉篇章,不同层不同窗口大小,在**NSA(Narrow-Slot Attention)**论文已有验证,可提前预研。