很多团队在模型尺寸上来后,第一反应就是把Tensor Parallel打开。参数一拆,单卡显存马上松下来,原本放不进的hidden size和micro-batch也终于能跑。⚠️ 真正进入连续训练后,报表却经常变脸:显存压力下来了,单步耗时反而更长,MFU和有效吞吐一起回落。🎯
问题往往不在算力不够,而在切分后的通信时序被做坏了。一个线性层本来是一次大GEMM,做完就往下走;切成多卡后,每层前后都要插进all-gather、reduce-scatter或all-reduce。🔍 如果这些 collective 都堵在层边界,GPU 很多时间不是在算,而是在等上一轮通信把张量拼回来。🧠
掉效率的根子,常在切得太碎又收得太晚
最常见的误区,是把TP size当成越大越稳的扩容手柄。📦TP从2拉到4以后,单卡权重和激活确实更轻,但每层可并行的本地计算也更小;一旦GEMM被切到不足以盖住链路时延,通信开销就会开始反客为主。🚨 训练看起来像是在多卡并行,实际却在多卡排队。📉
第二个坑在重叠窗口。很多实现等本层反向完全结束,才统一发起all-reduce,再等下一层拿到完整张量继续前进。🛠️ 这会把原本能藏进计算尾部的通信,硬生生推到 step 主路径上。尤其当sequence length不长、micro-batch又小的时候,collective 的启动成本会被反复放大,链路利用率很高,训练效率却不高。📌
一组 32 B 回放里,决定结果的是重叠窗口不是 TP 开关
这次回放的是32 B解码器训练,硬件为16 x A100 80 GB,节点内NVLink,序列长度4096,BF16,全局 batch 固定。🧪 基线组使用TP = 2;第二组把TP提到4但维持默认层边界同步;第三组同样使用TP = 4,但把all-gather前移到下一层预读窗口,并限制每次在途通信 bucket。📊 结果很直接,TP 开得更大,不代表墙钟一定更差,关键看通信是不是被压进了可重叠区。✅
| 方案 | 峰值显存 | 单步耗时 | MFU | NVLink 利用率 |
|---|---|---|---|---|
TP = 2 | 71 GB | 1.34 s | 45% | 43% |
TP = 4默认同步 | 55 GB | 1.61 s | 36% | 68% |
TP = 4+ overlap window | 57 GB | 1.39 s | 44% | 64% |
这组数据最值得记住的点,是第二组并没有“配错 Tensor Parallel”,而是把 collective 都堆到了最慢的时刻。📍 当下一层输入张量必须等all-gather完成才出现时,GPU 看到的不是更多并行,而是更多阻塞;只有把通信 bucket 与线性层顺序绑在一起,让上一层的尾部计算覆盖下一层的拼接等待,TP 才会从显存工具变成吞吐工具。🔧
tp_config={"tensor_parallel_size":4,"tp_comm_overlap":True,"tp_overlap_window":2,"tp_reduce_bucket_mb":64,"sequence_parallel":True,"max_inflight_collectives":3,}生产里要把 TP 当成通信预算,不是容量魔法
更稳的做法,是先看每层FLOPs / bytes,再决定TP size,而不是看到显存紧就继续横向切。🧭 如果模型已经落到小micro-batch、短序列或高频 checkpoint 的组合,继续增大TP往往只会把NCCL wait拉长;这时更有效的补救,通常是把sequence parallel、激活检查点和 overlap 调度一起收敛,而不是单靠更大的张量切分。🔒
笔者认为,未来3 - 6个月分布式训练里更有价值的,不是继续神化TP size,而是把 collective 调度做成可观测、可自适应的系统能力。🚀 只要团队还在用“显存降了多少”代替“每层等待了多久”,Tensor Parallel 就很容易停在“能跑但不快”的阶段。你们现在盯的,是单卡终于不OOM,还是整条训练主路径的真实空转时间?💬