Why distributed training

A Llama-70B model in fp16 has 140 GB of weights. Gradients are another 140 GB, optimizer states (Adam: m + v in fp32) another 560 GB — total ~840 GB just for state, before activations. An H100 has 80 GB HBM. You cannot train this on a single GPU. Distributed training is not optional; it's how every frontier model is made.

Three dimensions of parallelism define the design space, and modern systems compose all three (3D parallelism):

  • Data parallel (DP) — replicate the model, split the batch across replicas, all-reduce gradients.
  • Tensor parallel (TP) — split each layer's matrices across devices within a node.
  • Pipeline parallel (PP) — split the layers themselves across stages.

Anchor numbers: training Llama-3-70B on the public recipe took ~7M H100-hours across 16k H100s for roughly two weeks. Throughput was ~400 TFLOPS/GPU sustained (out of ~990 fp16 theoretical). Model FLOPs Utilization (MFU) of 40–45% is the realistic target.

Source cross-reference

Key papers: Megatron-LM (Shoeybi 2019) for tensor parallelism, GPipe (Huang 2019) and PipeDream (Narayanan 2019) for pipeline parallelism, ZeRO (Rajbhandari 2020) and ZeRO-Infinity for optimizer-state sharding, and PyTorch's FSDP (Zhao 2023). DeepSpeed and NVIDIA Megatron-Core are the production implementations.

Data parallelism, ZeRO, FSDP

Classic DP

All GPUs hold a full model copy. Each processes a different micro-batch. After backward, all_reduce the gradients. Simple, but memory blows up: every GPU stores full weights + grads + optimizer state.

ZeRO (DeepSpeed)

Shard the state across DP ranks:

  • ZeRO-1: shard optimizer states (Adam m, v). ~4× memory reduction for Adam fp32.
  • ZeRO-2: also shard gradients. ~8× reduction.
  • ZeRO-3: also shard weights. Only the shard needed for the current layer is gathered via all_gather, freed after use. Equivalent to FSDP.

FSDP (PyTorch)

PyTorch's implementation of ZeRO-3. FSDP wraps modules into units; during forward, it all-gathers a unit's weights, runs it, re-shards. During backward, it all-gathers again for grad computation, then reduce-scatters gradients. Communication = 1.5× all-reduce vs DDP but memory is 1/N.

FSDP is the default recipe for training LLMs in 2024–2025 PyTorch. Named siblings: HSDP (Hybrid Sharded Data Parallel — shard within a node, replicate across nodes, reducing inter-node traffic).

Tensor and pipeline parallelism

Tensor parallel (TP)

Split a matrix multiply Y = X @ W across devices. Column-wise split of W: each device holds a column slice, outputs a column slice of Y. Row-wise split of W: each device holds a row slice, produces a partial sum; all_reduce to combine. Megatron-LM alternates these cleverly to minimize communication: one all-reduce per transformer block instead of two.

TP is intra-node only in practice (NVLink at 900 GB/s between 8 GPUs). Cross-node TP (over IB at 400 Gb/s) collapses throughput. Rule: TP ≤ 8 (one node).

Pipeline parallel (PP)

Split the transformer into stages; GPU 0 has layers 0–7, GPU 1 has 8–15, etc. Micro-batches flow through like a factory line. Problem: "bubble" — the first and last micro-batches leave idle time. 1F1B (one-forward-one-backward) schedule in Megatron or interleaved schedule in PipeDream-Flush minimize bubble.

PP is typically used across nodes when TP + FSDP isn't enough. Budget: 4–16 pipeline stages.

3D parallelism

The real recipe: TP=8 (intra-node) × PP=8 (inter-node) × DP=N (with ZeRO). A 1024-GPU cluster becomes TP8 × PP8 × DP16. This is how GPT-3, PaLM, Llama-3 were all trained.

flowchart TB
  subgraph Node0[Node 0 - 8 H100s, TP=8]
    G0[GPU 0 layer 0..7 slice] --- G1[GPU 1 layer 0..7 slice]
    G1 --- G2[...]
  end
  subgraph Node1[Node 1 - PP stage 2, TP=8]
    H0[GPU 0 layer 8..15] --- H1[GPU 1 layer 8..15]
  end
  Node0 -->|activations over IB| Node1
  Node1 --> Next[...]
  DP[Data-parallel replicas
FSDP sharding across nodes]

Activation memory: checkpointing, sequence parallel

