Intuition
Attention allows a model to focus on relevant parts of the input when producing each part of the output. Instead of compressing an entire sequence into a fixed-size vector, attention computes a weighted sum over all positions.
Scaled Dot-Product Attention
Given queries , keys , and values :
Where is the dimension of the keys. The scaling prevents softmax from producing extremely peaked distributions.
Multi-Head Attention
Instead of one attention function, use parallel heads:
Each head can attend to different representation subspaces.
Self-Attention
When Q, K, V all come from the same sequence, it’s self-attention. Each token attends to every other token in the sequence.
PyTorch Implementation
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
def forward(self, x):
# x shape: (seq_len, batch, embed_dim)
out, weights = self.attn(x, x, x)
return out
Why Attention Works
- Parallelizable: unlike RNNs, all positions are computed simultaneously
- Long-range dependencies: any position can directly attend to any other
- Interpretable: attention weights show what the model focuses on