← Part 1: The GPU Memory Problem
Distributing 140 GB of weights across 8 GPUs
In Part 1, we established that Llama-70B's weights require 140 GB — nearly double the A100's 80 GB capacity. How do we fit something that doesn't fit?
The answer is FSDP (Fully Sharded Data Parallel). But FSDP is built on several fundamental concepts that we need to understand first. In this part, we'll build up our understanding step by step:
The fundamental insight: instead of replicating all weights on every GPU (like standard Data Parallel), each GPU stores only 1/N of the weights.
Let's work through a simple example to build intuition. We'll use 4 GPUs here to keep the step-by-step traces readable — fewer GPUs means fewer steps to follow. Later, we'll scale up to 8 GPUs for the real Llama-70B numbers.
Suppose we have a weight matrix W and 4 GPUs:
Original weight matrix W (shape [8, 4]):
W = | w00 w01 w02 w03 |
| w10 w11 w12 w13 |
| w20 w21 w22 w23 |
| w30 w31 w32 w33 |
| w40 w41 w42 w43 |
| w50 w51 w52 w53 |
| w60 w61 w62 w63 |
| w70 w71 w72 w73 |
We split the matrix evenly across GPUs. In practice, FSDP flattens all parameters into a 1D buffer and divides it into equal chunks — but for visualization, think of it as splitting the rows:
GPU 0 holds W_shard0 (rows 0-1):
| w00 w01 w02 w03 |
| w10 w11 w12 w13 |
GPU 1 holds W_shard1 (rows 2-3):
| w20 w21 w22 w23 |
| w30 w31 w32 w33 |
GPU 2 holds W_shard2 (rows 4-5):
| w40 w41 w42 w43 |
| w50 w51 w52 w53 |
GPU 3 holds W_shard3 (rows 6-7):
| w60 w61 w62 w63 |
| w70 w71 w72 w73 |
Applying this to Llama-70B with 8 GPUs:
Total weights: 140 GB (70B params × 2 bytes in BF16)
Number of GPUs: 8
Per-GPU shard: 140 GB ÷ 8 = 17.5 GB
This fits comfortably within the A100's 80 GB, leaving room for activations and KV cache.
Every weight matrix in the model gets sharded. For one transformer layer:
Original Wq: [8192, 8192] → 67M params × 2 bytes = 134 MB
Sharded Wq: [1024, 8192] → 8.4M params × 2 bytes = 16.8 MB per GPU
Original FFN W1: [8192, 28672] → 235M params × 2 bytes = 470 MB
Sharded FFN W1: [1024, 28672] → 29.4M params × 2 bytes = 58.8 MB per GPU
Y = X @ W with only a shard of W. We need the full matrix. This brings us to AllGather.
Now we face a problem: to compute Y = X @ W, we need the complete weight matrix W, but each GPU only has 1/4 of it. This is where AllGather comes in.
AllGather is a collective communication operation where each GPU shares its data with all other GPUs. After AllGather, every GPU has the complete data.
Continuing our example with 4 GPUs, each holding a shard:
BEFORE AllGather:
─────────────────
GPU 0 has: W_shard0 = | w00 w01 w02 w03 |
| w10 w11 w12 w13 |
GPU 1 has: W_shard1 = | w20 w21 w22 w23 |
| w30 w31 w32 w33 |
GPU 2 has: W_shard2 = | w40 w41 w42 w43 |
| w50 w51 w52 w53 |
GPU 3 has: W_shard3 = | w60 w61 w62 w63 |
| w70 w71 w72 w73 |
After AllGather completes:
AFTER AllGather:
────────────────
GPU 0 has: W_full = | w00 w01 w02 w03 | ← from shard0 (local)
| w10 w11 w12 w13 |
| w20 w21 w22 w23 | ← received from GPU 1
| w30 w31 w32 w33 |
| w40 w41 w42 w43 | ← received from GPU 2
| w50 w51 w52 w53 |
| w60 w61 w62 w63 | ← received from GPU 3
| w70 w71 w72 w73 |
GPU 1 has: W_full (same complete matrix)
GPU 2 has: W_full (same complete matrix)
GPU 3 has: W_full (same complete matrix)
Now every GPU can compute Y = X @ W_full.
The simplest way to implement AllGather is for each GPU to broadcast its shard to all others:
Naive AllGather communication pattern:
──────────────────────────────────────
GPU 0 sends W_shard0 → GPU 1, GPU 2, GPU 3 (3 sends)
GPU 1 sends W_shard1 → GPU 0, GPU 2, GPU 3 (3 sends)
GPU 2 sends W_shard2 → GPU 0, GPU 1, GPU 3 (3 sends)
GPU 3 sends W_shard3 → GPU 0, GPU 1, GPU 2 (3 sends)
Total messages: N × (N-1) = 4 × 3 = 12 messages
Data per GPU sent: shard_size × (N-1) = S × 3
Data per GPU received: shard_size × (N-1) = S × 3
This creates a lot of network traffic, and all GPUs try to send simultaneously, causing congestion. The network links become a bottleneck because everyone is trying to use them at once.
Ring AllGather arranges GPUs in a logical ring and passes data around the ring in steps. This is much more efficient because it avoids congestion.
Here's how Ring AllGather works with 4 GPUs:
RING ALLGATHER - Step by Step
─────────────────────────────
Initial state:
GPU 0: [S0] GPU 1: [S1] GPU 2: [S2] GPU 3: [S3]
Step 1: Each GPU sends its current data to the next GPU in the ring
GPU 0 → GPU 1: sends S0
GPU 1 → GPU 2: sends S1
GPU 2 → GPU 3: sends S2
GPU 3 → GPU 0: sends S3
After Step 1:
GPU 0: [S0, S3] GPU 1: [S1, S0] GPU 2: [S2, S1] GPU 3: [S3, S2]
Step 2: Each GPU sends what it just received
GPU 0 → GPU 1: sends S3
GPU 1 → GPU 2: sends S0
GPU 2 → GPU 3: sends S1
GPU 3 → GPU 0: sends S2
After Step 2:
GPU 0: [S0, S3, S2] GPU 1: [S1, S0, S3] GPU 2: [S2, S1, S0] GPU 3: [S3, S2, S1]
Step 3: Each GPU sends what it just received
GPU 0 → GPU 1: sends S2
GPU 1 → GPU 2: sends S3
GPU 2 → GPU 3: sends S0
GPU 3 → GPU 0: sends S1
After Step 3 (COMPLETE):
GPU 0: [S0, S1, S2, S3] ✓
GPU 1: [S0, S1, S2, S3] ✓
GPU 2: [S0, S1, S2, S3] ✓
GPU 3: [S0, S1, S2, S3] ✓
Let's compare the communication costs in detail:
NAIVE ALLGATHER:
────────────────
- Each GPU sends: shard_size × (N-1) = S × (N-1)
- Each GPU receives: shard_size × (N-1) = S × (N-1)
- All sends happen SIMULTANEOUSLY → network congestion!
The problem: If each GPU has a 100 MB/s link, and all 4 GPUs try to
send 3 messages each at the same time, they're all competing for
the same network bandwidth. The effective throughput drops.
Time = (S × (N-1)) / (bandwidth / N) ← bandwidth divided among N senders
= S × (N-1) × N / bandwidth
= O(S × N²) with congestion
RING ALLGATHER:
───────────────
- Number of steps: N-1 (for N GPUs)
- Data sent per step per GPU: shard_size = S
- Total data sent per GPU: S × (N-1) ← same total as naive!
The key difference: In each step, each GPU sends to exactly ONE
neighbor and receives from exactly ONE neighbor. No congestion!
Step 1: GPU0→GPU1, GPU1→GPU2, GPU2→GPU3, GPU3→GPU0 (4 parallel transfers)
Step 2: GPU0→GPU1, GPU1→GPU2, GPU2→GPU3, GPU3→GPU0 (4 parallel transfers)
Step 3: GPU0→GPU1, GPU1→GPU2, GPU2→GPU3, GPU3→GPU0 (4 parallel transfers)
Each link is used by exactly one transfer at a time.
Time per step = S / bandwidth ← full bandwidth available!
Total time = (N-1) × S / bandwidth
= S × (N-1) / bandwidth
= O(S × N / bandwidth) ← linear in N, not quadratic!
You might wonder: in Ring AllGather, the steps happen sequentially (step 1, then step 2, then step 3). In naive AllGather, all transfers happen "at once." So with infinite bandwidth, wouldn't naive be faster?
The key insight is that bandwidth is always finite and shared. Think of it like a highway:
NAIVE ALLGATHER - The Traffic Jam
─────────────────────────────────
GPU 0 wants to send to GPU 1, GPU 2, GPU 3 simultaneously.
But GPU 0 has ONE outbound link with bandwidth B.
If GPU 0 sends 3 messages at once:
- Each message gets B/3 bandwidth
- Time to send one shard: S / (B/3) = 3S/B
All 4 GPUs do this simultaneously, but each GPU's outbound link
is the bottleneck for its own sends.
Total time = 3S/B (limited by each GPU's outbound bandwidth)
RING ALLGATHER - Smooth Flow
────────────────────────────
GPU 0 sends to GPU 1 only. Full bandwidth B available.
Step 1: Send S bytes at bandwidth B → time = S/B
Step 2: Send S bytes at bandwidth B → time = S/B
Step 3: Send S bytes at bandwidth B → time = S/B
Total time = 3 × S/B = 3S/B
Same total time! But Ring is more predictable and scales better.
The math works out the same because the bottleneck is always the physical link bandwidth. Ring AllGather is preferred because:
Let's calculate the actual time for one layer:
Per-layer weight size: ~805M params × 2 bytes = 1.61 GB
Number of GPUs: 8
Shard size: 1.61 GB ÷ 8 = 201 MB
Ring AllGather:
- Steps: N-1 = 7
- Data per step: 201 MB
- Total data transferred per GPU: 201 MB × 7 = 1.41 GB
With NVLink at 600 GB/s (bidirectional, so ~300 GB/s per direction):
Time per step = 201 MB ÷ 300 GB/s = 0.67 ms
Total time = 7 steps × 0.67 ms = 4.7 ms per layer
For 80 layers (forward pass):
Total AllGather time: 4.7 ms × 80 = 376 ms
AllGather solves the forward pass. But during training, after computing gradients, we face the reverse problem: each GPU has computed full gradients, and we need to (1) sum them across GPUs and (2) re-shard them for storage. This is where ReduceScatter comes in.
ReduceScatter combines two operations:
The result: each GPU ends up with a summed shard of the data.
In data parallel training, each GPU processes different data batches. The gradients computed on each GPU are partial gradients — they only reflect the loss on that GPU's data. To get the true gradient (reflecting all data), we must sum them:
True gradient = ∇W⁰ + ∇W¹ + ∇W² + ∇W³
= gradient from GPU0's data + GPU1's data + GPU2's data + GPU3's data
Just like AllGather, ReduceScatter can be implemented efficiently using a ring topology.
Here's how Ring ReduceScatter works:
RING REDUCESCATTER - Step by Step
─────────────────────────────────
Initial state (each GPU has full gradient, split into 4 chunks):
GPU 0: [∇W⁰₀, ∇W⁰₁, ∇W⁰₂, ∇W⁰₃] ← full gradient from GPU0's data
GPU 1: [∇W¹₀, ∇W¹₁, ∇W¹₂, ∇W¹₃] ← full gradient from GPU1's data
GPU 2: [∇W²₀, ∇W²₁, ∇W²₂, ∇W²₃] ← full gradient from GPU2's data
GPU 3: [∇W³₀, ∇W³₁, ∇W³₂, ∇W³₃] ← full gradient from GPU3's data
Step 1: Each GPU sends one chunk to next neighbor, receives and ADDS
GPU 0 sends chunk 0 → GPU 1 receives, computes ∇W⁰₀ + ∇W¹₀
GPU 1 sends chunk 1 → GPU 2 receives, computes ∇W¹₁ + ∇W²₁
GPU 2 sends chunk 2 → GPU 3 receives, computes ∇W²₂ + ∇W³₂
GPU 3 sends chunk 3 → GPU 0 receives, computes ∇W³₃ + ∇W⁰₃
After Step 1: Each GPU has a 2-way partial sum for one chunk
Step 2: Pass the partial sums around, keep adding
(partial sums continue flowing around the ring)
Step 3: Final step
(4-way sums complete)
After Step 3 (COMPLETE):
GPU 0: [Σ∇W₀, -, -, -] ← only keeps fully-summed chunk 0
GPU 1: [-, Σ∇W₁, -, -] ← only keeps fully-summed chunk 1
GPU 2: [-, -, Σ∇W₂, -] ← only keeps fully-summed chunk 2
GPU 3: [-, -, -, Σ∇W₃] ← only keeps fully-summed chunk 3
Where Σ∇Wᵢ = ∇W⁰ᵢ + ∇W¹ᵢ + ∇W²ᵢ + ∇W³ᵢ (sum across all GPUs)
You might wonder: why not just use AllReduce (which gives every GPU the full summed gradient) and then discard the parts we don't need?
OPTION A: AllReduce then discard
────────────────────────────────
AllReduce = ReduceScatter + AllGather (two phases!)
Phase 1 (ReduceScatter): Create partial sums, distribute shards
- N-1 steps
- After this: each GPU has ONE summed shard
Phase 2 (AllGather): Collect all shards to every GPU
- N-1 steps
- After this: each GPU has ALL summed shards (full gradient)
Total steps: 2 × (N-1)
Data transferred per GPU: 2 × data_size × (N-1)/N
Then we DISCARD (N-1)/N of the data we just gathered! Wasteful!
OPTION B: ReduceScatter only (what FSDP uses)
─────────────────────────────────────────────
Just do Phase 1 and stop!
- N-1 steps
- Each GPU gets only its summed shard directly
- No wasted communication
Total steps: N-1
Data transferred per GPU: data_size × (N-1)/N
Now we can see how FSDP orchestrates these operations across forward and backward passes.
The forward pass has three steps: AllGather → Compute → Free. Here's the visual flow:
And the detailed math:
FORWARD PASS - Layer i
──────────────────────
STEP 1: AllGather weights
Before: Each GPU has W_shard [1024, 8192] (16.8 MB)
After: Each GPU has W_full [8192, 8192] (134 MB)
Time: ~4.7 ms (Ring AllGather with 8 GPUs)
STEP 2: Compute Y = X @ W
Input X: [batch, seq_len, 8192]
Weight W_full: [8192, 8192]
Output Y: [batch, seq_len, 8192]
STEP 3: Free W_full
Discard the gathered weights, keep only W_shard
Memory returns to: 16.8 MB (shard only)
→ Proceed to layer i+1
The backward pass is more complex: AllGather → Compute Gradients → ReduceScatter. Here's the visual flow:
And the detailed math:
BACKWARD PASS - Layer i
───────────────────────
STEP 1: AllGather weights (again)
We need W_full to compute gradients (we freed it after forward!)
Time: ~4.7 ms
STEP 2: Compute gradients
Given: upstream gradient ∇Y [batch, seq_len, 8192]
Gradient w.r.t. input (to pass to previous layer):
∇X = ∇Y @ Wᵀ
Gradient w.r.t. weights (to update the weights):
∇W = Xᵀ @ ∇Y
Result: [8192, 8192] gradient matrix
STEP 3: Free W_full
Memory returns to shard only
STEP 4: ReduceScatter gradients
Before: Each GPU has ∇W_local [8192, 8192] (full, local)
After: Each GPU has ∇W_shard [1024, 8192] (summed across GPUs)
Time: ~4.7 ms
→ Proceed to layer i-1
COMMUNICATION SUMMARY (per training step)
─────────────────────────────────────────
Forward pass (80 layers):
- AllGather per layer: ~4.7 ms
- Total: 4.7 ms × 80 = 376 ms
Backward pass (80 layers):
- AllGather per layer: ~4.7 ms
- ReduceScatter per layer: ~4.7 ms
- Total: 9.4 ms × 80 = 752 ms
───────────────────────────────────────────
Total communication: 376 ms + 752 ms ≈ 1.1 seconds
For inference (forward only): 376 ms
But wait — this assumes communication and compute happen sequentially. Can we do better?
The key to making FSDP efficient is hiding communication latency behind compute. While one layer is computing, we can prefetch the weights for the next layer.
SEQUENTIAL EXECUTION (naive):
─────────────────────────────
Time →
Layer 1: [AllGather]────[Compute]────[Free]
Layer 2: [AllGather]────[Compute]────[Free]
Layer 3: [AllGather]────[Compute]────[Free]
↑
GPU sits idle during AllGather!
Total time = Σ(AllGather_time + Compute_time) for all layers
The insight: while the GPU is busy computing layer L, the network is idle. We can use this time to prefetch weights for layer L+1!
PIPELINED EXECUTION (optimized):
────────────────────────────────
Time →
Layer 1: [AllGather₁]────[Compute₁]────[Free₁]
Layer 2: [AllGather₂]────[Compute₂]────[Free₂]
Layer 3: [AllGather₃]────[Compute₃]────[Free₃]
↑ ↑
└──────────┴── AllGather overlaps with previous Compute!
Total time ≈ AllGather₁ + Σ(Compute_time) + small overhead
≈ Compute-dominated (communication hidden!)
PREFETCH MECHANISM:
───────────────────
While GPU computes Layer L:
1. GPU compute cores: busy with matrix multiplies for layer L
2. Network interface: idle (no communication happening)
3. Opportunity: start AllGather for layer L+1 weights!
Implementation uses CUDA streams:
- Compute stream: runs matrix multiplies
- Communication stream: runs AllGather/ReduceScatter
- Both execute simultaneously on modern GPUs
Memory consideration during overlap:
- Layer L: W_full (being used for compute)
- Layer L+1: W_full (being gathered)
- Peak memory: 2 × 134 MB = 268 MB extra (manageable)
PIPELINING EFFECTIVENESS:
─────────────────────────
Case 1: Compute time > Communication time (IDEAL)
─────────────────────────────────────────────────
[AllGather]
[────────Compute────────]
[AllGather] ← fits entirely within Compute!
[────────Compute────────]
Communication is fully hidden!
Effective overhead: ~0
Case 2: Compute time < Communication time (LIMITED)
─────────────────────────────────────────────────
[────AllGather────]
[Compute]
[────AllGather────] ← extends beyond Compute
[Compute]
Some communication cannot be hidden.
Effective overhead: (Comm_time - Compute_time) per layer
For Llama-70B with long sequences:
Llama-70B, batch_size=1, seq_len=1M:
- Compute time per layer: ~50-200 ms (large matrix multiplies)
- AllGather time per layer: ~4.7 ms
- Ratio: Compute >> Communication (10-40×)
- Result: Communication is almost fully hidden! ✓
For smaller batches or shorter sequences:
- Compute time decreases proportionally
- Communication time stays the same
- Pipelining becomes less effective
Let's clarify what we need for our inference use case:
| Operation | Training | Inference |
|---|---|---|
| AllGather (forward) | ✓ Required | ✓ Required |
| Compute forward | ✓ Required | ✓ Required |
| Free weights | ✓ Required | ✓ Required |
| AllGather (backward) | ✓ Required | ✗ Not needed |
| Compute gradients | ✓ Required | ✗ Not needed |
| ReduceScatter | ✓ Required | ✗ Not needed |
| Optimizer update | ✓ Required | ✗ Not needed |
| Concept | What it does | Key insight |
|---|---|---|
| Weight Sharding | Split weights 1/N per GPU | Reduces memory from 140 GB to 17.5 GB |
| Ring AllGather | Reconstruct full weights | Avoids congestion via pipelining |
| Ring ReduceScatter | Sum + re-shard gradients | 50% less communication than AllReduce |
| Prefetching | Overlap comm & compute | Hides communication behind compute |
| Weight shard per GPU | 17.5 GB |
| AllGather time per layer | ~4.7 ms |
| ReduceScatter time per layer | ~4.7 ms |
| Peak weight memory per GPU | ~18 GB (shard + 1 full layer) |
| Effective overhead (with pipelining) | Near zero when compute-bound |
| Part 3: Sequence Sharding (Ulysses) | Splitting the 1M token sequence across GPUs with All-to-All |
| Part 4 | Ring Attention — distributed attention without O(n²) memory |
| Part 5 | Putting It Together — the complete USP picture |
Weights are sharded. Now let's shard the sequence.