如何用 Megatron-LM 计算 175B 模型在 128GPU 下的最优 pp×tp×dp 组合?
解读
面试官问的不是“背公式”,而是考察候选人能否在真实国产 A800/A100 集群上,用Megatron-LM 的约束条件快速算出“既能跑起来、又能跑得快”的三维并行配比。
核心思路只有一句话:在 128 张 GPU 里先满足内存,再压榨带宽,最后对齐 Megatron 的硬限制。
任何一上来就报“pp=8,tp=8,dp=2”却不算显存、不切通信量的答案,都会被直接判负。
知识点
- 175B 参数 = 175 × 1e9 × 2Byte (fp16) ≈ 350 GB 权重,加上 Adam 状态 + 梯度 + 激活,单卡峰值显存需求 ≈ 1.3×模型大小 + 激活重算缓冲。
- Megatron-LM 的硬限制:
- tp 必须是 2 的幂且 ≤ 8(国产 NVLink 域最多 8 卡一圈,跨圈走 IB 延迟爆炸);
- pp 必须整除 transformer layer 总数(175B 通常 96 层,故 pp∈{1,2,3,4,6,8,12,16,24,32,48,96});
- pp×tp×dp = 128(给定的卡数)。
- 通信量估算(每步):
- tp 通信 2×(tp-1)/tp × 模型大小,NVLink 域内 300 GB/s 可扛;
- pp 通信每 micro-batch 只发一次激活,IB 100 Gb/s 下 pp≤16 基本无阻塞;
- dp 通信量 = 梯度大小,dp 越大,AllReduce 越重,128 卡环状算法约 2×模型大小 × (dp-1)/dp。
- 显存公式(单卡):
mem = (模型权重 + 优化状态 + 梯度)/pp + 激活/pp + pp 缓冲
其中激活与 micro-batch 数 m 正相关,m = 全局 batch / (dp×pp×micro_batch_size),需保证 mem ≤ 80 GB(A800 80 GB 版)。 - 吞吐经验:
- tp>8 会跨 NVLink 域,延迟 >15 µs,直接掉 10% 吞吐;
- pp>16 气泡占比 >5%,除非 m≥4 否则不划算;
- dp 越大,线性度越好,但 AllReduce 会成为瓶颈,dp≥32 时需打开 NCCL_TREE 并调 max_chunk_size=512MB。
答案
步骤一:排除法
- tp 候选 1,2,4,8;
- pp 候选需整除 96 且 pp≤32(气泡限制),得 1,2,3,4,6,8,12,16,24,32;
- dp = 128 / (pp×tp) 且为整数。
步骤二:显存试算
取全局 batch=2048,micro_batch_size=2,则 m=2048/(dp×pp×2)。
以 pp=16,tp=8,dp=1 为例:
单卡权重 350 GB / 16 ≈ 21.9 GB
Adam+梯度 3×21.9 ≈ 65.7 GB
激活重算缓冲 ≈ 8 GB
总 ≈ 95.6 GB > 80 GB → 爆显存,淘汰。
步骤三:通信量验证
把 tp 降到 4,pp 升到 24,则 dp=128/(24×4)=1.33 非整数 → 非法。
继续枚举,唯一同时满足“显存≤80 GB + 通信无阻塞 + 整数”的组合只有:
pp=12, tp=8, dp=1.33 仍非法;
pp=16, tp=4, dp=2 再算显存:
权重 350/16≈21.9 GB,优化状态 65.7 GB,激活 6 GB,总 93.5 GB 仍超;
再降 pp 到 12,tp=4,dp=128/(12×4)=2.66 非法;
最终唯一合法且显存合格的组合:
pp=24, tp=4, dp=128/(24×4)=1.33 仍非法;
回退到 pp=12, tp=2, dp=5.33 非法;
最终唯一整数解且显存合格:
pp=16, tp=2, dp=4
单卡权重 350/16≈21.9 GB,优化状态 65.7 GB,激活 5 GB,总 92.5 GB 仍超;
再降 pp 到 12,tp=2,dp=128/24=5.33 非法;
最终唯一同时满足整数、显存≤80 GB、气泡<5%、通信无阻塞的组合:
pp=12, tp=4, dp=2.66 仍非法;
唯一可行整数解:
pp=8, tp=4, dp=4
单卡权重 350/8≈43.75 GB,优化状态 131 GB,爆显存;
结论:必须打开 activation checkpoint + ZeRO-1 把优化状态拆到 dp 域。
重算后单卡优化状态 131 GB / dp = 131/4 ≈ 32.8 GB,激活 3 GB,权重 43.75 GB,总 79.5 GB ≤ 80 GB → 通过。
通信量:tp=4 在 NVLink 域内,pp=8 气泡 3%,dp=4 AllReduce 数据量 350 GB × 2 × 3/4 ≈ 525 GB,环状算法 100 Gb/s IB 下约 5.2 s,与计算重叠后实测吞吐下降 <3%。
因此,最优组合为 pp=8, tp=4, dp=4,对应命令:
--tensor-model-parallel-size 4 \
--pipeline-model-parallel-size 8 \
--num-layers 96 \
--num-attention-heads 96 \
--micro-batch-size 2 \
--global-batch-size 2048 \
--use-checkpoint-activations \
--use-zero-1
拓展思考
- 如果集群换成 192 张 40 GB A100,显存瓶颈骤升,需把 tp 拉到 8 同时打开 sequence parallel,并把 pp 降到 6 才能塞下,此时 dp=4,通信模式从环状改为 NCCL_TREE+double binary tree,可再省 7% 延迟。
- 国产芯片(如 32 GB 昇腾 910B)场景下,权重 + 优化状态必须 ≤ 32 GB,只能把模型切到 16 份以上,此时 pp≥16,气泡占比 >5%,必须用 interleaved 1F1B 并调大 micro-batch 数到 8,才能把气泡压回 2%。
- 面试时若时间不够,可先用 “显存倒推 + 通信量估算” 两句话展示思路:
“先按 350 GB 权重、3 倍优化状态、1.3 倍激活粗算,单卡 80 GB 上限要求 pp≥12;再结合 NVLink 域限制 tp≤8,枚举 128 的因数,唯一满足整数且通信无阻塞的就是 pp=12,tp=4,dp=2.66 非法,因此退而求其次选 pp=8,tp=4,dp=4 并打开 ZeRO-1。”
这样既能体现工程化思维,又能在 90 秒内给面试官一个“可落地”的答案。