pedro's scratchpad

Multi-head attention variants step-by-step in PyTorch

Self-attention is an operation in Transformer neural networks that allows tokens to incorporate information from other relevant tokens. Multi-head attention (MHA) extends this idea to multiple "heads" of attention, so that multiple kinds of relevance can be incorporated. Through three separate linear projections, every input token can determine:

Transformers can be sped up by performing fewer ops through multi-query attention (MQA), where only a single key/value projection is used across all heads (Shazeer et al., 2019). While this variant speeds up inference, it sacrifices model performance. Looking for an in-between that improves inference speed while maintaining model performance, Ainslie et al. proposed grouped query attention (GQA), where an intermediate number of key/value projections are used.

The goal of this post is to re-implement all three variants of attention, and compare our implementation to PyTorch's torch.nn.functional.scaled_dot_product_attention(). But first, a summary of the attention variants one more time when we assume N attention heads and G groups:

Method Query Heads Key Heads Value Heads Quality and Speed
MHA N N N High Quality / Slowest
GQA N G G Balanced
MHA N 1 1 Lower Quality / Fastest

Step 1: bidirectional attention

Vaswani et al. defined the self-attention operation as

𝐀=softmax(𝐐𝐊dk)𝐕

where 𝐐 is a matrix of rows of queries, 𝐊 is a matrix of keys, 𝐕 is a matrix of values, and dk is the dimension of the keys. In this expression, all queries can "see" all keys because we dot product every query to every key. In this way, the attention is bidirectional. One can implement this in Python as

# assume q,k,v have shape (batch_size, num_heads, sequence_len, head_dim)
scaling = math.sqrt(k.size(-1))
attn = (q @ k.transpose(-1,-2)) / scaling 
attn = F.softmax(attn, dim=-1) @ v

To check our implementation, we will be using the following function which compares two tensors:

def compare(t1, t2):
    print(f"Equal ✅" if torch.allclose(t1, t2, atol=1e-5) else "Not equal ❌")

and we'll also generate some random keys, queries and values

B = 2   # batch size
S = 256 # sequence length
E = 768 # head embedding dimension
Nh = 12 # num heads

query = torch.randn(B, Nh, S, E)
key   = torch.randn(B, Nh, S, E)
value = torch.randn(B, Nh, S, E)

Test

We can check what we have so far:

def my_sdpa(q, k, v):
    scaling = math.sqrt(k.size(-1))
    attn = (q @ k.transpose(-1,-2)) / scaling 
    attn = F.softmax(attn, dim=-1) @ v
    return attn

compare(my_sdpa(query, key, value), 
        F.scaled_dot_product_attention(query, key, value)) 
# Prints: Equal ✅

Step 2: causal attention

To implement the causal attention used in most LLMs, we need to prevent the query vector at position i from "seeing" keys that occur after position i. To do this, we use an attention mask like the following

# matrix where values of 1 participate in attention, 0 does not
attn_mask = torch.tril(torch.ones(S, S))
# tensor([[1., 0., 0.,  ..., 0., 0., 0.],
#         [1., 1., 0.,  ..., 0., 0., 0.],
#         [1., 1., 1.,  ..., 0., 0., 0.],
#         ...,
#         [1., 1., 1.,  ..., 1., 0., 0.],
#         [1., 1., 1.,  ..., 1., 1., 0.],
#         [1., 1., 1.,  ..., 1., 1., 1.]])

where a 1 represents an attention score (dot product) that we will keep and 0 represents an attention score that we will toss out. Row i represents the query position and column j represents the key position. This mask allows query i to dot product with keys ji. To actually "keep" the attention scores where a 1 is present and "toss" the scores where a 0 is present, we use the softmax operation. We can replace any 0 with float('-inf') so that after softmax, the negative inf positions become zero:

def my_sdpa(q, k, v, attn_mask):
    scaling = math.sqrt(k.size(-1))
    attn = (q @ k.transpose(-1,-2)) / scaling 
    # mask out scores that should not be used with -inf
    attn.masked_fill_(attn_mask == 0, float('-inf'))
    attn = F.softmax(attn, dim=-1) @ v
    return attn

Test

