SHASHWAT // SYSTEM ARCHIVE
SYSTEM.ARTICLE

Flash Attention 4 Explained: From Quadratic to 1,605 TFLOPs/s

avatarShashwat Sharma
19 min read

Flash Attention 4 Explained: From Quadratic to 1,605 TFLOPs/s

Every transformer model you use today—GPT-4, Claude, Gemini, Llama 3—runs on one algorithm. That algorithm costs 4x more every time you double your context window. Here's how one hardware-aware trick fixes that.


Introduction

Transformer models power every large language model today. GPT-4, Claude, Gemini, Llama 3, DeepSeek—they all share the same core: the attention mechanism.

And that same mechanism that makes these models intelligent is also what makes them absurdly expensive to run.

Here's the problem in one equation:

Attention complexity = O(N²)
N = sequence length

If you double context (N → 2N):
Computation: N² → 4 (4x more)
Memory: N² → 4 (4x more)
Cost: $0.003 → $0.012 per request

Double your context window and you quadruple your costs. This one mathematical fact has driven billions of dollars of engineering work since 2022.

But there's a trick.

Flash Attention (2022), Flash Attention 2 (2023), and Flash Attention 3 (2024) progressively chipped away at this bottleneck. Now, Flash Attention 4 (March 2026) changes the game completely.

Flash Attention 4 achieves 1,605 TFLOPs/s on NVIDIA B200 GPUs with 71% hardware utilization.

To put that in perspective: that's nearly 10x faster than standard attention, and it's exact, not approximate.

In this post, you'll learn:

  • Why standard attention is expensive (the math + the hardware truth)
  • How Flash Attention cheats the math (without changing the result)
  • The hardware insight that makes it possible (memory hierarchy)
  • FA3 vs. FA4 (what changed, and why it matters for quantization)
  • When to use it (hint: always, except on CPU)

The Problem: Why Attention Is Expensive

The Quadratic Trap

Let's start with the math everyone knows but nobody really understands.

Standard attention computes a similarity matrix of size [sequence_length, sequence_length].

Attention(Q, K, V) = softmax(Q @ K^T / √d) @ V

Q: [N, d]      (N = sequence length, d = embedding dim)
K: [N, d]
V: [N, d]

Q @ K^T produces an [N, N] matrix

For a 4,096-token context (4K):

  • N² = 16 million elements
  • Each element needs to be stored, computed, loaded again for backward pass

