How to fit a 70B model with 1M context on GPUs
You want to run Llama-70B with a 1 million token context window. Sounds straightforward — just load the model and go, right?
Not quite. Before we can talk about solutions, we need to understand the problem. In this part, we'll do the actual math to see exactly why this is hard:
By the end, you'll understand exactly why we need parallelism — and which specific bottlenecks each parallelism strategy addresses.
The A100 is NVIDIA's data center GPU for AI workloads. Let's understand each spec and why it matters:
This is the GPU's main memory — High Bandwidth Memory (HBM). Everything lives here:
80 GB is our hard constraint. If our data doesn't fit, we can't run.
This is how fast we can move data between HBM and the compute units (tensor cores). Let's put this in perspective:
Memory bandwidth: 2.0 TB/s = 2,000 GB/s
To read 80 GB (entire memory):
Time = 80 GB ÷ 2,000 GB/s = 0.04 seconds = 40 ms
To read 1 GB:
Time = 1 GB ÷ 2,000 GB/s = 0.5 ms
This sounds fast, but for large models, we're constantly streaming weights from memory to compute. Memory bandwidth often becomes the bottleneck — not compute.
TFLOPS = Trillion Floating Point Operations Per Second. The A100 can do 312 trillion BF16 operations per second.
312 TFLOPS = 312 × 10¹² operations/second
For a matrix multiply of [M, K] × [K, N]:
Operations = 2 × M × K × N (multiply-add)
Example: [8192, 8192] × [8192, 8192]
Operations = 2 × 8192 × 8192 × 8192 = 1.1 × 10¹² = 1.1 TFLOP
Time (compute-bound) = 1.1 TFLOP ÷ 312 TFLOPS = 3.5 μs
But wait — we also need to read the matrices from memory:
Data to read: 2 matrices × 8192 × 8192 × 2 bytes = 268 MB
Time (memory-bound) = 268 MB ÷ 2,000 GB/s = 134 μs
Compute time: 3.5 μs
Memory time: 134 μs ← 38× slower!
We're memory-bound, not compute-bound.
When we distribute work across multiple GPUs, they need to communicate. There are two main interconnects:
PCIe Gen4 x16:
- Bandwidth: ~32 GB/s (bidirectional)
- Latency: ~1-2 μs
- Connects: GPU ↔ CPU, GPU ↔ GPU (different nodes)
- Available on: All GPUs
NVLink (A100):
- Bandwidth: 600 GB/s (bidirectional, 12 links × 50 GB/s)
- Latency: ~0.5 μs
- Connects: GPU ↔ GPU (same node only)
- Available on: Data center GPUs (A100, H100)
Let's see what this means for transferring 1 GB of data:
Transfer 1 GB over PCIe:
Time = 1 GB ÷ 32 GB/s = 31.25 ms
Transfer 1 GB over NVLink:
Time = 1 GB ÷ 600 GB/s = 1.67 ms
NVLink is ~19× faster than PCIe!
This is why multi-GPU training and inference strongly prefer NVLink-connected GPUs within a single node. Cross-node communication (which must use PCIe/InfiniBand) is much slower.
Understanding where data lives and how fast we can access it:
MEMORY HIERARCHY (A100):
────────────────────────
Level Size Bandwidth Latency
─────────────────────────────────────────────────────
Registers ~20 MB ~20 TB/s ~1 cycle
L2 Cache 40 MB ~5 TB/s ~30 cycles
HBM (main) 80 GB 2 TB/s ~300 cycles
NVLink N/A 600 GB/s ~500 cycles
PCIe N/A 32 GB/s ~1000 cycles
CPU RAM ~1 TB ~100 GB/s ~10000 cycles
The key takeaway: HBM is our working memory, and 80 GB is all we have. If data doesn't fit in HBM, we have to go to much slower storage.
Now let's understand what we're trying to fit into that 80 GB. Llama-3-70B is a decoder-only transformer with these parameters:
| Parameter | Value | What it means |
|---|---|---|
d_model | 8192 | Each token is represented as a vector of 8192 numbers |
n_heads | 64 | Query attention heads |
n_kv_heads | 8 | Key/Value heads (GQA — 8× fewer than query heads) |
d_head | 128 | Each head works with 128-dimensional vectors |
n_layers | 80 | 80 transformer blocks stacked sequentially |
d_ff | 28672 | FFN hidden dimension (~3.5× d_model) |
vocab_size | 128,256 | Number of unique tokens the model knows |
Llama-3 uses GQA, not standard multi-head attention. Instead of 64 separate K and V projections (one per head), it uses only 8 — each shared by 8 query heads. This reduces parameters and KV cache size:
Standard MHA: 64 query heads, 64 key heads, 64 value heads
Llama-3 GQA: 64 query heads, 8 key heads, 8 value heads
Each group of 8 query heads shares 1 key head and 1 value head.
Each of the 80 layers contains these weight matrices:
ATTENTION BLOCK (with GQA):
───────────────────────────
Wq (Query projection): [8192, 8192] = 67M params (64 heads × 128)
Wk (Key projection): [8192, 1024] = 8.4M params (8 heads × 128)
Wv (Value projection): [8192, 1024] = 8.4M params (8 heads × 128)
Wo (Output projection): [8192, 8192] = 67M params
Attention total: 67 + 8.4 + 8.4 + 67 = 150.8M params per layer
FFN BLOCK (SwiGLU — 3 matrices, not 2):
───────────────────────────────────────
W_gate: [8192, 28672] = 235M params (gate projection)
W_up: [8192, 28672] = 235M params (up projection)
W_down: [28672, 8192] = 235M params (down projection)
FFN total: 235 × 3 = 705M params per layer
LAYER NORMS: ~16K params (negligible)
TOTAL PER LAYER: 150.8M + 705M ≈ 856M params
Now let's add it all up:
TRANSFORMER LAYERS:
───────────────────
856M params × 80 layers = 68.5B params
EMBEDDINGS:
───────────
Token embeddings: [128256, 8192] = 1.05B params
Output projection: [8192, 128256] = 1.05B params
(These are sometimes tied, but Llama-3 keeps them separate)
TOTAL: 68.5B + 1.05B + 1.05B ≈ 70.6B parameters ✓
Now let's calculate the actual memory needed. We'll use BF16 (bfloat16), which is standard for inference:
BF16 = 16 bits = 2 bytes per parameter
Total parameters: 70 × 10⁹
Memory = 70 × 10⁹ params × 2 bytes/param
= 140 × 10⁹ bytes
= 140 GB
You might think: "Just use INT8 or INT4 quantization!" Let's see:
INT8 (8-bit): 70B × 1 byte = 70 GB ← Fits! But quality degrades
INT4 (4-bit): 70B × 0.5 bytes = 35 GB ← Fits easily! More quality loss
But wait — we still need memory for:
- Activations (Q, K, V tensors)
- KV cache
- Intermediate buffers
Even with INT4 weights, we'll run out of memory for long sequences.
Quantization helps, but it doesn't solve the fundamental problem for long-context inference. We'll still need parallelism.
Now let's see what happens when we actually run inference. We have a 1 million token input — think of a massive document, an entire codebase, or a long conversation history. Let's trace through the memory requirements step by step.
Our input starts as a sequence of token IDs, which get embedded into vectors:
Input tokens: [1, 1,000,000] (batch_size=1, seq_len=1M)
After embedding lookup:
X: [batch_size, seq_len, d_model]
X: [1, 1,000,000, 8192]
Memory for X:
Elements: 1 × 1,000,000 × 8192 = 8.192 × 10⁹
Bytes (BF16): 8.192 × 10⁹ × 2 = 16.38 GB
Just the embedded input takes 16 GB. But this is just the beginning.
In each transformer layer, we project the input into Query, Key, and Value tensors. Remember, with GQA, K and V use only 8 heads (not 64):
Q = X @ Wq → [1, 1M, 8192] @ [8192, 8192] = [1, 1M, 8192] (64 query heads × 128)
K = X @ Wk → [1, 1M, 8192] @ [8192, 1024] = [1, 1M, 1024] (8 KV heads × 128, GQA)
V = X @ Wv → [1, 1M, 8192] @ [8192, 1024] = [1, 1M, 1024] (8 KV heads × 128, GQA)
Q is large (64 heads), but K and V are 8× smaller thanks to GQA. Let's calculate the memory:
Q: 1 × 1,000,000 × 8192 × 2 bytes = 16.38 GB (64 heads)
K: 1 × 1,000,000 × 1024 × 2 bytes = 2.05 GB (8 KV heads)
V: 1 × 1,000,000 × 1024 × 2 bytes = 2.05 GB (8 KV heads)
─────────────────────────────────────────────
Total Q+K+V: 20.48 GB (for ONE layer)
You might wonder: if each layer needs 20.5 GB for Q, K, V, don't we need 20.5 × 80 = 1.6 TB total?
No — and here's why. During inference, we process one layer at a time:
Layer 1: Compute Q, K, V → Compute attention → Get output → FREE Q, K, V
Layer 2: Compute Q, K, V → Compute attention → Get output → FREE Q, K, V
...
Layer 80: Compute Q, K, V → Compute attention → Get output → FREE Q, K, V
We reuse the same memory buffer for each layer's activations. So 20.5 GB is the peak activation memory, not cumulative. But 20.5 GB is still a significant chunk of our 80 GB budget — especially when combined with weights and KV cache.
During autoregressive generation (generating tokens one at a time), we don't want to recompute K and V for all previous tokens. Instead, we cache them:
KV Cache structure per layer (with GQA — 8 KV heads, not 64):
K cache: [batch, seq_len, n_kv_heads, d_head] = [1, 1M, 8, 128]
V cache: [batch, seq_len, n_kv_heads, d_head] = [1, 1M, 8, 128]
Memory per layer:
K: 1 × 1,000,000 × 8 × 128 × 2 bytes = 2.05 GB
V: 1 × 1,000,000 × 8 × 128 × 2 bytes = 2.05 GB
Total: 4.1 GB per layer
All 80 layers:
4.1 GB × 80 = 328 GB
This is why long-context inference is so challenging. The KV cache grows linearly with sequence length, and we need to keep all of it in memory.
We've saved the worst for last. The attention mechanism computes:
Attention(Q, K, V) = softmax(Q @ Kᵀ / √d_head) @ V
The critical operation is Q @ Kᵀ — this produces the attention scores matrix, where every token attends to every other token.
Let's trace through the shapes carefully. First, we reshape Q and K for multi-head attention:
Original shapes:
Q: [batch, seq_len, d_model] = [1, 1M, 8192]
K: [batch, seq_len, d_model] = [1, 1M, 8192]
Reshape for multi-head (split d_model into n_heads × d_head):
Q: [batch, seq_len, n_heads, d_head] = [1, 1M, 64, 128]
K: [batch, seq_len, n_kv_heads, d_head] = [1, 1M, 8, 128] (GQA)
For attention, each KV head is broadcast to 8 query heads:
K_expanded: [1, 1M, 64, 128] (each KV head repeated 8×)
Transpose for batched matmul:
Q: [batch, n_heads, seq_len, d_head] = [1, 64, 1M, 128]
K: [batch, n_heads, seq_len, d_head] = [1, 64, 1M, 128] (after broadcast)
Kᵀ: [batch, n_heads, d_head, seq_len] = [1, 64, 128, 1M]
Now the attention scores computation:
Attention Scores = Q @ Kᵀ
[1, 64, 1M, 128] @ [1, 64, 128, 1M]
= [1, 64, 1M, 1M]
This is a [seq_len × seq_len] matrix for each head!
Let's compute the memory for this attention scores matrix:
Shape: [1, 64, 1,000,000, 1,000,000]
Elements: 1 × 64 × 1,000,000 × 1,000,000
= 64 × 10¹²
= 64 trillion elements
Memory (BF16): 64 × 10¹² × 2 bytes
= 128 × 10¹² bytes
= 128 TB
Let's see how attention memory scales with sequence length:
Sequence Length Attention Memory (64 heads, BF16)
─────────────── ─────────────────────────────────
1,000 128 MB
10,000 12.8 GB
100,000 1.28 TB
1,000,000 128 TB
10,000,000 12.8 PB (petabytes!)
Every 10× increase in sequence length causes a 100× increase in attention memory. This is why long-context models are so challenging.
FlashAttention is a clever algorithm that computes attention without ever materializing the full attention matrix. The key insight: we don't need to store all attention scores — we can compute them in tiles and immediately use them.
Let's compare standard attention vs FlashAttention:
Standard Attention (naive):
1. Compute full S = Q @ Kᵀ [1, 64, 1M, 1M] → 128 TB ← Store this!
2. Apply softmax to S [1, 64, 1M, 1M] → 128 TB ← Store this!
3. Compute Output = S @ V [1, 64, 1M, 128] → 16 GB
Peak memory: 128 TB (the attention matrix)
FlashAttention processes Q in blocks (tiles), and for each Q block, iterates through all K,V blocks:
FlashAttention (tiled):
For each Q_block (e.g., rows 0-1023 of Q):
Initialize: output_block = 0, running_max = -∞, running_sum = 0
For each K_block, V_block:
1. Compute tile: S_tile = Q_block @ K_blockᵀ [1024, 1024] → 8 MB
2. Compute local softmax statistics
3. Update running softmax (online algorithm)
4. Accumulate: output_block += softmax(S_tile) @ V_block
5. DISCARD S_tile immediately!
Store output_block (final result for these Q rows)
Peak memory: O(tile_size²) ≈ 8 MB per tile
But wait — there's a problem. Softmax requires knowing the maximum value across the entire row:
softmax(x_i) = exp(x_i - max(x)) / Σ exp(x_j - max(x))
For row i of the attention matrix:
- We need max across ALL 1M columns
- We need sum of exp() across ALL 1M columns
But we're only looking at 1024 columns at a time!
FlashAttention solves this with the "online softmax" trick — maintaining running statistics that get corrected as we see more tiles:
Online Softmax Algorithm:
─────────────────────────
After processing K_block 1:
m₁ = max of tile 1
l₁ = sum of exp(scores - m₁) for tile 1
o₁ = softmax(tile 1) @ V_block_1
After processing K_block 2:
m₂ = max(m₁, max of tile 2) ← Update global max
Correction factor for old results:
α = exp(m₁ - m₂) ← Scale down old values
l₂ = α × l₁ + sum of exp(tile 2 - m₂) ← Update running sum
o₂ = α × o₁ + softmax(tile 2) @ V_block_2 ← Correct and accumulate
After ALL K_blocks:
Final output = o_final / l_final ← Normalize
Let's quantify the memory reduction:
Standard Attention:
Attention matrix: [1, 64, 1M, 1M] × 2 bytes = 128 TB
FlashAttention (tile_size = 1024):
One tile: [1, 64, 1024, 1024] × 2 bytes = 128 MB
Running statistics per Q_block: ~few KB
Memory reduction: 128 TB → 128 MB = 1,000,000× less!
But we still need:
Q: 16.38 GB (64 query heads)
K: 2.05 GB (8 KV heads, GQA)
V: 2.05 GB (8 KV heads, GQA)
Output: 16.38 GB
─────────────────
Total: ~37 GB (fits in 80 GB for one layer!)
However, FlashAttention doesn't solve everything. We still need to store Q, K, V tensors (20.5 GB per layer), and for inference, the KV cache still grows linearly with sequence length. For 1M tokens, that's still 328 GB across all layers (thanks to GQA reducing it 8×).
Unlike the KV cache, attention scores are computed and immediately discarded:
Layer 1: Compute attention scores → Use them → DISCARD
Layer 2: Compute attention scores → Use them → DISCARD
...
Layer 80: Compute attention scores → Use them → DISCARD
So 128 TB (or with FlashAttention, much less) is the peak memory for one layer's attention computation, not cumulative across all 80 layers. This is a crucial distinction from the KV cache, which must persist.
Let's put it all together. Here's what we're trying to fit into an 80 GB GPU:
| Component | Size | vs A100 (80 GB) | Persists? |
|---|---|---|---|
| Model Weights (BF16) | 140 GB | 1.75× capacity ❌ | Yes — always in memory |
| Q + K + V (1 layer) | 20.5 GB | 26% of capacity ⚠️ | No — recomputed per layer |
| Attention Scores (1 layer, naive) | 128 TB | 1,600× capacity ❌ | No — discarded after use |
| KV Cache (all 80 layers, GQA) | 328 GB | 4× capacity ❌ | Yes — grows with sequence |
Even with FlashAttention eliminating the 128 TB attention matrix, we still have:
Minimum memory needed:
Weights: 140 GB
Activations: 20.5 GB (peak, one layer at a time)
KV Cache: 328 GB (with GQA)
─────────────────────
Total: ~489 GB
Available on one A100: 80 GB
We need at least 7 GPUs just for memory capacity!
We've identified three distinct memory bottlenecks. Each requires a different parallelism strategy:
| Bottleneck | Size | Solution | How It Works |
|---|---|---|---|
| Weights (140 GB) | 1.75× GPU | Weight Sharding (FSDP) | Split weight matrices across GPUs. Each GPU stores 1/N of the weights and gathers them when needed. |
| Activations (20.5 GB/layer) | 26% GPU | Sequence Sharding (Ulysses) | Split the sequence across GPUs. Each GPU processes 1/N of the tokens, using all-to-all communication for attention. |
| Attention O(n²) | 1,600× GPU | Ring Attention | Compute attention in chunks, passing KV blocks around a ring of GPUs. Never materialize the full attention matrix. |
You might wonder: can't we just use one strategy? Let's see:
FSDP alone (8 GPUs):
Weights: 140 GB ÷ 8 = 17.5 GB per GPU ✓
Activations: Still 20.5 GB per GPU ❌
KV Cache: Still 328 GB total ❌
Sequence Parallelism alone (8 GPUs):
Weights: Still 140 GB per GPU ❌
Activations: 20.5 GB ÷ 8 = 2.6 GB per GPU ✓
KV Cache: 328 GB ÷ 8 = 41 GB per GPU ❌
Ring Attention alone (8 GPUs):
Weights: Still 140 GB per GPU ❌
Activations: Distributed ✓
KV Cache: Distributed ✓
But: Requires weights to fit on each GPU ❌
No single strategy solves all three problems. We need to combine them.
USP combines all three strategies:
USP with 8 GPUs:
FSDP: Weights 140 GB ÷ 8 = 17.5 GB per GPU ✓
Ulysses: Activations 20.5 GB ÷ 8 = 2.6 GB per GPU ✓
Ring Attention: KV Cache distributed across ring ✓
Per GPU: 17.5 + 2.6 + KV cache (distributed) — fits in 80 GB!
In the following parts, we'll dive deep into each strategy:
| A100 HBM Memory | 80 GB |
| A100 Memory Bandwidth | 2.0 TB/s |
| A100 NVLink Bandwidth | 600 GB/s |
| Llama-70B Weights (BF16) | 140 GB |
| Q+K+V for 1M tokens (1 layer) | 20.5 GB |
| Attention Scores for 1M tokens | 128 TB |
| KV Cache for 1M tokens (80 layers, GQA) | 328 GB |
| Part 2: Weight Sharding (FSDP) | How to distribute 140 GB of weights across GPUs using AllGather and ReduceScatter |
| Part 3: Sequence Sharding (Ulysses) | Splitting the 1M token sequence with all-to-all communication |
| Part 4: Ring Attention | Distributed attention without materializing O(n²) |
| Part 5: Putting It Together | The complete USP picture |
The math doesn't lie. We need parallelism — and now we know exactly why.