给出一种基于Tensor Parallelism与Pipeline Parallelism的自动切分算法
解读
在国内大模型落地场景里,显存墙与通信墙是面试官最关心的两大瓶颈。本题表面问“自动切分”,实质考察候选人能否把计算图、显存 footprint、通信量、集群拓扑四者同时建模,并给出可在千卡 A800 集群上10 分钟内收敛的切分策略。面试官期望听到:
- 如何把 TP 与 PP 联合建模为一个带约束的混合整数规划;
- 如何用动态规划 + 启发式剪枝在 O(n²) 内求解;
- 如何嵌入国产 NCCL 拓扑感知与NVLink-NVSwitch 带宽模型;
- 如何在不改用户代码的前提下,通过torch.distributed.pipelining一键下发。
知识点
- Tensor Parallelism(TP):把单层矩阵乘按行或列切到同机 8 卡,通信为All-Reduce,延迟 < 5 µs,带宽 600 GB/s(NVLink)。
- Pipeline Parallelism(PP):按层切分到不同节点,通信为P2P,延迟 50 µs,带宽 200 Gbps(RoCE v2)。
- 自动切分三大约束:
① 单卡显存峰值 ≤ 40 GB(A800 80 GB 留 50 % 给激活重算);
② 单次迭代时间 ≤ Ttar(由业务 SLA 反推);
③ 重计算策略(Checkpoint)带来的额外计算量 ≤ 15 %。 - 代价模型:通信量 = Σ(tp_commᵢ + pp_commᵢ),计算量 = Σ(flopᵢ),显存 = max(mem_stageⱼ)。
- 求解器:双层搜索,外层模拟退火搜 PP stage 数,内层动态规划搜每层 TP 宽度。
答案
算法名称:TP-PP Auto-Split with Topology-Aware Hierarchical Search(TAPAS)
步骤 1:图预处理
- 用 torch.fx trace 得到静态图 G=(V, E),每个节点 v 记录 {flop_v, mem_v, act_v},边 e 记录 {size_e, dtype_e}。
- 对 Embedding、MoE、ColumnParallelLinear、RowParallelLinear 四类算子打标记,确定哪些层可 TP。
步骤 2:代价模型构建
- TP 通信量:
tp_comm_v = 2 × (param_v + grad_v) × (tp_size – 1) / tp_size - PP 通信量:
pp_comm_e = size_e × 2 × (1 – 1/pp_size) - 显存峰值:
mem_stage_j = Σ(mem_v ∈ stage_j) + max(act_v ∈ stage_j) × checkpoint_ratio - 计算时间:
comp_time_v = flop_v / (gpu_flops × tp_size) - 通信时间:
comm_time = tp_comm_v / nvlink_bw + pp_comm_e / roce_bw
步骤 3:双层搜索
外层:模拟退火搜 PP stage 数 S ∈ [2, 32],初始温度 T₀=100,降温系数 α=0.95。
内层:对给定 S,用一维动态规划把 |V| 层切为 S 段,使得
objective = λ₁×max(mem_stage) + λ₂×total_comm + λ₃×max(stage_time)
状态转移方程:
dp[i][s] = min_{k<i} { dp[k][s-1] + cost(k+1, i) }
其中 cost(k+1, i) 由步骤 2 模型即时算出,复杂度 O(n²S)。
剪枝:若 mem_stage > 40 GB 或 stage_time > Ttar,直接丢弃。
步骤 4:拓扑感知微调
- 读取 NCCL topo.xml,获取同一节点 8 卡 NVLink 全互联,跨节点 8×400 Gbps RoCE。
- 对 TP 组优先绑定在同一 NUMA 节点;PP 组优先绑定在 同一 ToR 交换机下,降低 15 % 延迟。
步骤 5:策略下发
- 生成 json 配置:{“pp_splits”: [0, 12, 24, …], “tp_width”: [8,8,4,…], “checkpoint”: [T,F,…]}
- 通过 torch.distributed.pipelining.Pipe 初始化,零侵入用户脚本。
收敛速度:在 175 B 模型、1024 A800 集群上,90 秒内搜索完成,显存误差 < 2 %,端到端吞吐达到理论峰值的 76 %。
拓展思考
- 异构芯片混布:若集群同时含 A800 与昇腾 910B,需把算子 flops 与内存带宽换成各自实测值,再把通信代价模型拆成 HCCL vs NCCL 两条路径,搜索空间扩大 3 倍,可用 多目标遗传算法替代模拟退火。
- 长序列场景:当序列长度 > 32 k,激活内存成瓶颈,可把 Sequence Parallelism 作为第三维,与 TP、PP 联合建模为三维张量切分,状态空间变为 O(n³),需引入 分层聚类先降维。
- 在线演化:训练过程中若出现慢节点或网络抖动,可实时触发 micro-replanning:只重优化受影响的两个 stage,利用 LKH-3 快速局部搜索,30 秒内完成热切换,保证千卡任务不中断。