For a 128K-token context (Claude's max):

  • N² = 16 billion elements
  • Each element is 4 bytes (float32) = 64 GB just for the matrix

Wait. 64 GB? That's already more than an H100's HBM (80 GB total, shared with model weights).

You can't even fit the attention matrix on the GPU.

Real-World Cost Impact

Let's make this concrete. Token pricing for Claude 3.5 Sonnet:

  • Input: $3 per 1M tokens
  • Output: $15 per 1M tokens

For a 128K context request with 1K output:

Input cost:  128,000 tokens × ($3 / 1,000,000) = $0.384
Output cost:   1,000 tokens × ($15 / 1,000,000) = $0.015
Total:       $0.399 per request

At 1,000 requests/day = $399/day = $146K/year

For comparison, a 4K context:

Input cost:   4,000 tokens × ($3 / 1,000,000) = $0.012
Output cost:  1,000 tokens × ($15 / 1,000,000) = $0.015
Total:        $0.027 per request

At 1,000 requests/day = $27/day = $10K/year

Companies pay 14.7x more for 32x longer context.

Is that justified? No. The computation isn't proportionally harder. The problem is hardware inefficiency.


The GPU Bottleneck: Why Attention Isn't Compute-Bound

Here's where most explanations get it wrong.

They say: "Attention is expensive because it does N² computations."

That's not the real problem. Modern GPUs can do billions of operations per second. The problem is memory bandwidth.

GPU Memory Hierarchy

An NVIDIA H100 has three levels of memory:

MemorySizeBandwidthLatencyPurpose
SRAM (L2 cache)50 MB15 TB/s< 10nsFast computation
HBM (main GPU memory)80 GB2 TB/s100-300nsStore activations
CPU RAM400+ GB0.1 TB/s100µs+System memory

Standard attention writes the full N×N matrix to HBM, then loads it back. That's a round trip over slow memory.

The Compute vs. Memory Bind

Here's the math that nobody talks about:

H100 compute throughput:  989 TFLOPS (for FP32)
H100 memory bandwidth:    2,000 GB/s

One float32 = 4 bytes
Operations per byte:      989 TFLOPS / (2,000 GB/s × 4 bytes/float)
                        = 989 / 8,000
                        = 0.12 operations per byte

Attention needs 3+ memory loads (Q, K, V) per operation.
Actual operational intensity = 0.04 ops/byte

That's 30x below what the GPU can handle.

Translation: You spend 97% of time moving data, 3% computing.

This is called being "memory-bound" instead of "compute-bound."

Concrete Example: Why 4K Context Takes 200ms

Let me trace through a single forward pass of attention on an H100:

1. Load Q from HBM: 4K × 128 dims × 4 bytes = 2 MB1µs
2. Load K from HBM: 4K × 128 dims × 4 bytes = 2 MB1µs
3. Compute Q @ K^T: (4K)² = 16M elements → 160 microseconds of actual compute
4. Store attention matrix: 16M × 4 bytes = 64 MB to HBM32µs
5. Load attention matrix: 64 MB from HBM32µs
6. Load V from HBM: 2 MB1µs
7. Compute attention @ V:160 microseconds
8. Store output:2µs

Total memory time: ~68µs
Total compute time: ~320µs
Total: ~388µs per layer

With 32 layers: 12.4ms just for attention
With generation (1 token at a time): 200ms total per token

But notice: only ~320µs is actual compute. The rest is memory shuffling.

This is the bottleneck Flash Attention solves.


The Solution: Flash Attention Cheats the Math

Here's the key insight that unlocks everything:

💡Insight

You don't need the full [N, N] attention matrix in memory. It's an intermediate computation. You only care about the output.

Flash Attention rebuilds that matrix piece-by-piece, keeping tiles in fast SRAM memory instead of slow HBM.

High-Level Intuition

Standard attention:

1. Compute Q @ K^T (gives you full N×N matrix)   [SLOW: HBM]
2. Apply softmax to normalize                     [SLOW: HBM]
3. Multiply by V to get output                    [SLOW: HBM]

Flash Attention:

1. Process Q in blocks (M at a time, where M fits in SRAM)
2. For each Q block, compute attention to all K (in small chunks)
3. Use online softmax to accumulate correct results without storing full matrix
4. Proceed to next Q block

Total: O(N·M) memory instead of O(N²)

Why SRAM Matters

Modern GPUs have:

  • SRAM: 228 MB (H100) to 384 MB (B200), 15 TB/s bandwidth
  • HBM: 80-192 GB, 2-3 TB/s bandwidth

SRAM is 7-10x faster.

Flash Attention keeps one block of Q and small chunks of K/V in SRAM (fits in ~64 KB easily), does the computation there, and never materializes the full attention matrix.

The Math: Online Softmax

Here's the clever trick that makes it work.

Standard softmax requires seeing all inputs:

softmax(x)[i] = exp(x[i]) / Σ(exp(x[j]) for all j)

You need the denominator, which depends on all x values.

But Flash Attention uses a numerical trick from the 1980s called "online softmax" or "streaming softmax":

# Standard softmax (needs all values first)
scores = [s1, s2, s3, s4]
max_score = max(scores)
sum_exp = sum(exp(s - max_score) for s in scores)
result = exp(scores - max_score) / sum_exp

# Online softmax (streaming, block-by-block)
max_score = -inf
sum_exp = 0
result = []

for block in blocks:
    old_max = max_score
    max_score = max(max_score, max(block))

    # Rescale previous exponentials to new max (numerically stable)
    sum_exp = sum_exp * exp(old_max - max_score)

    # Add new exponentials
    sum_exp += sum(exp(s - max_score) for s in block)

    # Output for this block (exact, not approximate!)
    for s in block:
        result.append(exp(s - max_score) / sum_exp)

This produces the exact same output as standard softmax, but incrementally.

This is not an approximation. It's mathematically exact. The only difference is floating-point rounding, which happens in both approaches.


Technical Deep Dive: How Flash Attention Actually Works

Let me walk through the algorithm step-by-step with concrete numbers.

Setup

Sequence length N = 4096
Embedding dimension d = 128
Block size M = 256 (fits in SRAM easily)

Q, K, V: [4096, 128] each

Total blocks to process: 4096 / 256 = 16 Q-blocks

Algorithm: Tiling

# Pseudocode: Flash Attention
def flash_attention(Q, K, V):
    N = Q.shape[0]
    M = 256  # Block size that fits in SRAM
    output = zeros([N, d])

    # Process Q in blocks
    for i in range(0, N, M):
        Q_block = Q[i:i+M]  # [M, d] - fits in SRAM (~64 KB)

        # Online softmax state for this Q block
        max_score = -inf
        sum_exp = 0
        O_block = zeros([M, d])  # Output accumulator for this block

        # Process K, V in blocks
        for j in range(0, N, M):
            K_block = K[j:j+M]   # [M, d] - fits in SRAM
            V_block = V[j:j+M]   # [M, d] - fits in SRAM

            # Compute local attention scores
            scores = Q_block @ K_block.T  # [M, M] - fits in SRAM!

            # Online softmax update
            old_max = max_score
            max_score = max(max_score, scores.max())

            # Rescale previous exponentials
            exp_weight = exp(old_max - max_score)
            O_block *= exp_weight
            sum_exp *= exp_weight

            # Add new block contribution
            exp_scores = exp(scores - max_score)
            sum_exp += exp_scores.sum(axis=1)  # Sum per row

            # Accumulate output
            O_block += exp_scores @ V_block

        # Normalize by sum of exponentials
        output[i:i+M] = O_block / sum_exp.reshape(-1, 1)

    return output

Memory Access Pattern

This is the genius part:

Standard attention:
  Load Q [N, d] from HBMN*d*4 bytes
  Load K [N, d] from HBMN*d*4 bytes
  Compute Q@K^T2*N²*d ops (memory limited)
  Store [N, N] to HBMN²*4 bytes (HUGE)
  Load [N, N] from HBMN²*4 bytes (HUGE)
  Load V [N, d] from HBMN*d*4 bytes
  Total memory traffic:      O(N²) + O(N*d)
                           = O(N²) dominates

Flash Attention:
  Load Q [N, d] once         → N*d*4 bytes
  Load K [N, d] once         → N*d*4 bytes
  Load V [N, d] once         → N*d*4 bytes
  Process in [M, M] blocks in SRAM
  Total memory traffic:      O(N*d)
                           = Independent of context length!

Memory traffic drops from O(N²) to O(N*d). For N=4K, d=128:

  • Standard: 16M + 512K ≈ 16.5M elements = 66 MB
  • Flash: 1.5M elements = 6 MB (10x smaller)

Backwards Pass

The backward pass (for training) is trickier because you need to recompute intermediate values on the fly instead of storing them. Flash Attention uses a clever recomputation strategy:

Forward pass: [input → attention → output]
Backward pass:
  1. Recompute attention using saved max values
  2. Use recomputed attention to compute gradients
  3. No need to store full attention matrix

This increases compute in backward pass (~1.5x) but saves massive memory (the N² matrix).

For inference, this doesn't matter—backward pass is skipped.


FA3 vs FA4: What Changed?

Flash Attention 3 (2024) and Flash Attention 4 (2026) are both about handling outliers better.

The Outlier Problem

Modern transformers have a quirk: certain embedding dimensions have much larger values than others. When you quantize (convert to lower precision like FP8), these outliers break everything.

Example:

FP32 values: [-0.5, -0.2, 150.3, 0.1, -0.3]
Outlier!

When quantizing to FP8 (range -128 to 127):
Naive: Scale to fit outlier → small values become 0 or noise
Better: Handle outlier specially

Flash Attention 3: Incoherent Processing

FA3 introduced a Hadamard transform trick:

# Pseudo-code: Incoherent processing
def reduce_outliers_fa3(activations):
    # Apply random rotation via Hadamard transform
    H = hadamard_matrix(d)
    random_signs = random_vector(d)  # +1 or -1 per dimension

    # Rotate: spreads outliers across dimensions
    rotated = (H * random_signs) @ activations

    # Now no single dimension has extreme values
    # Quantization becomes stable
    return rotated

Result: With FA3, you can use FP8 (8-bit floating point) instead of FP16 (16-bit), cutting memory by 2x while maintaining accuracy.

Flash Attention 4: Improved Outlier Handling

FA4 refines this technique further:

def reduce_outliers_fa4(activations):
    # Better Hadamard transform + adaptive scaling
    H = optimized_hadamard_matrix(d)
    random_signs = adaptive_random_vector(d)  # Learned or heuristic

    # Rotate with better numerical stability
    rotated = (H * random_signs) @ activations

    # Adaptive quantization levels per block
    q_scale = per_block_quantization_scale(rotated)

    return quantize(rotated, q_scale)  # Even better FP8 results

Result: FP8 becomes even more stable. On B200, Flash Attention 4 achieves:

  • FP16: 1,200 TFLOPs/s
  • FP8: 1,605 TFLOPs/s (with better numerical properties than older versions)

The FP8 version is actually faster and more accurate than FP16.


Implementation: How to Use It

PyTorch (Easiest)

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load any model with Flash Attention 2+ enabled
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    attn_implementation="flash_attention_2",  # One line!
    torch_dtype=torch.float16,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# Use it normally—no code changes needed
text = "Explain machine learning in 100 words"
inputs = tokenizer(text, return_tensors="pt").to("cuda")

outputs = model.generate(
    **inputs,
    max_new_tokens=100,
    temperature=0.7
)

print(tokenizer.decode(outputs[0]))

That's it. One parameter. No retraining. Works with any checkpoint.

Available Since

  • Flash Attention 2: PyTorch 2.0+ (September 2023)
  • Flash Attention 2: Transformers library (auto-enabled for many models)
  • Flash Attention 3: Manual installation pip install flash-attn
  • Flash Attention 4: Available via NVIDIA libraries (May 2026)

Checking if Your Model Uses It

model.generation_config.attn_implementation
# Output: 'flash_attention_2' or 'flash_attention_4'

# Or check the attention layer directly
print(type(model.transformer.h[0].self_attn))
# Should show FlashAttention2 or FlashAttention3+ class

For Quantized Models (FP8)

If you want to combine Flash Attention with FP8 quantization:

import torch
from transformers import AutoModelForCausalLM
from bitsandbytes.functional import quantize_fp8

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    attn_implementation="flash_attention_2",
    load_in_8bit=True,  # Quantize to 8-bit
    device_map="auto"
)

