Created by Darshan Fofadiya

← Part 1: The GPU Memory Problem

Part 2: Weight Sharding (FSDP)

Distributing 140 GB of weights across 8 GPUs

By Darshan Fofadiya

Part 1: GPU Memory Part 2: FSDP Part 3: Ulysses Part 4: Ring Attention Part 5: USP

In Part 1, we established that Llama-70B's weights require 140 GB — nearly double the A100's 80 GB capacity. How do we fit something that doesn't fit?

The answer is FSDP (Fully Sharded Data Parallel). But FSDP is built on several fundamental concepts that we need to understand first. In this part, we'll build up our understanding step by step:

  1. First, we'll understand weight sharding — the core idea of splitting weights across GPUs
  2. Then, we'll learn about AllGather — how to reconstruct full weights when we need them
  3. Next, we'll cover ReduceScatter — how to aggregate gradients efficiently
  4. Finally, we'll see how FSDP orchestrates these operations together
A note on scope: This blog series focuses on inference for long-context models. However, FSDP was originally designed for training, where it shards weights, gradients, and optimizer states. We'll explain the full FSDP picture (forward + backward pass) because understanding the complete mechanism helps clarify how weight sharding works. For inference, we only need the forward pass portion — but the concepts transfer directly.

2.1 Weight Sharding: The Core Idea

The fundamental insight: instead of replicating all weights on every GPU (like standard Data Parallel), each GPU stores only 1/N of the weights.

2.1.1 A Concrete Example

Let's work through a simple example to build intuition. We'll use 4 GPUs here to keep the step-by-step traces readable — fewer GPUs means fewer steps to follow. Later, we'll scale up to 8 GPUs for the real Llama-70B numbers.

Suppose we have a weight matrix W and 4 GPUs:

Original weight matrix W (shape [8, 4]):

W = | w00  w01  w02  w03 |
    | w10  w11  w12  w13 |
    | w20  w21  w22  w23 |
    | w30  w31  w32  w33 |
    | w40  w41  w42  w43 |
    | w50  w51  w52  w53 |
    | w60  w61  w62  w63 |
    | w70  w71  w72  w73 |

We split the matrix evenly across GPUs. In practice, FSDP flattens all parameters into a 1D buffer and divides it into equal chunks — but for visualization, think of it as splitting the rows:

GPU 0 holds W_shard0 (rows 0-1):
    | w00  w01  w02  w03 |
    | w10  w11  w12  w13 |

GPU 1 holds W_shard1 (rows 2-3):
    | w20  w21  w22  w23 |
    | w30  w31  w32  w33 |

GPU 2 holds W_shard2 (rows 4-5):
    | w40  w41  w42  w43 |
    | w50  w51  w52  w53 |

GPU 3 holds W_shard3 (rows 6-7):
    | w60  w61  w62  w63 |
    | w70  w71  w72  w73 |

2.1.2 Memory Savings

Applying this to Llama-70B with 8 GPUs:

Total weights: 140 GB (70B params × 2 bytes in BF16)
Number of GPUs: 8

Per-GPU shard: 140 GB ÷ 8 = 17.5 GB

This fits comfortably within the A100's 80 GB, leaving room for activations and KV cache.

2.1.3 What Gets Sharded

Every weight matrix in the model gets sharded. For one transformer layer:

Original Wq: [8192, 8192]  → 67M params × 2 bytes = 134 MB
Sharded Wq:  [1024, 8192]  → 8.4M params × 2 bytes = 16.8 MB per GPU

Original FFN W1: [8192, 28672]  → 235M params × 2 bytes = 470 MB
Sharded FFN W1:  [1024, 28672]  → 29.4M params × 2 bytes = 58.8 MB per GPU
Result: Each GPU now holds only 17.5 GB of weights instead of 140 GB. But there's a catch — we can't compute Y = X @ W with only a shard of W. We need the full matrix. This brings us to AllGather.

2.2 AllGather: Reconstructing the Full Matrix

Now we face a problem: to compute Y = X @ W, we need the complete weight matrix W, but each GPU only has 1/4 of it. This is where AllGather comes in.

2.2.1 What is AllGather?

