Multi-head attention variants step-by-step in PyTorch
Self-attention is an operation in Transformer neural networks that allows tokens to incorporate information from other relevant tokens. Multi-head attention (MHA) extends this idea to multiple "heads" of attention, so that multiple kinds of relevance can be incorporated. Through three separate linear projections, every input token can determine:
- what kind of information it's looking for (through a "query" projection),
- what kind of information it provides (through a "key" projection), and
- the actual information (through a "value" projection)
Transformers can be sped up by performing fewer ops through multi-query attention (MQA), where only a single key/value projection is used across all heads (Shazeer et al., 2019). While this variant speeds up inference, it sacrifices model performance. Looking for an in-between that improves inference speed while maintaining model performance, Ainslie et al. proposed grouped query attention (GQA), where an intermediate number of key/value projections are used.
The goal of this post is to re-implement all three variants of attention, and compare our implementation to PyTorch's torch.nn.functional.scaled_dot_product_attention()
. But first, a summary of the attention variants one more time when we assume attention heads and groups:
Method | Query Heads | Key Heads | Value Heads | Quality and Speed |
---|---|---|---|---|
MHA | High Quality / Slowest | |||
GQA | Balanced | |||
MHA | Lower Quality / Fastest |
Step 1: bidirectional attention
Vaswani et al. defined the self-attention operation as
where is a matrix of rows of queries, is a matrix of keys, is a matrix of values, and is the dimension of the keys. In this expression, all queries can "see" all keys because we dot product every query to every key. In this way, the attention is bidirectional. One can implement this in Python as
# assume q,k,v have shape (batch_size, num_heads, sequence_len, head_dim)
scaling = math.sqrt(k.size(-1))
attn = (q @ k.transpose(-1,-2)) / scaling
attn = F.softmax(attn, dim=-1) @ v
To check our implementation, we will be using the following function which compares two tensors:
def compare(t1, t2):
print(f"Equal ✅" if torch.allclose(t1, t2, atol=1e-5) else "Not equal ❌")
and we'll also generate some random keys, queries and values
B = 2 # batch size
S = 256 # sequence length
E = 768 # head embedding dimension
Nh = 12 # num heads
query = torch.randn(B, Nh, S, E)
key = torch.randn(B, Nh, S, E)
value = torch.randn(B, Nh, S, E)
Test
We can check what we have so far:
def my_sdpa(q, k, v):
scaling = math.sqrt(k.size(-1))
attn = (q @ k.transpose(-1,-2)) / scaling
attn = F.softmax(attn, dim=-1) @ v
return attn
compare(my_sdpa(query, key, value),
F.scaled_dot_product_attention(query, key, value))
# Prints: Equal ✅
Step 2: causal attention
To implement the causal attention used in most LLMs, we need to prevent the query vector at position from "seeing" keys that occur after position . To do this, we use an attention mask like the following
# matrix where values of 1 participate in attention, 0 does not
attn_mask = torch.tril(torch.ones(S, S))
# tensor([[1., 0., 0., ..., 0., 0., 0.],
# [1., 1., 0., ..., 0., 0., 0.],
# [1., 1., 1., ..., 0., 0., 0.],
# ...,
# [1., 1., 1., ..., 1., 0., 0.],
# [1., 1., 1., ..., 1., 1., 0.],
# [1., 1., 1., ..., 1., 1., 1.]])
where a 1
represents an attention score (dot product) that we will keep and 0
represents an attention score that we will toss out. Row represents the query position and column represents the key position. This mask allows query to dot product with keys . To actually "keep" the attention scores where a 1
is present and "toss" the scores where a 0
is present, we use the softmax operation. We can replace any 0
with float('-inf')
so that after softmax, the negative inf positions become zero:
def my_sdpa(q, k, v, attn_mask):
scaling = math.sqrt(k.size(-1))
attn = (q @ k.transpose(-1,-2)) / scaling
# mask out scores that should not be used with -inf
attn.masked_fill_(attn_mask == 0, float('-inf'))
attn = F.softmax(attn, dim=-1) @ v
return attn
Test
compare(my_sdpa(query, key, value, attn_mask=attn_mask),
F.scaled_dot_product_attention(query, key, value, attn_mask=(attn_mask == 1)))
# Prints: Equal ✅
Note that F.scaled_dot_product_attention
's attn_mask
argument expects a boolean mask where a True
value indicates a dot product that we "keep" to participate in attention.
Step 3: grouped query attention
In grouped query attention, we divide the Nh
query heads into groups, each of which shares a single key/value head. For example, suppose we had 4 query heads and 2 key/value groups, we assign query heads to groups in the following way:
Q0 Q1 Q2 Q3
K0 K0 K1 K1
where queries Q0
and Q1
use the key K0
. In our case, we defined num heads Nh = 12
and, for purposes of testing, we will define G = 3
groups and redefine our keys and values:
G = Nh // 4 # number of key/value groups is 12 // 4 = 3
key = torch.randn(B, G, S, E)
value = torch.randn(B, G, S, E)
Now, the main change we need to make to my_sdpa
is to repeat our keys and values along dim=1
so that the first 4 queries (making up the first "group") use key/value head 0, and so the next 4 queries use key/value head 1, and so on. We can do this using torch.repeat_interleave
:
x = torch.tensor([1,2,3])
print(x) # Prints: tensor([1, 2, 3])
print(x.repeat_interleave(2, dim=0)) # Prints: tensor([1, 1, 2, 2, 3, 3])
Let's make the change:
def my_sdpa(q, k, v, attn_mask, enable_gqa=False):
if enable_gqa:
_, Nh, _, _ = q.shape
_, G, _, _ = k.shape
assert Nh % G == 0, "number of heads in key and value must divide the number of heads in query"
repeat = Nh // G
k = k.repeat_interleave(repeat, dim=1)
v = v.repeat_interleave(repeat, dim=1)
scaling = math.sqrt(k.size(-1))
attn = (q @ k.transpose(-1,-2)) / scaling
# mask out scores that should not be used with -inf
attn.masked_fill_(attn_mask == 0, float('-inf'))
attn = F.softmax(attn, dim=-1) @ v
return attn
Test
PyTorch's F.scaled_dot_product_attention
supports GQA with enable_gqa=True
, so we can check our implementation easily:
compare(my_sdpa(query, key, value, attn_mask=attn_mask, enable_gqa=True),
F.scaled_dot_product_attention(query, key, value, attn_mask=(attn_mask == 1), enable_gqa=True))
# Prints: Equal ✅
Step 4: multi-query attention
Multi-query attention is a special case of GQA where there is only 1 key/value group. So the function we wrote above will work for this case!
# multi-query attention
G = 1
key = torch.randn(B, G, S, E)
value = torch.randn(B, G, S, E)
my_out = my_sdpa(query, key, value, attn_mask=attn_mask, enable_gqa=True)
compare(my_sdpa(query, key, value, attn_mask=attn_mask, enable_gqa=True),
F.scaled_dot_product_attention(query, key, value, attn_mask=(attn_mask == 1), enable_gqa=True))
# Prints: Equal ✅
Summary
Our my_sdpa
function reimplements F.scaled_dot_product_attention
and allows us to better understand the three popular attention variants used in LLMs today. It's important to note that the shapes of our randomly generated query, key, and value tensors are just a reshape of the hidden states of shape (batch_size, sequence_len, hidden_size)
that the Transformer operates on, where hidden_size = num_heads * head_dim
. So, whereas a Transformer using MHA would have a key projection that looks like:
k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
a Transformer using GQA with G
groups would have
k_proj = nn.Linear(hidden_size, G * head_dim, bias=False)
If this was helpful, let me know on X @psandovalsegura!
@article{sandovalsegura2025multihead,
title = "Multi-head attention variants step-by-step in PyTorch.",
author = "Sandoval-Segura, Pedro",
journal = "psando.bearblog.dev",
year = "2025",
month = "May",
url = "https://psando.bearblog.dev/multi-head-attention-variants-step-by-step-in-pytorch/"
}