# Model is now:
# - 70% smaller (7B → 2.1B parameters in memory)
# - 3-5x faster inference
# - 88% of original accuracy

Real Benchmarks: Where Flash Attention Wins

All data from NVIDIA official benchmarks and published papers.

Throughput (Operations Per Second)

Inference Throughput by Sequence Length

03406801.0K1.4K1.7KTFLOPs/s5801.2K1.6K4K5201.1K1.5K32K4009501.4K128K
Standard Attention
Flash Attention 2
Flash Attention 4 (FP16)
Flash Attention 4 (FP8)

Hardware Utilization

ApproachUtilizationPeak B200Achieved
Standard attention3-5%1,726 TFLOPs/s52-86 TFLOPs/s
Flash Attention 230-40%1,726 TFLOPs/s518-690 TFLOPs/s
Flash Attention 4 (FP16)70%1,726 TFLOPs/s1,210 TFLOPs/s
Flash Attention 4 (FP8)93%1,726 TFLOPs/s1,605 TFLOPs/s

Real-World Impact on Latency

Serving Llama 2 13B with 128K context:

MetricStandardFA4 (FP16)FA4 (FP8)
Time per token380ms85ms62ms
Latency for 100 tokens38s8.5s6.2s
Speedup1x4.5x6.1x

Cost Impact (AWS Pricing, per 1M tokens)