AllGather is a collective communication operation where each GPU shares its data with all other GPUs. After AllGather, every GPU has the complete data.

2.2.2 Step-by-Step Matrix Example

Continuing our example with 4 GPUs, each holding a shard:

BEFORE AllGather:
─────────────────
GPU 0 has: W_shard0 = | w00  w01  w02  w03 |
                      | w10  w11  w12  w13 |

GPU 1 has: W_shard1 = | w20  w21  w22  w23 |
                      | w30  w31  w32  w33 |

GPU 2 has: W_shard2 = | w40  w41  w42  w43 |
                      | w50  w51  w52  w53 |

GPU 3 has: W_shard3 = | w60  w61  w62  w63 |
                      | w70  w71  w72  w73 |

After AllGather completes:

AFTER AllGather:
────────────────
GPU 0 has: W_full = | w00  w01  w02  w03 |  ← from shard0 (local)
                    | w10  w11  w12  w13 |
                    | w20  w21  w22  w23 |  ← received from GPU 1
                    | w30  w31  w32  w33 |
                    | w40  w41  w42  w43 |  ← received from GPU 2
                    | w50  w51  w52  w53 |
                    | w60  w61  w62  w63 |  ← received from GPU 3
                    | w70  w71  w72  w73 |

GPU 1 has: W_full (same complete matrix)
GPU 2 has: W_full (same complete matrix)
GPU 3 has: W_full (same complete matrix)

Now every GPU can compute Y = X @ W_full.

2.2.3 Naive AllGather: The Communication Problem

The simplest way to implement AllGather is for each GPU to broadcast its shard to all others:

Naive AllGather communication pattern:
──────────────────────────────────────
GPU 0 sends W_shard0 → GPU 1, GPU 2, GPU 3  (3 sends)
GPU 1 sends W_shard1 → GPU 0, GPU 2, GPU 3  (3 sends)
GPU 2 sends W_shard2 → GPU 0, GPU 1, GPU 3  (3 sends)
GPU 3 sends W_shard3 → GPU 0, GPU 1, GPU 2  (3 sends)

Total messages: N × (N-1) = 4 × 3 = 12 messages
Data per GPU sent: shard_size × (N-1) = S × 3
Data per GPU received: shard_size × (N-1) = S × 3

This creates a lot of network traffic, and all GPUs try to send simultaneously, causing congestion. The network links become a bottleneck because everyone is trying to use them at once.

2.2.4 Ring AllGather: The Efficient Solution

Ring AllGather arranges GPUs in a logical ring and passes data around the ring in steps. This is much more efficient because it avoids congestion.

Here's how Ring AllGather works with 4 GPUs:

RING ALLGATHER - Step by Step
─────────────────────────────
Initial state:
  GPU 0: [S0]      GPU 1: [S1]      GPU 2: [S2]      GPU 3: [S3]

Step 1: Each GPU sends its current data to the next GPU in the ring
  GPU 0 → GPU 1: sends S0
  GPU 1 → GPU 2: sends S1
  GPU 2 → GPU 3: sends S2
  GPU 3 → GPU 0: sends S3

After Step 1:
  GPU 0: [S0, S3]  GPU 1: [S1, S0]  GPU 2: [S2, S1]  GPU 3: [S3, S2]

Step 2: Each GPU sends what it just received
  GPU 0 → GPU 1: sends S3
  GPU 1 → GPU 2: sends S0
  GPU 2 → GPU 3: sends S1
  GPU 3 → GPU 0: sends S2

After Step 2:
  GPU 0: [S0, S3, S2]  GPU 1: [S1, S0, S3]  GPU 2: [S2, S1, S0]  GPU 3: [S3, S2, S1]

Step 3: Each GPU sends what it just received
  GPU 0 → GPU 1: sends S2
  GPU 1 → GPU 2: sends S3
  GPU 2 → GPU 3: sends S0
  GPU 3 → GPU 0: sends S1

After Step 3 (COMPLETE):
  GPU 0: [S0, S1, S2, S3]  ✓
  GPU 1: [S0, S1, S2, S3]  ✓
  GPU 2: [S0, S1, S2, S3]  ✓
  GPU 3: [S0, S1, S2, S3]  ✓

