Why distributed training
A Llama-70B model in fp16 has 140 GB of weights. Gradients are another 140 GB, optimizer states (Adam: m + v in fp32) another 560 GB — total ~840 GB just for state, before activations. An H100 has 80 GB HBM. You cannot train this on a single GPU. Distributed training is not optional; it's how every frontier model is made.
Three dimensions of parallelism define the design space, and modern systems compose all three (3D parallelism):
- Data parallel (DP) — replicate the model, split the batch across replicas, all-reduce gradients.
- Tensor parallel (TP) — split each layer's matrices across devices within a node.
- Pipeline parallel (PP) — split the layers themselves across stages.
Anchor numbers: training Llama-3-70B on the public recipe took ~7M H100-hours across 16k H100s for roughly two weeks. Throughput was ~400 TFLOPS/GPU sustained (out of ~990 fp16 theoretical). Model FLOPs Utilization (MFU) of 40–45% is the realistic target.
Source cross-reference
Key papers: Megatron-LM (Shoeybi 2019) for tensor parallelism, GPipe (Huang 2019) and PipeDream (Narayanan 2019) for pipeline parallelism, ZeRO (Rajbhandari 2020) and ZeRO-Infinity for optimizer-state sharding, and PyTorch's FSDP (Zhao 2023). DeepSpeed and NVIDIA Megatron-Core are the production implementations.
Data parallelism, ZeRO, FSDP
Classic DP
All GPUs hold a full model copy. Each processes a different micro-batch. After backward, all_reduce the gradients. Simple, but memory blows up: every GPU stores full weights + grads + optimizer state.
ZeRO (DeepSpeed)
Shard the state across DP ranks:
- ZeRO-1: shard optimizer states (Adam m, v). ~4× memory reduction for Adam fp32.
- ZeRO-2: also shard gradients. ~8× reduction.
- ZeRO-3: also shard weights. Only the shard needed for the current layer is gathered via
all_gather, freed after use. Equivalent to FSDP.
FSDP (PyTorch)
PyTorch's implementation of ZeRO-3. FSDP wraps modules into units; during forward, it all-gathers a unit's weights, runs it, re-shards. During backward, it all-gathers again for grad computation, then reduce-scatters gradients. Communication = 1.5× all-reduce vs DDP but memory is 1/N.
FSDP is the default recipe for training LLMs in 2024–2025 PyTorch. Named siblings: HSDP (Hybrid Sharded Data Parallel — shard within a node, replicate across nodes, reducing inter-node traffic).
Tensor and pipeline parallelism
Tensor parallel (TP)
Split a matrix multiply Y = X @ W across devices. Column-wise split of W: each device holds a column slice, outputs a column slice of Y. Row-wise split of W: each device holds a row slice, produces a partial sum; all_reduce to combine. Megatron-LM alternates these cleverly to minimize communication: one all-reduce per transformer block instead of two.
TP is intra-node only in practice (NVLink at 900 GB/s between 8 GPUs). Cross-node TP (over IB at 400 Gb/s) collapses throughput. Rule: TP ≤ 8 (one node).
Pipeline parallel (PP)
Split the transformer into stages; GPU 0 has layers 0–7, GPU 1 has 8–15, etc. Micro-batches flow through like a factory line. Problem: "bubble" — the first and last micro-batches leave idle time. 1F1B (one-forward-one-backward) schedule in Megatron or interleaved schedule in PipeDream-Flush minimize bubble.
PP is typically used across nodes when TP + FSDP isn't enough. Budget: 4–16 pipeline stages.
3D parallelism
The real recipe: TP=8 (intra-node) × PP=8 (inter-node) × DP=N (with ZeRO). A 1024-GPU cluster becomes TP8 × PP8 × DP16. This is how GPT-3, PaLM, Llama-3 were all trained.
flowchart TB
subgraph Node0[Node 0 - 8 H100s, TP=8]
G0[GPU 0 layer 0..7 slice] --- G1[GPU 1 layer 0..7 slice]
G1 --- G2[...]
end
subgraph Node1[Node 1 - PP stage 2, TP=8]
H0[GPU 0 layer 8..15] --- H1[GPU 1 layer 8..15]
end
Node0 -->|activations over IB| Node1
Node1 --> Next[...]
DP[Data-parallel replicas
FSDP sharding across nodes]
Activation memory: checkpointing, sequence parallel
Even with sharded weights, activations (stored for backward) can dominate memory for long sequences. For a transformer at seq_len=8192, batch=4, layers=80, d_model=8192, fp16, activations ≈ 100+ GB per GPU.
Activation / gradient checkpointing
Don't store all activations; recompute during backward. Saves ~60–80% activation memory at the cost of ~30% extra FLOPs (one extra forward). Essentially mandatory for >7B models.
Sequence parallel (Megatron)
LayerNorm and dropout activations are still replicated across TP ranks. Sequence parallel splits these along the sequence axis, recovering 2–4× activation memory within an existing TP group. Negligible extra communication.
Mixed precision
Train in bf16/fp16 for matmuls; keep a fp32 master copy of weights and optimizer states. Automatic Mixed Precision (AMP) is standard. On H100, fp8 training (with Transformer Engine) gives another 2× speedup with careful loss scaling — used by Nvidia/Meta/Anthropic for frontier runs.
Checkpointing, resilience, and scale
Checkpoint size
Llama-70B + optimizer state ≈ 840 GB. At 16k GPUs, a naive per-rank save is 840 GB × 16k replicas = petabytes. Real recipes:
- Shard the checkpoint across ranks (each saves its FSDP shard).
- Async / offload to CPU+NVMe while training continues.
- Resume loads shards in parallel; ~2–5 minutes on a modern cluster.
Checkpoint every 500–2000 steps. Skipping even one costs many GPU-hours if a node fails.
Failure rates
Interview number: on a 16k-GPU cluster, expect 1+ hardware failure per day (GPUs fall off the bus, NVLink link errors, node reboots). Training must tolerate this. Standard pattern:
- Health-check ring (torch.distributed.elastic) detects dead workers.
- Kill and requeue the job on remaining-healthy machines.
- Resume from last checkpoint.
- Meta's Llama 3 paper reports average of 2 interruptions per day across the run; achieved 90%+ "effective training time".
flowchart LR
T[Training step] --> C{step %% N == 0?}
C -->|no| T
C -->|yes| S[Snapshot shards
each rank to local NVMe]
S --> U[Upload to S3/GCS
async]
F[Failure detector] --> R[Restart job]
R --> L[Load latest checkpoint
from S3]
L --> T
Networking, cost, anti-patterns, checklist
Networking
At scale, network is the bottleneck. Modern clusters use:
- NVLink 900 GB/s intra-node (H100 NVSwitch).
- InfiniBand 400 Gb/s (NDR) or RoCE per GPU inter-node.
- Fat-tree / rail-optimized topology: each GPU has a dedicated NIC mapped to a "rail"; all-reduce across rails stays collision-free.
- NCCL library for collectives; topology-aware ring-allreduce or tree-allreduce.
Communication-to-compute ratio determines MFU. ZeRO-3 adds ~50% more comm than DDP. Overlap comm with compute (NCCL streams) is crucial.
Cost
H100 at ~$3/hour on cloud. 16k GPUs × 2 weeks × $3 = $16M for a 70B pretraining run. Reserved/on-prem clusters drop this 2–3×. Efficient sharding is not a nicety — saving 10% MFU saves $1.6M.
OpenAI-specific
OpenAI's public training papers are sparse, but their GPT-4 technical report confirms massive-scale training with "predictable scaling" — extrapolating loss curves from small runs. Implied infra: custom network on Azure, mixed TP/PP/DP at ~25k GPU scale. Their RLHF pipeline uses a separate reward-model cluster in tight loop with the policy trainer.
Anthropic-specific
Anthropic trains on AWS Trainium + Nvidia, across multiple regions. Their safety-first approach means "preparedness" evals are run on intermediate checkpoints; unsafe capabilities may halt the run. Constitutional AI fine-tuning adds another training stage after pretraining, with RLAIF (AI-generated feedback) replacing parts of RLHF.
Anti-patterns
- TP across nodes. Inter-node IB is 10× slower than NVLink. Throughput dies.
- No gradient clipping. One bad batch spikes the loss and corrupts optimizer state. Always clip (1.0 is common).
- Fully synchronous checkpointing. Stops training for minutes per save. Use async + double-buffered.
- Ignoring stragglers. One slow GPU slows the whole step. Monitor per-rank step time.
- Optimizer in fp16. Overflows and NaNs. Keep fp32 master weights.
Whiteboard checklist: model size → memory budget (weights + grad + optim + act) → pick parallelism (TP=8 intra-node, PP for layer count, FSDP/ZeRO-3 for DP) → activation checkpointing + sequence parallel → bf16/fp8 mixed precision → sharded async checkpoints → health-check + requeue → NCCL comm-compute overlap → MFU target 40%+ → cost estimate.
为什么要分布式训练
Llama-70B fp16 权重 140 GB。梯度再 140 GB,优化器状态(Adam:m + v 在 fp32)再 560 GB——仅状态就 ~840 GB,还没算激活。H100 HBM 80 GB。单卡训不了。分布式训练不是可选,是前沿模型的唯一做法。
三维并行定义设计空间,现代系统三者组合(3D 并行):
- 数据并行 (DP)——复制模型,batch 按副本切,all-reduce 梯度。
- 张量并行 (TP)——每层矩阵按节点内设备切。
- 流水并行 (PP)——按层切到不同 stage。
锚点数:Llama-3-70B 公开食谱约 ~700 万 H100 小时、16000 H100 两周。吞吐维持 ~400 TFLOPS/GPU(理论 fp16 ~990)。MFU(模型 FLOPS 利用率)40-45% 是现实目标。
参考来源
关键论文:Megatron-LM(Shoeybi 2019)TP,GPipe(Huang 2019)与 PipeDream(Narayanan 2019)PP,ZeRO(Rajbhandari 2020)与 ZeRO-Infinity 优化器状态切片,PyTorch FSDP(Zhao 2023)。DeepSpeed 与 NVIDIA Megatron-Core 是生产实现。
数据并行、ZeRO、FSDP
经典 DP
所有 GPU 存满模型。各自处理不同 micro-batch。backward 后 all_reduce 梯度。简单,但内存爆炸:每卡存全权重 + 梯度 + 优化器状态。
ZeRO (DeepSpeed)
状态按 DP rank 切:
- ZeRO-1:切优化器状态(Adam m、v)。对 fp32 Adam 约 4× 减内存。
- ZeRO-2:再切梯度。~8× 减。
- ZeRO-3:再切权重。只在当前层需要时
all_gather,用完释放。等价于 FSDP。
FSDP (PyTorch)
PyTorch 的 ZeRO-3 实现。FSDP 把模块包成 unit;forward 时 all-gather unit 权重、运行、再切片。backward 时再 all-gather 算梯度,再 reduce-scatter 梯度。通信 = DDP 的 1.5× all-reduce,内存是 1/N。
FSDP 是 2024-2025 PyTorch 训 LLM 的默认食谱。同门:HSDP(Hybrid Sharded DP——节点内切片、节点间复制,减跨节点流量)。
张量并行与流水并行
张量并行 (TP)
把矩阵乘 Y = X @ W 切到多设备。W 列切:每设备持一列片,输出 Y 的一列片。W 行切:每设备持一行片,产生部分和,all_reduce 合并。Megatron-LM 交替使用,把每 transformer block 从两次 all-reduce 降到一次。
TP 实践上仅在节点内(NVLink 8 卡间 900 GB/s)。跨节点 TP(IB 400 Gb/s)吞吐崩。规则:TP ≤ 8(单节点)。
流水并行 (PP)
transformer 按层切 stage;GPU 0 拿 0-7 层,GPU 1 拿 8-15,等。micro-batch 像流水线流。问题:"气泡"——首末 micro-batch 有空闲。Megatron 1F1B(one-forward-one-backward)或 PipeDream-Flush 交错调度减气泡。
TP + FSDP 不够时跨节点上 PP。预算:4-16 个流水 stage。
3D 并行
真食谱:TP=8(节点内)× PP=8(节点间)× DP=N(带 ZeRO)。1024 卡集群 = TP8 × PP8 × DP16。GPT-3、PaLM、Llama-3 都这么训。
flowchart TB
subgraph Node0[节点 0 - 8 H100, TP=8]
G0[GPU 0 层 0..7 片] --- G1[GPU 1 层 0..7 片]
G1 --- G2[...]
end
subgraph Node1[节点 1 - PP stage 2, TP=8]
H0[GPU 0 层 8..15] --- H1[GPU 1 层 8..15]
end
Node0 -->|激活走 IB| Node1
Node1 --> Next[...]
DP[数据并行副本
FSDP 跨节点切片]
激活内存:checkpoint、序列并行
即使权重切片,激活(backward 用)在长序列下也能吃满内存。transformer seq_len=8192, batch=4, layers=80, d_model=8192, fp16 时激活 ≈ 每卡 100+ GB。
激活/梯度 checkpointing
不全存激活,backward 时重算。省 60-80% 激活内存,代价多约 30% FLOPS(多一次 forward)。>7B 模型基本必备。
序列并行 (Megatron)
LayerNorm 和 dropout 激活仍在 TP rank 间复制。序列并行沿序列轴切,在已有 TP 组内再省 2-4× 激活内存,额外通信可忽略。
混合精度
矩阵乘 bf16/fp16,保留 fp32 权重与优化器状态主副本。AMP 是标准。H100 上 fp8 训练(配 Transformer Engine)再 2× 加速,需谨慎 loss scaling——NVIDIA/Meta/Anthropic 的前沿 run 都用。
checkpoint、韧性与规模
checkpoint 大小
Llama-70B + 优化器状态 ≈ 840 GB。16000 卡下朴素每 rank 存 = 840 GB × 16000 = PB 级。真食谱:
- checkpoint 按 rank 切(各存自己的 FSDP 片)。
- 训练继续的同时异步/卸载到 CPU+NVMe。
- 恢复时各片并行加载;现代集群 2-5 分钟。
每 500-2000 步存一次。节点故障时跳过一次就是很多 GPU 小时。
故障率
面试数:16000 卡集群每日期望 1+ 硬件故障(GPU 掉线、NVLink 错、节点重启)。训练必须容忍。标准模式:
- 健康检查环(torch.distributed.elastic)检测死 worker。
- 杀掉并在剩余健康机上重新排入。
- 从最新 checkpoint 恢复。
- Meta Llama 3 论文报告平均每日 2 次中断,达到 90%+ "有效训练时间"。
flowchart LR
T[训练步] --> C{step 能被 N 整除?}
C -->|否| T
C -->|是| S[快照分片
各 rank 到本地 NVMe]
S --> U[异步上传 S3/GCS]
F[故障检测] --> R[重启作业]
R --> L[从 S3 加载最新 checkpoint]
L --> T
网络、成本、反模式、清单
网络
规模下网络是瓶颈。现代集群:
- NVLink 900 GB/s 节点内(H100 NVSwitch)。
- InfiniBand 400 Gb/s (NDR) 或 RoCE per GPU 跨节点。
- Fat-tree / rail-optimized 拓扑:每 GPU 专属 NIC 映射到 "rail",跨 rail all-reduce 无碰撞。
- NCCL 做集合通信,拓扑感知 ring-allreduce 或 tree-allreduce。
通信计算比决定 MFU。ZeRO-3 比 DDP 多约 50% 通信。通信与计算重叠(NCCL stream)关键。
成本
H100 云上约 $3/小时。16000 卡 × 两周 × $3 = 70B 预训练 $1600 万。自建/保留集群降 2-3×。高效切片不是奢侈——省 10% MFU = 省 $160 万。
OpenAI 细节
OpenAI 训练论文少,但 GPT-4 技术报告确认了"可预测扩展"——从小规模拟合损失曲线外推。隐含基础设施:Azure 上定制网络、TP/PP/DP 混合于约 2.5 万卡规模。RLHF 管线在独立奖励模型集群与策略训练器紧耦合。
Anthropic 细节
Anthropic 在 AWS Trainium + Nvidia 跨区域训练。安全优先意味着"准备度"评估在中间 checkpoint 上跑,不安全能力可能中断训练。Constitutional AI 微调是预训练后的又一阶段,RLAIF(AI 生成反馈)替代部分 RLHF。
反模式
- 跨节点 TP。跨节点 IB 比 NVLink 慢 10×,吞吐挂。
- 没梯度裁剪。一个坏 batch 让 loss 飙、优化器状态腐蚀。永远 clip(1.0 常见)。
- 全同步 checkpoint。每次存阻塞训练数分钟。用异步 + 双缓冲。
- 忽略慢机。一卡慢拖整个 step。监控每 rank 步时间。
- fp16 优化器。溢出和 NaN。保 fp32 主权重。
白板清单:模型规模 → 内存预算(权重+梯度+优化器+激活)→ 选并行(节点内 TP=8、按层数选 PP、DP 用 FSDP/ZeRO-3)→ 激活 checkpointing + 序列并行 → bf16/fp8 混合精度 → 异步分片 checkpoint → 健康检查+重排 → NCCL 通信计算重叠 → MFU 目标 40%+ → 成本估算。