compare(my_sdpa(query, key, value, attn_mask=attn_mask),
        F.scaled_dot_product_attention(query, key, value, attn_mask=(attn_mask == 1)))
# Prints: Equal ✅

Note that F.scaled_dot_product_attention's attn_mask argument expects a boolean mask where a True value indicates a dot product that we "keep" to participate in attention.

Step 3: grouped query attention

In grouped query attention, we divide the Nh query heads into groups, each of which shares a single key/value head. For example, suppose we had 4 query heads and 2 key/value groups, we assign query heads to groups in the following way:

Q0 Q1 Q2 Q3
K0 K0 K1 K1

where queries Q0 and Q1 use the key K0. In our case, we defined num heads Nh = 12 and, for purposes of testing, we will define G = 3 groups and redefine our keys and values:

G     = Nh // 4 # number of key/value groups is 12 // 4 = 3
key   = torch.randn(B, G, S, E)
value = torch.randn(B, G, S, E)

Now, the main change we need to make to my_sdpa is to repeat our keys and values along dim=1 so that the first 4 queries (making up the first "group") use key/value head 0, and so the next 4 queries use key/value head 1, and so on. We can do this using torch.repeat_interleave:

x = torch.tensor([1,2,3])
print(x)                             # Prints: tensor([1, 2, 3])
print(x.repeat_interleave(2, dim=0)) # Prints: tensor([1, 1, 2, 2, 3, 3])

Let's make the change:

def my_sdpa(q, k, v, attn_mask, enable_gqa=False):
    if enable_gqa:
        _, Nh, _, _ = q.shape
        _, G, _, _ = k.shape
        assert Nh % G == 0, "number of heads in key and value must divide the number of heads in query"
        repeat = Nh // G
        k = k.repeat_interleave(repeat, dim=1)
        v = v.repeat_interleave(repeat, dim=1)

    scaling = math.sqrt(k.size(-1))
    attn = (q @ k.transpose(-1,-2)) / scaling 
    # mask out scores that should not be used with -inf
    attn.masked_fill_(attn_mask == 0, float('-inf'))
    attn = F.softmax(attn, dim=-1) @ v
    return attn

Test

PyTorch's F.scaled_dot_product_attention supports GQA with enable_gqa=True, so we can check our implementation easily:

compare(my_sdpa(query, key, value, attn_mask=attn_mask, enable_gqa=True),
        F.scaled_dot_product_attention(query, key, value, attn_mask=(attn_mask == 1), enable_gqa=True))
# Prints: Equal ✅

Step 4: multi-query attention

Multi-query attention is a special case of GQA where there is only 1 key/value group. So the function we wrote above will work for this case!

# multi-query attention
G = 1
key   = torch.randn(B, G, S, E)
value = torch.randn(B, G, S, E)

my_out = my_sdpa(query, key, value, attn_mask=attn_mask, enable_gqa=True)

compare(my_sdpa(query, key, value, attn_mask=attn_mask, enable_gqa=True),
        F.scaled_dot_product_attention(query, key, value, attn_mask=(attn_mask == 1), enable_gqa=True))
# Prints: Equal ✅

Summary

Our my_sdpa function reimplements F.scaled_dot_product_attention and allows us to better understand the three popular attention variants used in LLMs today. It's important to note that the shapes of our randomly generated query, key, and value tensors are just a reshape of the hidden states of shape (batch_size, sequence_len, hidden_size) that the Transformer operates on, where hidden_size = num_heads * head_dim. So, whereas a Transformer using MHA would have a key projection that looks like:

k_proj = nn.Linear(hidden_size, hidden_size, bias=False)

a Transformer using GQA with G groups would have

k_proj = nn.Linear(hidden_size, G * head_dim, bias=False)

If this was helpful, let me know on X @psandovalsegura!

@article{sandovalsegura2025multihead,
  title   = "Multi-head attention variants step-by-step in PyTorch.",
  author  = "Sandoval-Segura, Pedro",
  journal = "psando.bearblog.dev",
  year    = "2025",
  month   = "May",
  url     = "https://psando.bearblog.dev/multi-head-attention-variants-step-by-step-in-pytorch/"
}