2.2.5 Why Ring AllGather is Efficient

Let's compare the communication costs in detail:

NAIVE ALLGATHER:
────────────────
- Each GPU sends: shard_size × (N-1) = S × (N-1)
- Each GPU receives: shard_size × (N-1) = S × (N-1)
- All sends happen SIMULTANEOUSLY → network congestion!

The problem: If each GPU has a 100 MB/s link, and all 4 GPUs try to 
send 3 messages each at the same time, they're all competing for 
the same network bandwidth. The effective throughput drops.

Time = (S × (N-1)) / (bandwidth / N)  ← bandwidth divided among N senders
     = S × (N-1) × N / bandwidth
     = O(S × N²) with congestion
RING ALLGATHER:
───────────────
- Number of steps: N-1 (for N GPUs)
- Data sent per step per GPU: shard_size = S
- Total data sent per GPU: S × (N-1)  ← same total as naive!

The key difference: In each step, each GPU sends to exactly ONE 
neighbor and receives from exactly ONE neighbor. No congestion!

Step 1: GPU0→GPU1, GPU1→GPU2, GPU2→GPU3, GPU3→GPU0 (4 parallel transfers)
Step 2: GPU0→GPU1, GPU1→GPU2, GPU2→GPU3, GPU3→GPU0 (4 parallel transfers)
Step 3: GPU0→GPU1, GPU1→GPU2, GPU2→GPU3, GPU3→GPU0 (4 parallel transfers)

Each link is used by exactly one transfer at a time.

Time per step = S / bandwidth  ← full bandwidth available!
Total time = (N-1) × S / bandwidth
           = S × (N-1) / bandwidth
           = O(S × N / bandwidth)  ← linear in N, not quadratic!
Key insight: Ring AllGather transfers the same total amount of data as naive AllGather, but it pipelines the communication so there's no congestion. Each GPU is always sending to exactly one neighbor and receiving from exactly one neighbor, so every link operates at full bandwidth.

2.2.6 But Wait — Wouldn't Naive Be Faster with Infinite Bandwidth?

You might wonder: in Ring AllGather, the steps happen sequentially (step 1, then step 2, then step 3). In naive AllGather, all transfers happen "at once." So with infinite bandwidth, wouldn't naive be faster?

The key insight is that bandwidth is always finite and shared. Think of it like a highway:

NAIVE ALLGATHER - The Traffic Jam
─────────────────────────────────
GPU 0 wants to send to GPU 1, GPU 2, GPU 3 simultaneously.
But GPU 0 has ONE outbound link with bandwidth B.

If GPU 0 sends 3 messages at once:
  - Each message gets B/3 bandwidth
  - Time to send one shard: S / (B/3) = 3S/B

All 4 GPUs do this simultaneously, but each GPU's outbound link
is the bottleneck for its own sends.