Even with sharded weights, activations (stored for backward) can dominate memory for long sequences. For a transformer at seq_len=8192, batch=4, layers=80, d_model=8192, fp16, activations ≈ 100+ GB per GPU.

Activation / gradient checkpointing

Don't store all activations; recompute during backward. Saves ~60–80% activation memory at the cost of ~30% extra FLOPs (one extra forward). Essentially mandatory for >7B models.

Sequence parallel (Megatron)

LayerNorm and dropout activations are still replicated across TP ranks. Sequence parallel splits these along the sequence axis, recovering 2–4× activation memory within an existing TP group. Negligible extra communication.

Mixed precision

Train in bf16/fp16 for matmuls; keep a fp32 master copy of weights and optimizer states. Automatic Mixed Precision (AMP) is standard. On H100, fp8 training (with Transformer Engine) gives another 2× speedup with careful loss scaling — used by Nvidia/Meta/Anthropic for frontier runs.

Checkpointing, resilience, and scale

Checkpoint size

Llama-70B + optimizer state ≈ 840 GB. At 16k GPUs, a naive per-rank save is 840 GB × 16k replicas = petabytes. Real recipes:

  • Shard the checkpoint across ranks (each saves its FSDP shard).
  • Async / offload to CPU+NVMe while training continues.
  • Resume loads shards in parallel; ~2–5 minutes on a modern cluster.

Checkpoint every 500–2000 steps. Skipping even one costs many GPU-hours if a node fails.

Failure rates

Interview number: on a 16k-GPU cluster, expect 1+ hardware failure per day (GPUs fall off the bus, NVLink link errors, node reboots). Training must tolerate this. Standard pattern:

  • Health-check ring (torch.distributed.elastic) detects dead workers.
  • Kill and requeue the job on remaining-healthy machines.
  • Resume from last checkpoint.
  • Meta's Llama 3 paper reports average of 2 interruptions per day across the run; achieved 90%+ "effective training time".
flowchart LR
  T[Training step] --> C{step %% N == 0?}
  C -->|no| T
  C -->|yes| S[Snapshot shards
each rank to local NVMe] S --> U[Upload to S3/GCS
async] F[Failure detector] --> R[Restart job] R --> L[Load latest checkpoint
from S3] L --> T

Networking, cost, anti-patterns, checklist

Networking

At scale, network is the bottleneck. Modern clusters use:

  • NVLink 900 GB/s intra-node (H100 NVSwitch).
  • InfiniBand 400 Gb/s (NDR) or RoCE per GPU inter-node.
  • Fat-tree / rail-optimized topology: each GPU has a dedicated NIC mapped to a "rail"; all-reduce across rails stays collision-free.
  • NCCL library for collectives; topology-aware ring-allreduce or tree-allreduce.

Communication-to-compute ratio determines MFU. ZeRO-3 adds ~50% more comm than DDP. Overlap comm with compute (NCCL streams) is crucial.

Cost

H100 at ~$3/hour on cloud. 16k GPUs × 2 weeks × $3 = $16M for a 70B pretraining run. Reserved/on-prem clusters drop this 2–3×. Efficient sharding is not a nicety — saving 10% MFU saves $1.6M.

OpenAI-specific

OpenAI's public training papers are sparse, but their GPT-4 technical report confirms massive-scale training with "predictable scaling" — extrapolating loss curves from small runs. Implied infra: custom network on Azure, mixed TP/PP/DP at ~25k GPU scale. Their RLHF pipeline uses a separate reward-model cluster in tight loop with the policy trainer.

Anthropic-specific

Anthropic trains on AWS Trainium + Nvidia, across multiple regions. Their safety-first approach means "preparedness" evals are run on intermediate checkpoints; unsafe capabilities may halt the run. Constitutional AI fine-tuning adds another training stage after pretraining, with RLAIF (AI-generated feedback) replacing parts of RLHF.

Anti-patterns

  • TP across nodes. Inter-node IB is 10× slower than NVLink. Throughput dies.
  • No gradient clipping. One bad batch spikes the loss and corrupts optimizer state. Always clip (1.0 is common).
  • Fully synchronous checkpointing. Stops training for minutes per save. Use async + double-buffered.
  • Ignoring stragglers. One slow GPU slows the whole step. Monitor per-rank step time.
  • Optimizer in fp16. Overflows and NaNs. Keep fp32 master weights.