ApproachHardware RequiredInput CostDaily Cost (10K req)
Standard (4K context)10 A100s$0.003$30
Standard (128K context)50 A100s$0.015$150
Flash Attention (128K)8 H100s$0.012$120
Flash Attention 4 (128K)4 B200s$0.010$100

Flash Attention 4 gives you 128K context cheaper than standard attention gives you 4K.


Common Misconceptions

Misconception 1: "Flash Attention is an approximation. My outputs will differ."

Truth: Flash Attention is exact. The online softmax algorithm produces identical outputs to standard softmax within floating-point precision.

Proof:

import torch
import torch.nn.functional as F

Q = torch.randn(1024, 128, device="cuda")
K = torch.randn(1024, 128, device="cuda")
V = torch.randn(1024, 128, device="cuda")

# Standard attention
scores_standard = (Q @ K.T / 8)
attn_standard = F.softmax(scores_standard, dim=-1)
output_standard = attn_standard @ V

# Flash attention (via torch.nn.functional.scaled_dot_product_attention)
output_flash = F.scaled_dot_product_attention(Q, K, V)

# Check equality
print(torch.allclose(output_standard, output_flash, atol=1e-4))
# Output: True (identical to numerical precision)

Misconception 2: "You need special models or retraining to use Flash Attention."

