Attention Is All You Need — Line-by-Line Transformer Explanation

This page provides an annotated view of the Transformer architecture implementation. On the left, each line’s purpose is described. On the right, you’ll see the corresponding Python code, similar to nn.labml.ai.

1️⃣ Scaled Dot-Product Attention

The scaled dot-product attention computes the similarity between query Q and key K. It divides by √dₖ to stabilize gradients.
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
We apply a softmax to convert scores into attention weights (probabilities).
attn = torch.softmax(scores, dim=-1)
Finally, multiply the weights with V to produce the attention output.
output = torch.matmul(attn, V)

2️⃣ Multi-Head Attention

Each head learns a unique projection of the input vectors Q, K, and V.
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
Split the embeddings into multiple heads to capture different types of relationships between tokens.
x = x.view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
Each head performs scaled dot-product attention independently.
out, attn = self.attention(Q, K, V, mask)
The results from all heads are concatenated and projected back to the original embedding space.
out = self.w_o(torch.cat(heads, dim=-1))

3️⃣ Encoder Block

The encoder layer applies multi-head attention followed by a feed-forward network, with residual connections and layer normalization.
x = x + self.dropout(attn_out)
x = self.norm1(x)
x = x + self.dropout(ff_out)
x = self.norm2(x)

4️⃣ Decoder Block

The decoder first performs masked self-attention to prevent seeing future tokens.
masked_out, masked_attn = self.masked_mha(x, x, x, mask=tgt_mask)
Then performs encoder-decoder attention, attending over encoder outputs.
encdec_out, encdec_attn = self.mha(x, enc_out, enc_out, mask=src_mask)