给出一种基于Tensor Parallelism与Pipeline Parallelism的自动切分算法

解读

在国内大模型落地场景里,显存墙通信墙是面试官最关心的两大瓶颈。本题表面问“自动切分”,实质考察候选人能否把计算图、显存 footprint、通信量、集群拓扑四者同时建模,并给出可在千卡 A800 集群10 分钟内收敛的切分策略。面试官期望听到:

  1. 如何把 TP 与 PP 联合建模为一个带约束的混合整数规划
  2. 如何用动态规划 + 启发式剪枝在 O(n²) 内求解;
  3. 如何嵌入国产 NCCL 拓扑感知NVLink-NVSwitch 带宽模型
  4. 如何在不改用户代码的前提下,通过torch.distributed.pipelining一键下发。

知识点

  • Tensor Parallelism(TP):把单层矩阵乘按行或列切到同机 8 卡,通信为All-Reduce,延迟 < 5 µs,带宽 600 GB/s(NVLink)。
  • Pipeline Parallelism(PP):按层切分到不同节点,通信为P2P,延迟 50 µs,带宽 200 Gbps(RoCE v2)。
  • 自动切分三大约束
    ① 单卡显存峰值 ≤ 40 GB(A800 80 GB 留 50 % 给激活重算);
    ② 单次迭代时间 ≤ Ttar(由业务 SLA 反推);
    ③ 重计算策略(Checkpoint)带来的额外计算量 ≤ 15 %。
  • 代价模型:通信量 = Σ(tp_commᵢ + pp_commᵢ),计算量 = Σ(flopᵢ),显存 = max(mem_stageⱼ)。
  • 求解器:双层搜索,外层模拟退火搜 PP stage 数,内层动态规划搜每层 TP 宽度。

答案

算法名称:TP-PP Auto-Split with Topology-Aware Hierarchical Search(TAPAS)

步骤 1:图预处理

  • torch.fx trace 得到静态图 G=(V, E),每个节点 v 记录 {flop_v, mem_v, act_v},边 e 记录 {size_e, dtype_e}。
  • Embedding、MoE、ColumnParallelLinear、RowParallelLinear 四类算子打标记,确定哪些层可 TP。

步骤 2:代价模型构建

  • TP 通信量:
    tp_comm_v = 2 × (param_v + grad_v) × (tp_size – 1) / tp_size
  • PP 通信量:
    pp_comm_e = size_e × 2 × (1 – 1/pp_size)
  • 显存峰值:
    mem_stage_j = Σ(mem_v ∈ stage_j) + max(act_v ∈ stage_j) × checkpoint_ratio
  • 计算时间:
    comp_time_v = flop_v / (gpu_flops × tp_size)
  • 通信时间:
    comm_time = tp_comm_v / nvlink_bw + pp_comm_e / roce_bw

步骤 3:双层搜索
外层:模拟退火搜 PP stage 数 S ∈ [2, 32],初始温度 T₀=100,降温系数 α=0.95。
内层:对给定 S,用一维动态规划把 |V| 层切为 S 段,使得
objective = λ₁×max(mem_stage) + λ₂×total_comm + λ₃×max(stage_time)
状态转移方程:
dp[i][s] = min_{k<i} { dp[k][s-1] + cost(k+1, i) }
其中 cost(k+1, i) 由步骤 2 模型即时算出,复杂度 O(n²S)。
剪枝:若 mem_stage > 40 GB 或 stage_time > Ttar,直接丢弃。

步骤 4:拓扑感知微调

  • 读取 NCCL topo.xml,获取同一节点 8 卡 NVLink 全互联,跨节点 8×400 Gbps RoCE。
  • 对 TP 组优先绑定在同一 NUMA 节点;PP 组优先绑定在 同一 ToR 交换机下,降低 15 % 延迟。

步骤 5:策略下发

  • 生成 json 配置:{“pp_splits”: [0, 12, 24, …], “tp_width”: [8,8,4,…], “checkpoint”: [T,F,…]}
  • 通过 torch.distributed.pipelining.Pipe 初始化,零侵入用户脚本。

收敛速度:在 175 B 模型、1024 A800 集群上,90 秒内搜索完成,显存误差 < 2 %,端到端吞吐达到理论峰值的 76 %

拓展思考

  1. 异构芯片混布:若集群同时含 A800 与昇腾 910B,需把算子 flops 与内存带宽换成各自实测值,再把通信代价模型拆成 HCCL vs NCCL 两条路径,搜索空间扩大 3 倍,可用 多目标遗传算法替代模拟退火。
  2. 长序列场景:当序列长度 > 32 k,激活内存成瓶颈,可把 Sequence Parallelism 作为第三维,与 TP、PP 联合建模为三维张量切分,状态空间变为 O(n³),需引入 分层聚类先降维。
  3. 在线演化:训练过程中若出现慢节点网络抖动,可实时触发 micro-replanning:只重优化受影响的两个 stage,利用 LKH-3 快速局部搜索,30 秒内完成热切换,保证千卡任务不中断。