Whiteboard checklist: model size → memory budget (weights + grad + optim + act) → pick parallelism (TP=8 intra-node, PP for layer count, FSDP/ZeRO-3 for DP) → activation checkpointing + sequence parallel → bf16/fp8 mixed precision → sharded async checkpoints → health-check + requeue → NCCL comm-compute overlap → MFU target 40%+ → cost estimate.

为什么要分布式训练

Llama-70B fp16 权重 140 GB。梯度再 140 GB,优化器状态(Adam:m + v 在 fp32)再 560 GB——仅状态就 ~840 GB,还没算激活。H100 HBM 80 GB。单卡训不了。分布式训练不是可选,是前沿模型的唯一做法。

三维并行定义设计空间,现代系统三者组合(3D 并行):

  • 数据并行 (DP)——复制模型,batch 按副本切,all-reduce 梯度。
  • 张量并行 (TP)——每层矩阵按节点内设备切。
  • 流水并行 (PP)——按层切到不同 stage。

锚点数:Llama-3-70B 公开食谱约 ~700 万 H100 小时、16000 H100 两周。吞吐维持 ~400 TFLOPS/GPU(理论 fp16 ~990)。MFU(模型 FLOPS 利用率)40-45% 是现实目标。

参考来源

关键论文:Megatron-LM(Shoeybi 2019)TP,GPipe(Huang 2019)与 PipeDream(Narayanan 2019)PP,ZeRO(Rajbhandari 2020)与 ZeRO-Infinity 优化器状态切片,PyTorch FSDP(Zhao 2023)。DeepSpeed 与 NVIDIA Megatron-Core 是生产实现。

数据并行、ZeRO、FSDP

经典 DP

所有 GPU 存满模型。各自处理不同 micro-batch。backward 后 all_reduce 梯度。简单,但内存爆炸:每卡存全权重 + 梯度 + 优化器状态。

ZeRO (DeepSpeed)

状态按 DP rank 切:

  • ZeRO-1:切优化器状态(Adam m、v)。对 fp32 Adam 约 4× 减内存。
  • ZeRO-2:再切梯度。~8× 减。
  • ZeRO-3:再切权重。只在当前层需要时 all_gather,用完释放。等价于 FSDP。

FSDP (PyTorch)

PyTorch 的 ZeRO-3 实现。FSDP 把模块包成 unit;forward 时 all-gather unit 权重、运行、再切片。backward 时再 all-gather 算梯度,再 reduce-scatter 梯度。通信 = DDP 的 1.5× all-reduce,内存是 1/N。

FSDP 是 2024-2025 PyTorch 训 LLM 的默认食谱。同门:HSDP(Hybrid Sharded DP——节点内切片、节点间复制,减跨节点流量)。

张量并行与流水并行

张量并行 (TP)

把矩阵乘 Y = X @ W 切到多设备。W 列切:每设备持一列片,输出 Y 的一列片。W 行切:每设备持一行片,产生部分和,all_reduce 合并。Megatron-LM 交替使用,把每 transformer block 从两次 all-reduce 降到一次。

TP 实践上仅在节点内(NVLink 8 卡间 900 GB/s)。跨节点 TP(IB 400 Gb/s)吞吐崩。规则:TP ≤ 8(单节点)。

流水并行 (PP)

transformer 按层切 stage;GPU 0 拿 0-7 层,GPU 1 拿 8-15,等。micro-batch 像流水线流。问题:"气泡"——首末 micro-batch 有空闲。Megatron 1F1B(one-forward-one-backward)或 PipeDream-Flush 交错调度减气泡。

TP + FSDP 不够时跨节点上 PP。预算:4-16 个流水 stage。

3D 并行

真食谱:TP=8(节点内)× PP=8(节点间)× DP=N(带 ZeRO)。1024 卡集群 = TP8 × PP8 × DP16。GPT-3、PaLM、Llama-3 都这么训。

flowchart TB
  subgraph Node0[节点 0 - 8 H100, TP=8]
    G0[GPU 0 层 0..7 片] --- G1[GPU 1 层 0..7 片]
    G1 --- G2[...]
  end
  subgraph Node1[节点 1 - PP stage 2, TP=8]
    H0[GPU 0 层 8..15] --- H1[GPU 1 层 8..15]
  end
  Node0 -->|激活走 IB| Node1
  Node1 --> Next[...]
  DP[数据并行副本
FSDP 跨节点切片]

激活内存:checkpoint、序列并行

即使权重切片,激活(backward 用)在长序列下也能吃满内存。transformer seq_len=8192, batch=4, layers=80, d_model=8192, fp16 时激活 ≈ 每卡 100+ GB。

