Flash Attention 4 Explained: From Quadratic to 1,605 TFLOPs/s
Flash Attention 4 Explained: From Quadratic to 1,605 TFLOPs/s
Every transformer model you use today—GPT-4, Claude, Gemini, Llama 3—runs on one algorithm. That algorithm costs 4x more every time you double your context window. Here's how one hardware-aware trick fixes that.
Introduction
Transformer models power every large language model today. GPT-4, Claude, Gemini, Llama 3, DeepSeek—they all share the same core: the attention mechanism.
And that same mechanism that makes these models intelligent is also what makes them absurdly expensive to run.
Here's the problem in one equation:
Attention complexity = O(N²)
N = sequence length
If you double context (N → 2N):
Computation: N² → 4N² (4x more)
Memory: N² → 4N² (4x more)
Cost: $0.003 → $0.012 per request
Double your context window and you quadruple your costs. This one mathematical fact has driven billions of dollars of engineering work since 2022.
But there's a trick.
Flash Attention (2022), Flash Attention 2 (2023), and Flash Attention 3 (2024) progressively chipped away at this bottleneck. Now, Flash Attention 4 (March 2026) changes the game completely.
Flash Attention 4 achieves 1,605 TFLOPs/s on NVIDIA B200 GPUs with 71% hardware utilization.
To put that in perspective: that's nearly 10x faster than standard attention, and it's exact, not approximate.
In this post, you'll learn:
- Why standard attention is expensive (the math + the hardware truth)
- How Flash Attention cheats the math (without changing the result)
- The hardware insight that makes it possible (memory hierarchy)
- FA3 vs. FA4 (what changed, and why it matters for quantization)
- When to use it (hint: always, except on CPU)
The Problem: Why Attention Is Expensive
The Quadratic Trap
Let's start with the math everyone knows but nobody really understands.
Standard attention computes a similarity matrix of size [sequence_length, sequence_length].
Attention(Q, K, V) = softmax(Q @ K^T / √d) @ V
Q: [N, d] (N = sequence length, d = embedding dim)
K: [N, d]
V: [N, d]
Q @ K^T produces an [N, N] matrix
For a 4,096-token context (4K):
- N² = 16 million elements
- Each element needs to be stored, computed, loaded again for backward pass
For a 128K-token context (Claude's max):
- N² = 16 billion elements
- Each element is 4 bytes (float32) = 64 GB just for the matrix
Wait. 64 GB? That's already more than an H100's HBM (80 GB total, shared with model weights).
You can't even fit the attention matrix on the GPU.
Real-World Cost Impact
Let's make this concrete. Token pricing for Claude 3.5 Sonnet:
- Input: $3 per 1M tokens
- Output: $15 per 1M tokens
For a 128K context request with 1K output:
Input cost: 128,000 tokens × ($3 / 1,000,000) = $0.384
Output cost: 1,000 tokens × ($15 / 1,000,000) = $0.015
Total: $0.399 per request
At 1,000 requests/day = $399/day = $146K/year
For comparison, a 4K context:
Input cost: 4,000 tokens × ($3 / 1,000,000) = $0.012
Output cost: 1,000 tokens × ($15 / 1,000,000) = $0.015
Total: $0.027 per request
At 1,000 requests/day = $27/day = $10K/year
Companies pay 14.7x more for 32x longer context.
Is that justified? No. The computation isn't proportionally harder. The problem is hardware inefficiency.
The GPU Bottleneck: Why Attention Isn't Compute-Bound
Here's where most explanations get it wrong.
They say: "Attention is expensive because it does N² computations."
That's not the real problem. Modern GPUs can do billions of operations per second. The problem is memory bandwidth.
GPU Memory Hierarchy
An NVIDIA H100 has three levels of memory:
| Memory | Size | Bandwidth | Latency | Purpose |
|---|---|---|---|---|
| SRAM (L2 cache) | 50 MB | 15 TB/s | < 10ns | Fast computation |
| HBM (main GPU memory) | 80 GB | 2 TB/s | 100-300ns | Store activations |
| CPU RAM | 400+ GB | 0.1 TB/s | 100µs+ | System memory |
Standard attention writes the full N×N matrix to HBM, then loads it back. That's a round trip over slow memory.
The Compute vs. Memory Bind
Here's the math that nobody talks about:
H100 compute throughput: 989 TFLOPS (for FP32)
H100 memory bandwidth: 2,000 GB/s
One float32 = 4 bytes
Operations per byte: 989 TFLOPS / (2,000 GB/s × 4 bytes/float)
= 989 / 8,000
= 0.12 operations per byte
Attention needs 3+ memory loads (Q, K, V) per operation.
Actual operational intensity = 0.04 ops/byte
That's 30x below what the GPU can handle.
Translation: You spend 97% of time moving data, 3% computing.
This is called being "memory-bound" instead of "compute-bound."
Concrete Example: Why 4K Context Takes 200ms
Let me trace through a single forward pass of attention on an H100:
1. Load Q from HBM: 4K × 128 dims × 4 bytes = 2 MB → 1µs
2. Load K from HBM: 4K × 128 dims × 4 bytes = 2 MB → 1µs
3. Compute Q @ K^T: (4K)² = 16M elements → 160 microseconds of actual compute
4. Store attention matrix: 16M × 4 bytes = 64 MB to HBM → 32µs
5. Load attention matrix: 64 MB from HBM → 32µs
6. Load V from HBM: 2 MB → 1µs
7. Compute attention @ V: → 160 microseconds
8. Store output: → 2µs
Total memory time: ~68µs
Total compute time: ~320µs
Total: ~388µs per layer
With 32 layers: 12.4ms just for attention
With generation (1 token at a time): 200ms total per token
But notice: only ~320µs is actual compute. The rest is memory shuffling.
This is the bottleneck Flash Attention solves.
The Solution: Flash Attention Cheats the Math
Here's the key insight that unlocks everything:
You don't need the full [N, N] attention matrix in memory. It's an intermediate computation. You only care about the output.
Flash Attention rebuilds that matrix piece-by-piece, keeping tiles in fast SRAM memory instead of slow HBM.
High-Level Intuition
Standard attention:
1. Compute Q @ K^T (gives you full N×N matrix) [SLOW: HBM]
2. Apply softmax to normalize [SLOW: HBM]
3. Multiply by V to get output [SLOW: HBM]
Flash Attention:
1. Process Q in blocks (M at a time, where M fits in SRAM)
2. For each Q block, compute attention to all K (in small chunks)
3. Use online softmax to accumulate correct results without storing full matrix
4. Proceed to next Q block
Total: O(N·M) memory instead of O(N²)
Why SRAM Matters
Modern GPUs have:
- SRAM: 228 MB (H100) to 384 MB (B200), 15 TB/s bandwidth
- HBM: 80-192 GB, 2-3 TB/s bandwidth
SRAM is 7-10x faster.
Flash Attention keeps one block of Q and small chunks of K/V in SRAM (fits in ~64 KB easily), does the computation there, and never materializes the full attention matrix.
The Math: Online Softmax
Here's the clever trick that makes it work.
Standard softmax requires seeing all inputs:
softmax(x)[i] = exp(x[i]) / Σ(exp(x[j]) for all j)
You need the denominator, which depends on all x values.
But Flash Attention uses a numerical trick from the 1980s called "online softmax" or "streaming softmax":
# Standard softmax (needs all values first)
scores = [s1, s2, s3, s4]
max_score = max(scores)
sum_exp = sum(exp(s - max_score) for s in scores)
result = exp(scores - max_score) / sum_exp
# Online softmax (streaming, block-by-block)
max_score = -inf
sum_exp = 0
result = []
for block in blocks:
old_max = max_score
max_score = max(max_score, max(block))
# Rescale previous exponentials to new max (numerically stable)
sum_exp = sum_exp * exp(old_max - max_score)
# Add new exponentials
sum_exp += sum(exp(s - max_score) for s in block)
# Output for this block (exact, not approximate!)
for s in block:
result.append(exp(s - max_score) / sum_exp)
This produces the exact same output as standard softmax, but incrementally.
This is not an approximation. It's mathematically exact. The only difference is floating-point rounding, which happens in both approaches.
Technical Deep Dive: How Flash Attention Actually Works
Let me walk through the algorithm step-by-step with concrete numbers.
Setup
Sequence length N = 4096
Embedding dimension d = 128
Block size M = 256 (fits in SRAM easily)
Q, K, V: [4096, 128] each
Total blocks to process: 4096 / 256 = 16 Q-blocks
Algorithm: Tiling
# Pseudocode: Flash Attention
def flash_attention(Q, K, V):
N = Q.shape[0]
M = 256 # Block size that fits in SRAM
output = zeros([N, d])
# Process Q in blocks
for i in range(0, N, M):
Q_block = Q[i:i+M] # [M, d] - fits in SRAM (~64 KB)
# Online softmax state for this Q block
max_score = -inf
sum_exp = 0
O_block = zeros([M, d]) # Output accumulator for this block
# Process K, V in blocks
for j in range(0, N, M):
K_block = K[j:j+M] # [M, d] - fits in SRAM
V_block = V[j:j+M] # [M, d] - fits in SRAM
# Compute local attention scores
scores = Q_block @ K_block.T # [M, M] - fits in SRAM!
# Online softmax update
old_max = max_score
max_score = max(max_score, scores.max())
# Rescale previous exponentials
exp_weight = exp(old_max - max_score)
O_block *= exp_weight
sum_exp *= exp_weight
# Add new block contribution
exp_scores = exp(scores - max_score)
sum_exp += exp_scores.sum(axis=1) # Sum per row
# Accumulate output
O_block += exp_scores @ V_block
# Normalize by sum of exponentials
output[i:i+M] = O_block / sum_exp.reshape(-1, 1)
return output
Memory Access Pattern
This is the genius part:
Standard attention:
Load Q [N, d] from HBM → N*d*4 bytes
Load K [N, d] from HBM → N*d*4 bytes
Compute Q@K^T → 2*N²*d ops (memory limited)
Store [N, N] to HBM → N²*4 bytes (HUGE)
Load [N, N] from HBM → N²*4 bytes (HUGE)
Load V [N, d] from HBM → N*d*4 bytes
Total memory traffic: O(N²) + O(N*d)
= O(N²) dominates
Flash Attention:
Load Q [N, d] once → N*d*4 bytes
Load K [N, d] once → N*d*4 bytes
Load V [N, d] once → N*d*4 bytes
Process in [M, M] blocks in SRAM
Total memory traffic: O(N*d)
= Independent of context length!
Memory traffic drops from O(N²) to O(N*d). For N=4K, d=128:
- Standard: 16M + 512K ≈ 16.5M elements = 66 MB
- Flash: 1.5M elements = 6 MB (10x smaller)
Backwards Pass
The backward pass (for training) is trickier because you need to recompute intermediate values on the fly instead of storing them. Flash Attention uses a clever recomputation strategy:
Forward pass: [input → attention → output]
Backward pass:
1. Recompute attention using saved max values
2. Use recomputed attention to compute gradients
3. No need to store full attention matrix
This increases compute in backward pass (~1.5x) but saves massive memory (the N² matrix).
For inference, this doesn't matter—backward pass is skipped.
FA3 vs FA4: What Changed?
Flash Attention 3 (2024) and Flash Attention 4 (2026) are both about handling outliers better.
The Outlier Problem
Modern transformers have a quirk: certain embedding dimensions have much larger values than others. When you quantize (convert to lower precision like FP8), these outliers break everything.
Example:
FP32 values: [-0.5, -0.2, 150.3, 0.1, -0.3]
↑ Outlier!
When quantizing to FP8 (range -128 to 127):
Naive: Scale to fit outlier → small values become 0 or noise
Better: Handle outlier specially
Flash Attention 3: Incoherent Processing
FA3 introduced a Hadamard transform trick:
# Pseudo-code: Incoherent processing
def reduce_outliers_fa3(activations):
# Apply random rotation via Hadamard transform
H = hadamard_matrix(d)
random_signs = random_vector(d) # +1 or -1 per dimension
# Rotate: spreads outliers across dimensions
rotated = (H * random_signs) @ activations
# Now no single dimension has extreme values
# Quantization becomes stable
return rotated
Result: With FA3, you can use FP8 (8-bit floating point) instead of FP16 (16-bit), cutting memory by 2x while maintaining accuracy.
Flash Attention 4: Improved Outlier Handling
FA4 refines this technique further:
def reduce_outliers_fa4(activations):
# Better Hadamard transform + adaptive scaling
H = optimized_hadamard_matrix(d)
random_signs = adaptive_random_vector(d) # Learned or heuristic
# Rotate with better numerical stability
rotated = (H * random_signs) @ activations
# Adaptive quantization levels per block
q_scale = per_block_quantization_scale(rotated)
return quantize(rotated, q_scale) # Even better FP8 results
Result: FP8 becomes even more stable. On B200, Flash Attention 4 achieves:
- FP16: 1,200 TFLOPs/s
- FP8: 1,605 TFLOPs/s (with better numerical properties than older versions)
The FP8 version is actually faster and more accurate than FP16.
Implementation: How to Use It
PyTorch (Easiest)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load any model with Flash Attention 2+ enabled
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2", # One line!
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# Use it normally—no code changes needed
text = "Explain machine learning in 100 words"
inputs = tokenizer(text, return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.7
)
print(tokenizer.decode(outputs[0]))
That's it. One parameter. No retraining. Works with any checkpoint.
Available Since
- Flash Attention 2: PyTorch 2.0+ (September 2023)
- Flash Attention 2: Transformers library (auto-enabled for many models)
- Flash Attention 3: Manual installation
pip install flash-attn - Flash Attention 4: Available via NVIDIA libraries (May 2026)
Checking if Your Model Uses It
model.generation_config.attn_implementation
# Output: 'flash_attention_2' or 'flash_attention_4'
# Or check the attention layer directly
print(type(model.transformer.h[0].self_attn))
# Should show FlashAttention2 or FlashAttention3+ class
For Quantized Models (FP8)
If you want to combine Flash Attention with FP8 quantization:
import torch
from transformers import AutoModelForCausalLM
from bitsandbytes.functional import quantize_fp8
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
load_in_8bit=True, # Quantize to 8-bit
device_map="auto"
)
# Model is now:
# - 70% smaller (7B → 2.1B parameters in memory)
# - 3-5x faster inference
# - 88% of original accuracy
Real Benchmarks: Where Flash Attention Wins
All data from NVIDIA official benchmarks and published papers.
Throughput (Operations Per Second)
Inference Throughput by Sequence Length
Hardware Utilization
| Approach | Utilization | Peak B200 | Achieved |
|---|---|---|---|
| Standard attention | 3-5% | 1,726 TFLOPs/s | 52-86 TFLOPs/s |
| Flash Attention 2 | 30-40% | 1,726 TFLOPs/s | 518-690 TFLOPs/s |
| Flash Attention 4 (FP16) | 70% | 1,726 TFLOPs/s | 1,210 TFLOPs/s |
| Flash Attention 4 (FP8) | 93% | 1,726 TFLOPs/s | 1,605 TFLOPs/s |
Real-World Impact on Latency
Serving Llama 2 13B with 128K context:
| Metric | Standard | FA4 (FP16) | FA4 (FP8) |
|---|---|---|---|
| Time per token | 380ms | 85ms | 62ms |
| Latency for 100 tokens | 38s | 8.5s | 6.2s |
| Speedup | 1x | 4.5x | 6.1x |
Cost Impact (AWS Pricing, per 1M tokens)
| Approach | Hardware Required | Input Cost | Daily Cost (10K req) |
|---|---|---|---|
| Standard (4K context) | 10 A100s | $0.003 | $30 |
| Standard (128K context) | 50 A100s | $0.015 | $150 |
| Flash Attention (128K) | 8 H100s | $0.012 | $120 |
| Flash Attention 4 (128K) | 4 B200s | $0.010 | $100 |
Flash Attention 4 gives you 128K context cheaper than standard attention gives you 4K.
Common Misconceptions
Misconception 1: "Flash Attention is an approximation. My outputs will differ."
Truth: Flash Attention is exact. The online softmax algorithm produces identical outputs to standard softmax within floating-point precision.
Proof:
import torch
import torch.nn.functional as F
Q = torch.randn(1024, 128, device="cuda")
K = torch.randn(1024, 128, device="cuda")
V = torch.randn(1024, 128, device="cuda")
# Standard attention
scores_standard = (Q @ K.T / 8)
attn_standard = F.softmax(scores_standard, dim=-1)
output_standard = attn_standard @ V
# Flash attention (via torch.nn.functional.scaled_dot_product_attention)
output_flash = F.scaled_dot_product_attention(Q, K, V)
# Check equality
print(torch.allclose(output_standard, output_flash, atol=1e-4))
# Output: True (identical to numerical precision)
Misconception 2: "You need special models or retraining to use Flash Attention."
Truth: Works with any pre-trained checkpoint. Change one parameter and you're done.
If a model was trained with standard attention, it will produce identical outputs with Flash Attention. The algorithm is equivalent.
Misconception 3: "Flash Attention only helps on long sequences."
Truth: It helps everywhere, but the benefit grows with sequence length.
- 4K context: 3-5x speedup
- 32K context: 5-10x speedup
- 128K context: 10-50x speedup
Even on short sequences, you're saving HBM round-trips.
Misconception 4: "It only works on new GPUs (H100, B200)."
Truth: Works on any GPU from RTX 4090 onwards (2022+).
- H100/B200: Best results (71%+ utilization)
- A100: 40-50% utilization
- H80: 30-40% utilization
- V100: Slower, but supported (2018 GPU)
- CPU: Not supported (yet—this is a GPU-specific trick)
When to Use Flash Attention
Always Use It If:
- ✅ Running on H100, B200, or RTX 4090+
- ✅ Doing inference (any context length)
- ✅ Fine-tuning transformers
- ✅ Benchmarking models
- ✅ Using transformers library (auto-enabled)
Don't Use It If:
- ❌ Running on CPU only
- ❌ Using custom attention mechanisms (sparse, dilated, etc.)
- ❌ Debugging attention weights (FA modifies intermediate values)
For Your ML Engineering Internship:
If you're targeting inference/optimization roles: Understand the hardware-level trade-offs. Be able to explain why standard attention is memory-bound, not compute-bound. Know the difference between tiling, online softmax, and incoherent processing.
If you're targeting research roles: Read the papers. Understand the mathematical proofs of correctness. Know how backward pass recomputation works.
If you're targeting product/ML eng roles: Know when to use it, what speedups to expect, how to measure them.
Why This Matters in 2026
Flash Attention is one of the few algorithmic breakthroughs that actually moves the needle on production inference costs.
Most "innovations" are marginal (5-10% improvements). Flash Attention is 5-10x improvements with zero accuracy loss.
That's rare.
It's also a masterclass in hardware-aware algorithm design. Instead of optimizing the algorithm, Tri Dao (the author) optimized for the GPU's memory hierarchy.
This is the future of deep learning: not bigger models, not more parameters, but algorithms that respect hardware constraints.
Next Steps
- Enable it immediately: Add
attn_implementation="flash_attention_2"to your next project - Measure the difference: Benchmark before/after on your hardware
- Combine with quantization: Flash Attention 4 + FP8 gives you max savings
- Understand the hardware: Read about memory hierarchy, cache, bandwidth
- Go deeper: Read the papers (see Further Reading below)
The future of inference is about working with hardware constraints, not against them. Flash Attention is the best example of that principle in practice.
Further Reading
- Flash Attention paper (2022): "Fast and Memory-Efficient Exact Attention with IO-Awareness" — https://arxiv.org/abs/2205.14135
- Flash Attention 2 (ICCV 2023): Optimizations for multi-GPU and backward pass — https://arxiv.org/abs/2307.08691
- Flash Attention 3 (2024): Incoherent processing and FP8 support — https://arxiv.org/abs/2405.16999
- Tri Dao's blog: Deep dives on implementation — https://tridao.me
- NVIDIA Blog: Hardware-specific optimizations — https://developer.nvidia.com/blog/attention-mechanisms
- My implementation guide: Coming soon on this blog
Bonus: Hardware Specs for Reference
NVIDIA B200 (Latest, May 2026)
| Spec | Value |
|---|---|
| Peak FP32 compute | 1,726 TFLOPs/s |
| Peak FP8 compute | 3,452 TFLOPs/s |
| HBM bandwidth | 2.66 TB/s |
| L2 SRAM | 384 MB |
| Memory | 192 GB HBM |
NVIDIA H100
| Spec | Value |
|---|---|
| Peak FP32 compute | 989 TFLOPs/s |
| Peak FP8 compute | 1,978 TFLOPs/s |
| HBM bandwidth | 2 TB/s |
| L2 SRAM | 228 MB |
| Memory | 80 GB HBM |
RTX 4090 (Consumer GPU)
| Spec | Value |
|---|---|
| Peak FP32 compute | 163 TFLOPs/s |
| Peak FP8 compute | 326 TFLOPs/s |
| Memory bandwidth | 1 TB/s |
| VRAM | 24 GB |
Flash Attention helps on all of these, but the gains are largest on H100/B200 due to larger gap between compute and memory bandwidth.
Published: May 21, 2026 | Last updated: May 21, 2026
This post is a deep dive on infrastructure. If you're building LLM systems in production, understanding Flash Attention is critical. It's not optional—it's the standard.