Total time = 3S/B (limited by each GPU's outbound bandwidth)
RING ALLGATHER - Smooth Flow
────────────────────────────
GPU 0 sends to GPU 1 only. Full bandwidth B available.

Step 1: Send S bytes at bandwidth B → time = S/B
Step 2: Send S bytes at bandwidth B → time = S/B  
Step 3: Send S bytes at bandwidth B → time = S/B

Total time = 3 × S/B = 3S/B

Same total time! But Ring is more predictable and scales better.

The math works out the same because the bottleneck is always the physical link bandwidth. Ring AllGather is preferred because:

2.2.6 Communication Cost for Llama-70B

Let's calculate the actual time for one layer:

Per-layer weight size: ~805M params × 2 bytes = 1.61 GB
Number of GPUs: 8
Shard size: 1.61 GB ÷ 8 = 201 MB

Ring AllGather:
  - Steps: N-1 = 7
  - Data per step: 201 MB
  - Total data transferred per GPU: 201 MB × 7 = 1.41 GB

With NVLink at 600 GB/s (bidirectional, so ~300 GB/s per direction):
  Time per step = 201 MB ÷ 300 GB/s = 0.67 ms
  Total time = 7 steps × 0.67 ms = 4.7 ms per layer

For 80 layers (forward pass):
  Total AllGather time: 4.7 ms × 80 = 376 ms

2.3 ReduceScatter: Aggregating and Re-sharding

AllGather solves the forward pass. But during training, after computing gradients, we face the reverse problem: each GPU has computed full gradients, and we need to (1) sum them across GPUs and (2) re-shard them for storage. This is where ReduceScatter comes in.

2.3.1 What is ReduceScatter?

ReduceScatter combines two operations:

The result: each GPU ends up with a summed shard of the data.

2.3.2 Why Do We Need to Sum Gradients?

In data parallel training, each GPU processes different data batches. The gradients computed on each GPU are partial gradients — they only reflect the loss on that GPU's data. To get the true gradient (reflecting all data), we must sum them:

True gradient = ∇W⁰ + ∇W¹ + ∇W² + ∇W³
              = gradient from GPU0's data + GPU1's data + GPU2's data + GPU3's data

2.3.3 Ring ReduceScatter

Just like AllGather, ReduceScatter can be implemented efficiently using a ring topology.

Here's how Ring ReduceScatter works:

RING REDUCESCATTER - Step by Step
─────────────────────────────────
Initial state (each GPU has full gradient, split into 4 chunks):
  GPU 0: [∇W⁰₀, ∇W⁰₁, ∇W⁰₂, ∇W⁰₃]  ← full gradient from GPU0's data
  GPU 1: [∇W¹₀, ∇W¹₁, ∇W¹₂, ∇W¹₃]  ← full gradient from GPU1's data
  GPU 2: [∇W²₀, ∇W²₁, ∇W²₂, ∇W²₃]  ← full gradient from GPU2's data
  GPU 3: [∇W³₀, ∇W³₁, ∇W³₂, ∇W³₃]  ← full gradient from GPU3's data

Step 1: Each GPU sends one chunk to next neighbor, receives and ADDS
  GPU 0 sends chunk 0 → GPU 1 receives, computes ∇W⁰₀ + ∇W¹₀
  GPU 1 sends chunk 1 → GPU 2 receives, computes ∇W¹₁ + ∇W²₁
  GPU 2 sends chunk 2 → GPU 3 receives, computes ∇W²₂ + ∇W³₂
  GPU 3 sends chunk 3 → GPU 0 receives, computes ∇W³₃ + ∇W⁰₃

After Step 1: Each GPU has a 2-way partial sum for one chunk

Step 2: Pass the partial sums around, keep adding
  (partial sums continue flowing around the ring)

Step 3: Final step
  (4-way sums complete)

After Step 3 (COMPLETE):
  GPU 0: [Σ∇W₀, -, -, -]  ← only keeps fully-summed chunk 0
  GPU 1: [-, Σ∇W₁, -, -]  ← only keeps fully-summed chunk 1
  GPU 2: [-, -, Σ∇W₂, -]  ← only keeps fully-summed chunk 2
  GPU 3: [-, -, -, Σ∇W₃]  ← only keeps fully-summed chunk 3

Where Σ∇Wᵢ = ∇W⁰ᵢ + ∇W¹ᵢ + ∇W²ᵢ + ∇W³ᵢ (sum across all GPUs)

2.3.4 Why Not AllReduce + Discard?

You might wonder: why not just use AllReduce (which gives every GPU the full summed gradient) and then discard the parts we don't need?

OPTION A: AllReduce then discard
────────────────────────────────
AllReduce = ReduceScatter + AllGather (two phases!)

Phase 1 (ReduceScatter): Create partial sums, distribute shards
  - N-1 steps
  - After this: each GPU has ONE summed shard

Phase 2 (AllGather): Collect all shards to every GPU
  - N-1 steps  
  - After this: each GPU has ALL summed shards (full gradient)

Total steps: 2 × (N-1)
Data transferred per GPU: 2 × data_size × (N-1)/N

Then we DISCARD (N-1)/N of the data we just gathered! Wasteful!
OPTION B: ReduceScatter only (what FSDP uses)
─────────────────────────────────────────────
Just do Phase 1 and stop!

  - N-1 steps
  - Each GPU gets only its summed shard directly
  - No wasted communication

Total steps: N-1
Data transferred per GPU: data_size × (N-1)/N
ReduceScatter uses 50% less communication than AllReduce! AllReduce = ReduceScatter + AllGather. If we only need sharded results (which FSDP does for gradient storage), we skip the AllGather phase entirely, cutting communication in half.

2.4 Putting It Together: The FSDP Timeline

Now we can see how FSDP orchestrates these operations across forward and backward passes.

2.4.1 Forward Pass (Per Layer)

The forward pass has three steps: AllGather → Compute → Free. Here's the visual flow:

And the detailed math:

FORWARD PASS - Layer i
──────────────────────

STEP 1: AllGather weights
  Before: Each GPU has W_shard [1024, 8192] (16.8 MB)
  After:  Each GPU has W_full [8192, 8192] (134 MB)
  Time: ~4.7 ms (Ring AllGather with 8 GPUs)

STEP 2: Compute Y = X @ W
  Input X: [batch, seq_len, 8192]
  Weight W_full: [8192, 8192]
  Output Y: [batch, seq_len, 8192]

STEP 3: Free W_full
  Discard the gathered weights, keep only W_shard
  Memory returns to: 16.8 MB (shard only)

→ Proceed to layer i+1

2.4.2 Backward Pass (Per Layer)

The backward pass is more complex: AllGather → Compute Gradients → ReduceScatter. Here's the visual flow:

And the detailed math:

BACKWARD PASS - Layer i
───────────────────────

STEP 1: AllGather weights (again)
  We need W_full to compute gradients (we freed it after forward!)
  Time: ~4.7 ms

STEP 2: Compute gradients
  Given: upstream gradient ∇Y [batch, seq_len, 8192]
  
  Gradient w.r.t. input (to pass to previous layer):
    ∇X = ∇Y @ Wᵀ
  
  Gradient w.r.t. weights (to update the weights):
    ∇W = Xᵀ @ ∇Y
    Result: [8192, 8192] gradient matrix

STEP 3: Free W_full
  Memory returns to shard only

STEP 4: ReduceScatter gradients
  Before: Each GPU has ∇W_local [8192, 8192] (full, local)
  After:  Each GPU has ∇W_shard [1024, 8192] (summed across GPUs)
  Time: ~4.7 ms

→ Proceed to layer i-1

2.4.3 Total Communication Overhead

COMMUNICATION SUMMARY (per training step)
─────────────────────────────────────────

Forward pass (80 layers):
  - AllGather per layer: ~4.7 ms
  - Total: 4.7 ms × 80 = 376 ms

Backward pass (80 layers):
  - AllGather per layer: ~4.7 ms
  - ReduceScatter per layer: ~4.7 ms
  - Total: 9.4 ms × 80 = 752 ms

───────────────────────────────────────────
Total communication: 376 ms + 752 ms ≈ 1.1 seconds

For inference (forward only): 376 ms

But wait — this assumes communication and compute happen sequentially. Can we do better?


2.5 Overlapping Communication and Compute

The key to making FSDP efficient is hiding communication latency behind compute. While one layer is computing, we can prefetch the weights for the next layer.

2.5.1 The Problem with Sequential Execution

SEQUENTIAL EXECUTION (naive):
─────────────────────────────
Time →

Layer 1: [AllGather]────[Compute]────[Free]
Layer 2:                            [AllGather]────[Compute]────[Free]
Layer 3:                                                       [AllGather]────[Compute]────[Free]
                        ↑
                        GPU sits idle during AllGather!

Total time = Σ(AllGather_time + Compute_time) for all layers

2.5.2 Pipelined Execution with Prefetching

The insight: while the GPU is busy computing layer L, the network is idle. We can use this time to prefetch weights for layer L+1!

PIPELINED EXECUTION (optimized):
────────────────────────────────
Time →

Layer 1: [AllGather₁]────[Compute₁]────[Free₁]
Layer 2:            [AllGather₂]────[Compute₂]────[Free₂]
Layer 3:                       [AllGather₃]────[Compute₃]────[Free₃]
                    ↑          ↑
                    └──────────┴── AllGather overlaps with previous Compute!

Total time ≈ AllGather₁ + Σ(Compute_time) + small overhead
           ≈ Compute-dominated (communication hidden!)

2.5.3 How Prefetching Works

PREFETCH MECHANISM:
───────────────────

While GPU computes Layer L:
  1. GPU compute cores: busy with matrix multiplies for layer L
  2. Network interface: idle (no communication happening)
  3. Opportunity: start AllGather for layer L+1 weights!

Implementation uses CUDA streams:
  - Compute stream: runs matrix multiplies
  - Communication stream: runs AllGather/ReduceScatter
  - Both execute simultaneously on modern GPUs

Memory consideration during overlap:
  - Layer L: W_full (being used for compute)
  - Layer L+1: W_full (being gathered)
  - Peak memory: 2 × 134 MB = 268 MB extra (manageable)

2.5.4 When Pipelining Helps (and When It Doesn't)

PIPELINING EFFECTIVENESS:
─────────────────────────

Case 1: Compute time > Communication time (IDEAL)
  ─────────────────────────────────────────────────
  [AllGather]
            [────────Compute────────]
                    [AllGather]  ← fits entirely within Compute!
                              [────────Compute────────]
  
  Communication is fully hidden!
  Effective overhead: ~0

Case 2: Compute time < Communication time (LIMITED)
  ─────────────────────────────────────────────────
  [────AllGather────]
            [Compute]
                    [────AllGather────]  ← extends beyond Compute
                              [Compute]
  
  Some communication cannot be hidden.
  Effective overhead: (Comm_time - Compute_time) per layer

For Llama-70B with long sequences:

Llama-70B, batch_size=1, seq_len=1M:
  - Compute time per layer: ~50-200 ms (large matrix multiplies)
  - AllGather time per layer: ~4.7 ms
  - Ratio: Compute >> Communication (10-40×)
  - Result: Communication is almost fully hidden! ✓

For smaller batches or shorter sequences:
  - Compute time decreases proportionally
  - Communication time stays the same
  - Pipelining becomes less effective
Key insight: FSDP works best when compute time dominates communication time. For large batch sizes and long sequences (exactly our use case!), the communication overhead is almost completely hidden behind compute.

2.6 FSDP for Inference vs Training

Let's clarify what we need for our inference use case:

OperationTrainingInference
AllGather (forward)✓ Required✓ Required
Compute forward✓ Required✓ Required
Free weights✓ Required✓ Required
AllGather (backward)✓ Required✗ Not needed
Compute gradients✓ Required✗ Not needed
ReduceScatter✓ Required✗ Not needed
Optimizer update✓ Required✗ Not needed
For inference: We only need the forward pass portion of FSDP. This means AllGather before each layer's compute, then free the gathered weights. No gradient computation, no ReduceScatter. Communication overhead is cut in half compared to training!

2.7 Summary

Key Concepts

ConceptWhat it doesKey insight
Weight ShardingSplit weights 1/N per GPUReduces memory from 140 GB to 17.5 GB
Ring AllGatherReconstruct full weightsAvoids congestion via pipelining
Ring ReduceScatterSum + re-shard gradients50% less communication than AllReduce
PrefetchingOverlap comm & computeHides communication behind compute

Key Numbers (Llama-70B, 8 GPUs)

Weight shard per GPU17.5 GB
AllGather time per layer~4.7 ms
ReduceScatter time per layer~4.7 ms
Peak weight memory per GPU~18 GB (shard + 1 full layer)
Effective overhead (with pipelining)Near zero when compute-bound
FSDP solves Problem #1: 140 GB weights now fit across 8 GPUs at 17.5 GB each. The communication overhead is largely hidden behind compute through pipelining. But we still have 20.5 GB of Q/K/V activations per layer, and 328 GB of KV cache for 1M tokens. That's where sequence sharding comes in.

What's Next

Part 3: Sequence Sharding (Ulysses)Splitting the 1M token sequence across GPUs with All-to-All
Part 4Ring Attention — distributed attention without O(n²) memory
Part 5Putting It Together — the complete USP picture

Weights are sharded. Now let's shard the sequence.