激活/梯度 checkpointing

不全存激活,backward 时重算。省 60-80% 激活内存,代价多约 30% FLOPS(多一次 forward)。>7B 模型基本必备。

序列并行 (Megatron)

LayerNorm 和 dropout 激活仍在 TP rank 间复制。序列并行沿序列轴切,在已有 TP 组内再省 2-4× 激活内存,额外通信可忽略。

混合精度

矩阵乘 bf16/fp16,保留 fp32 权重与优化器状态主副本。AMP 是标准。H100 上 fp8 训练(配 Transformer Engine)再 2× 加速,需谨慎 loss scaling——NVIDIA/Meta/Anthropic 的前沿 run 都用。

checkpoint、韧性与规模

checkpoint 大小

Llama-70B + 优化器状态 ≈ 840 GB。16000 卡下朴素每 rank 存 = 840 GB × 16000 = PB 级。真食谱:

  • checkpoint 按 rank 切(各存自己的 FSDP 片)。
  • 训练继续的同时异步/卸载到 CPU+NVMe。
  • 恢复时各片并行加载;现代集群 2-5 分钟。

每 500-2000 步存一次。节点故障时跳过一次就是很多 GPU 小时。

故障率

面试数:16000 卡集群每日期望 1+ 硬件故障(GPU 掉线、NVLink 错、节点重启)。训练必须容忍。标准模式:

  • 健康检查环(torch.distributed.elastic)检测死 worker。
  • 杀掉并在剩余健康机上重新排入。
  • 从最新 checkpoint 恢复。
  • Meta Llama 3 论文报告平均每日 2 次中断,达到 90%+ "有效训练时间"。
flowchart LR
  T[训练步] --> C{step 能被 N 整除?}
  C -->|否| T
  C -->|是| S[快照分片
各 rank 到本地 NVMe] S --> U[异步上传 S3/GCS] F[故障检测] --> R[重启作业] R --> L[从 S3 加载最新 checkpoint] L --> T

网络、成本、反模式、清单

网络

规模下网络是瓶颈。现代集群:

  • NVLink 900 GB/s 节点内(H100 NVSwitch)。
  • InfiniBand 400 Gb/s (NDR) 或 RoCE per GPU 跨节点。
  • Fat-tree / rail-optimized 拓扑:每 GPU 专属 NIC 映射到 "rail",跨 rail all-reduce 无碰撞。
  • NCCL 做集合通信,拓扑感知 ring-allreduce 或 tree-allreduce。

通信计算比决定 MFU。ZeRO-3 比 DDP 多约 50% 通信。通信与计算重叠(NCCL stream)关键。

成本

H100 云上约 $3/小时。16000 卡 × 两周 × $3 = 70B 预训练 $1600 万。自建/保留集群降 2-3×。高效切片不是奢侈——省 10% MFU = 省 $160 万。

OpenAI 细节

OpenAI 训练论文少,但 GPT-4 技术报告确认了"可预测扩展"——从小规模拟合损失曲线外推。隐含基础设施:Azure 上定制网络、TP/PP/DP 混合于约 2.5 万卡规模。RLHF 管线在独立奖励模型集群与策略训练器紧耦合。

Anthropic 细节

Anthropic 在 AWS Trainium + Nvidia 跨区域训练。安全优先意味着"准备度"评估在中间 checkpoint 上跑,不安全能力可能中断训练。Constitutional AI 微调是预训练后的又一阶段,RLAIF(AI 生成反馈)替代部分 RLHF。

反模式

  • 跨节点 TP。跨节点 IB 比 NVLink 慢 10×,吞吐挂。
  • 没梯度裁剪。一个坏 batch 让 loss 飙、优化器状态腐蚀。永远 clip(1.0 常见)。
  • 全同步 checkpoint。每次存阻塞训练数分钟。用异步 + 双缓冲。
  • 忽略慢机。一卡慢拖整个 step。监控每 rank 步时间。
  • fp16 优化器。溢出和 NaN。保 fp32 主权重。

白板清单:模型规模 → 内存预算(权重+梯度+优化器+激活)→ 选并行(节点内 TP=8、按层数选 PP、DP 用 FSDP/ZeRO-3)→ 激活 checkpointing + 序列并行 → bf16/fp8 混合精度 → 异步分片 checkpoint → 健康检查+重排 → NCCL 通信计算重叠 → MFU 目标 40%+ → 成本估算。