Flash Attention
Much progress in AI over the past few years has been fueled by the transformer architecture. Transformers are the closest thing we have right now to machine learnable programs. They can be trained to generate images, text, videos, audio, video games, or even raw byte sequences, you name it.
Behind the transformer, powering many of these applications, there are two key operations which make up 99% of the FLOPs: attention and feed-forward layers. These are conceptually very simple, although very compute intensive. In this post, my goal is to expand on the attention operation and how to efficiently implement it using an algorithm called Flash Attention.
Attention layers are responsible for mixing information between words (or tokens). Without attention, each token wouldn’t know about all other tokens in the sequence. The input to an attention layer is a set of queries , keys , and values , where is the number of tokens in the sequence and is the dimension of each token. The output is a set of attended values .
The attention operation is defined as: . The output is a convex combination of the values, where the weights are given by the softmax of the dot product between queries and keys, also called the attention score.
The memory bottleneck
While the naive implementation of this formula is straight-forward, it is very memory inefficient. To understand why, we need to know a bit about GPU memory.
Modern GPUs have two levels of memory that matter here:
- HBM (High Bandwidth Memory): Large (e.g. 80 GB on an A100) but relatively slow to access.
- SRAM (on-chip memory): Tiny (e.g. 20 MB) but extremely fast — about 10x higher bandwidth than HBM.
A naive implementation of attention involves three separate operations:
- Matrix multiply: Compute the attention scores
- Softmax: Compute the attention weights
- Matrix multiply: Compute the output
Each of these operations reads its inputs from HBM, computes in SRAM, and writes the result back to HBM. This means the intermediate matrices and — both of size — get written to HBM and read back for the next step. For long sequences, this memory traffic becomes the bottleneck, not the actual computation.
The Flash Attention idea
Flash Attention solves this by fusing all three operations into a single kernel that processes small blocks at a time. Instead of computing the full matrix and writing it to HBM, we:
- Split into blocks of rows () and , into blocks of rows (, ).
- For each query block , iterate over all key-value blocks , .
- Compute the attention scores, softmax, and value-weighted sum for just that block pair — entirely in SRAM.
- Accumulate partial results and write only the final output back to HBM.
The tricky part is the softmax. Unlike the matrix multiplications, softmax requires knowing the maximum and sum across the entire row of . Flash Attention handles this with an online softmax algorithm that incrementally updates these statistics as we process each block.
Let me walk through a simplified Python implementation step by step.
Step 1: The outer structure
1import numpy as np
2
3
4def flash_attention(Q, K, V):
5 # Q: [batch, seq_out, head_dim]
6 # K: [batch, seq_in, head_dim]
7 # V: [batch, seq_in, head_dim]
8 O = np.zeros(Q.shape[:-2] + (Q.shape[-2], V.shape[-1]))
9 batch_size, query_seq_len, query_dim = Q.shape
10 softmax_scale = np.sqrt(query_dim)
11 blk_q = 64
12 blk_kv = 64
We allocate the output and define block sizes. On a real GPU, these block sizes would be chosen to maximally fill the available SRAM. Here we use 64 for both query and key-value blocks.
Step 2: Iterating over query blocks
1 for i in range(0, query_seq_len, blk_q):
2 O_i = np.zeros((batch_size, blk_q, V.shape[-1]))
3 l = np.zeros((batch_size, blk_q, 1))
4 m = None
5 Q_i = Q[:, i : i + blk_q, :] / softmax_scale
The outer loop iterates over query blocks. For each query block , we initialise three accumulators:
O_i— the partial output for this query block (will accumulate weighted values)l— the running sum of exponentiated scores (the softmax denominator)m— the running maximum score (for numerical stability)
We also load from HBM and pre-divide by so we don’t have to repeat this in the inner loop.
Step 3: Iterating over key-value blocks
1 for j in range(0, K.shape[-2], blk_kv):
2 K_j = K[:, j : j + blk_kv, :]
3 V_j = V[:, j : j + blk_kv, :]
4 S_ij = Q_i @ np.swapaxes(K_j, -1, -2)
The inner loop iterates over key-value blocks. For each block , we load and from HBM into SRAM and compute the block of attention scores:
This is a small matrix — in our case — that fits comfortably in SRAM.
Step 4: The online softmax trick
This is the heart of Flash Attention. The challenge: softmax is defined as:
The denominator requires summing over all keys, but we only have the current block. The solution is to keep a running maximum and sum, and rescale previous accumulations to the current block’s maximum.
1 m_ij = np.max(S_ij, axis=2, keepdims=True)
2 P_ij = np.exp(S_ij - m_ij)
3 w_ij = np.exp(m - m_ij) if j > 0 else 0.0
4 m = m_ij
Here’s what’s happening:
m_ijis the maximum score in the current block. We subtract it before exponentiating for numerical stability.P_ij = exp(S_ij - m_ij)gives us unnormalised attention weights for this block.w_ij = exp(m_prev - m_ij)is the correction factor. If the new block has a larger maximum than what we’ve seen before, this factor is less than 1, scaling down all previous accumulations to be consistent with the new maximum. If the previous maximum was larger, we scale up instead. Either way, it ensures all blocks end up normalised relative to a common reference point.
Step 5: Accumulating results
1 l = w_ij * l + np.sum(P_ij, axis=2, keepdims=True)
2 O_i = w_ij * O_i + P_ij @ V_j
Both the sum l and the output O_i get the same treatment:
- Rescale the previous accumulation by
w_ij(correcting for the updated maximum). - Add the current block’s contribution.
After processing all blocks, l contains the total sum and O_i contains , where is the maximum from the last block processed.
Step 6: Final normalisation
1 O[:, i : i + blk_q, :] = O_i / l
2 return O
Finally, we divide by the accumulated sum to get the properly normalised output:
The terms cancel between numerator and denominator, so we get the exact same result as standard attention — just computed block by block without ever materialising the full matrix.
Putting it all together
Here’s the complete implementation:
1import numpy as np
2
3
4def flash_attention(Q, K, V):
5 # K, V should have shape [batch, sequence_in, head_dim]
6 # Q should have shape [batch, sequence_out, head_dim]
7 O = np.zeros(Q.shape[:-2] + (Q.shape[-2], V.shape[-1]))
8 batch_size, query_seq_len, query_dim = Q.shape
9 softmax_scale = np.sqrt(query_dim)
10 blk_q = 64
11 blk_kv = 64
12 for i in range(0, query_seq_len, blk_q):
13 # Allocate intermediate results:
14 O_i = np.zeros((batch_size, blk_q, V.shape[-1]))
15 l = np.zeros((batch_size, blk_q, 1))
16 m = None
17 # Load Q_i from HBM memory
18 # The outputs for that query will be computed
19 Q_i = Q[:, i : i + blk_q, :] / softmax_scale
20 for j in range(0, K.shape[-2], blk_kv):
21 # Load K_j and V_j from HBM memory
22 # [batch, blk_kv, head_dim]
23 K_j = K[:, j : j + blk_kv, :]
24 V_j = V[:, j : j + blk_kv, :]
25 # Attention scores for block Q_i,K_j
26 # [batch, blk_q, blk_kv]
27 S_ij = Q_i @ np.swapaxes(K_j, -1, -2)
28 # Softmax over current block but don't normalize yet
29 m_ij = np.max(S_ij, axis=2, keepdims=True)
30 P_ij = np.exp(S_ij - m_ij)
31 # Carry forward factor w_ij
32 # [batch, blk_q, 1]
33 w_ij = np.exp(m - m_ij) if j > 0 else 0.0
34 m = m_ij
35 # Output normalization factor l
36 # [batch, blk_q, 1]
37 l = w_ij * l + np.sum(P_ij, axis=2, keepdims=True)
38 # Outputs for block Q_i,K_j
39 # [batch, blk_q, head_dim]
40 O_i = w_ij * O_i + P_ij @ V_j
41 O[:, i : i + blk_q, :] = O_i / l
42
43 return O
Verification
Let’s check that the results match other implementations:
1import torch
2
3def ref_attention_1(Q, K, V):
4 S = (Q @ np.swapaxes(K, -1, -2)) / np.sqrt(Q.shape[-1])
5 S = np.exp(S)
6 S = S/np.sum(S, -1)[..., None]
7 return S @ V
8
9def ref_attention_2(Q, K, V):
10 return torch.nn.functional.scaled_dot_product_attention(
11 query=torch.from_numpy(Q),
12 key=torch.from_numpy(K),
13 value=torch.from_numpy(V)
14 ).numpy()
15
16
17Q = np.random.uniform(size=(4, 4096, 32))
18K = np.random.uniform(size=(4, 4096, 32))
19V = np.random.uniform(size=(4, 4096, 32))
20
21ours = flash_attention(Q, K, V)
22np.testing.assert_allclose(ref_attention_1(Q, K, V), ours)
23np.testing.assert_allclose(ref_attention_2(Q, K, V), ours)
Both assertions pass — our block-by-block implementation produces the same results as computing the full attention matrix, up to floating point precision.
I hope this walkthrough of Flash Attention was helpful! If you have any questions, feel free to reach out to me on Twitter.