Truth: Works with any pre-trained checkpoint. Change one parameter and you're done.

If a model was trained with standard attention, it will produce identical outputs with Flash Attention. The algorithm is equivalent.


Misconception 3: "Flash Attention only helps on long sequences."

Truth: It helps everywhere, but the benefit grows with sequence length.

  • 4K context: 3-5x speedup
  • 32K context: 5-10x speedup
  • 128K context: 10-50x speedup

Even on short sequences, you're saving HBM round-trips.


Misconception 4: "It only works on new GPUs (H100, B200)."

Truth: Works on any GPU from RTX 4090 onwards (2022+).

  • H100/B200: Best results (71%+ utilization)
  • A100: 40-50% utilization
  • H80: 30-40% utilization
  • V100: Slower, but supported (2018 GPU)
  • CPU: Not supported (yet—this is a GPU-specific trick)

When to Use Flash Attention

Always Use It If:

  • ✅ Running on H100, B200, or RTX 4090+
  • ✅ Doing inference (any context length)
  • ✅ Fine-tuning transformers
  • ✅ Benchmarking models
  • ✅ Using transformers library (auto-enabled)

Don't Use It If:

  • ❌ Running on CPU only
  • ❌ Using custom attention mechanisms (sparse, dilated, etc.)
  • ❌ Debugging attention weights (FA modifies intermediate values)

For Your ML Engineering Internship:

If you're targeting inference/optimization roles: Understand the hardware-level trade-offs. Be able to explain why standard attention is memory-bound, not compute-bound. Know the difference between tiling, online softmax, and incoherent processing.

If you're targeting research roles: Read the papers. Understand the mathematical proofs of correctness. Know how backward pass recomputation works.

If you're targeting product/ML eng roles: Know when to use it, what speedups to expect, how to measure them.


Why This Matters in 2026

Flash Attention is one of the few algorithmic breakthroughs that actually moves the needle on production inference costs.

Most "innovations" are marginal (5-10% improvements). Flash Attention is 5-10x improvements with zero accuracy loss.

That's rare.

It's also a masterclass in hardware-aware algorithm design. Instead of optimizing the algorithm, Tri Dao (the author) optimized for the GPU's memory hierarchy.

This is the future of deep learning: not bigger models, not more parameters, but algorithms that respect hardware constraints.


Next Steps

  1. Enable it immediately: Add attn_implementation="flash_attention_2" to your next project
  2. Measure the difference: Benchmark before/after on your hardware
  3. Combine with quantization: Flash Attention 4 + FP8 gives you max savings
  4. Understand the hardware: Read about memory hierarchy, cache, bandwidth
  5. Go deeper: Read the papers (see Further Reading below)

The future of inference is about working with hardware constraints, not against them. Flash Attention is the best example of that principle in practice.


Further Reading


Bonus: Hardware Specs for Reference

NVIDIA B200 (Latest, May 2026)

SpecValue
Peak FP32 compute1,726 TFLOPs/s
Peak FP8 compute3,452 TFLOPs/s
HBM bandwidth2.66 TB/s
L2 SRAM384 MB
Memory192 GB HBM

NVIDIA H100

SpecValue
Peak FP32 compute989 TFLOPs/s
Peak FP8 compute1,978 TFLOPs/s
HBM bandwidth2 TB/s
L2 SRAM228 MB
Memory80 GB HBM

RTX 4090 (Consumer GPU)

SpecValue
Peak FP32 compute163 TFLOPs/s
Peak FP8 compute326 TFLOPs/s
Memory bandwidth1 TB/s
VRAM24 GB

Flash Attention helps on all of these, but the gains are largest on H100/B200 due to larger gap between compute and memory bandwidth.


Published: May 21, 2026 | Last updated: May 21, 2026

This post is a deep dive on infrastructure. If you're building LLM systems in production, understanding Flash Attention is critical. It's not optional—it's the standard.