back to writing
attention is all you need — a walkthrough
·ai, papers, notes

attention is all you need — a walkthrough

a hands-on explanation of the transformer: what attention actually does, why the paper changed everything, and how the pieces fit together.

in 2017, vaswani et al. published a paper with one of the cockiest titles in ML history: attention is all you need. eight years later, it's the foundation under every serious language model we use. this is my attempt at explaining the architecture the way i wish someone had explained it to me the first time — from the problem it solves, through the mechanism, up to the full picture.

i'll skip the history lesson and start with the uncomfortable thing the paper actually says.

#the uncomfortable claim

before transformers, sequence models were dominated by rnns and their upgraded cousins (lstm, gru). they processed a sentence the way you'd read one — left to right, one token at a time — and carried a hidden state forward.

two things went wrong:

  • you can't parallelize a loop. training an rnn on a sentence of length n requires n sequential steps. gpus hate this.
  • memory fades. by the time you're 40 tokens in, the signal from token 1 has been bent and squeezed through 40 matrix multiplications. it's mostly gone.

vaswani's claim was: you don't need the loop. you don't need recurrence at all. every token can look at every other token directly, in one shot, using attention. that's the whole paper.

#the core mechanism: attention as weighted lookup

forget neural networks for a second. imagine a python dict:

memory = {
    "cat": [0.2, 0.8, -0.1],
    "dog": [0.3, 0.7,  0.1],
    "car": [0.9, 0.1,  0.4],
}
memory["cat"]  # exact match

this is hard lookup. key must match exactly. useless for anything fuzzy.

attention is the soft version of the same idea:

  • you have a query (what am i looking for?)
  • you compare it against every key (what's in memory?)
  • you take a weighted sum of the corresponding values (the actual content)

the "weights" come from similarity. keys that are more similar to the query get a bigger slice of the final output. concretely:

# pseudo-code
scores   = query @ keys.T           # similarity of q to every k
weights  = softmax(scores / sqrt(d_k))  # normalize to a distribution
output   = weights @ values         # weighted average of values

that's it. that's attention. the rest of the paper is just scaling this idea up in clever ways.

QKVMatMul · Q × Kᵀ÷ √dkSoftMaxMatMul · attn × Vattention output
figure 1 — scaled dot-product attention. Q and K are compared via matmul, scaled by √dₖ to keep gradients sane, softmaxed into a distribution, then used to weight V.

the √d_k scaling is a small detail that matters a lot. without it, as the embedding dimension grows, the dot-product scores get bigger, the softmax saturates, gradients vanish, and training dies. dividing by √d_k is the duct tape that keeps the numbers in a range where softmax still has useful gradients.

#self-attention: q, k, v all come from the same place

here's the trick. in normal attention (say, in a seq2seq model), the query came from the decoder and the keys/values came from the encoder. that's fine, but limited.

self-attention asks: what if every token in a sequence generated its own query, key, and value — and we let each token attend to every other token, including itself?

given an input matrix X of shape (sequence_length, embedding_dim), we learn three projection matrices W_Q, W_K, W_V and compute:

Q = X @ W_Q
K = X @ W_K
V = X @ W_V
attention_output = softmax(Q @ K.T / sqrt(d_k)) @ V

now every token in the sequence has produced an output that is a mix of information from every other token, weighted by relevance. token 1 can directly look at token 40 without having to travel through 39 hidden states.

this is what makes transformers parallelizable — every token's output can be computed at the same time. the gpu stops hating you.

#multi-head: one attention isn't enough

one set of (W_Q, W_K, W_V) matrices produces one "view" of how tokens relate to each other. that's a strong constraint. what if some tokens relate by syntax ("the dog chased the ball") and others by coreference ("it was red") — you'd want attention to capture both.

the solution: run attention h times in parallel, each with its own projection matrices, then concatenate and project back down. these are heads. the paper uses h = 8.

QKVLinear projections · h copies of (WQ, WK, WV)head 1scaleddot-productattentionhead 2scaleddot-productattentionhead 3scaleddot-productattentionhead 4scaleddot-productattentionConcatLinear · WOmulti-head output
figure 2 — multi-head attention. h heads attend in parallel with independent (WQ, WK, WV) projections. their outputs are concatenated and projected through WO to produce the final output.

in practice, different heads specialize — one might focus on adjacent tokens, another on long-range dependencies, another on syntactic roles. nobody tells them to; they figure it out. the paper includes visualizations of this emergent specialization that are still worth staring at.

#but wait — where does order come from?

here's a problem. self-attention treats the input as a set. if you shuffle the tokens, the attention output is the same (just shuffled). that's wrong for language — "dog bites man" and "man bites dog" are different stories.

the paper's fix is positional encoding: before feeding embeddings into the transformer, add a vector that encodes position. they use fixed sinusoids of different frequencies:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

different dimensions oscillate at different rates, so each position gets a unique fingerprint. and because sines and cosines have nice algebraic properties, the model can (in theory) learn to attend by relative position.

pos 0pos 8pos 15dimension →
figure 3 — positional encoding as a heatmap. rows are positions, columns are dimensions. low dimensions (left) vary quickly; high dimensions (right) vary slowly. every position has a unique vector.

later work replaced sinusoidal pe with learned, relative, and rotary variants — but the core idea (inject positional information additively) is still standard.

#the full picture

with those pieces in hand, the transformer is nearly assembled. the paper uses an encoder-decoder architecture for machine translation:

  • the encoder reads the input sentence and produces a rich representation. it's a stack of N identical layers, each containing: multi-head self-attention → add & layer-norm → feed-forward → add & layer-norm.
  • the decoder generates the output sentence one token at a time. it's similar, but each layer has two attention sub-layers: masked self-attention (over the output so far) and cross-attention into the encoder's output.
  • residual connections ("add & norm") wrap every sub-layer, which is what makes the whole thing trainable at depth.
encoder · N×input tokensembedding + pos encodingmulti-headself-attentionadd & normfeed-forward(position-wise)add & normencoded representationdecoder · N×output tokens (shifted)embedding + pos encodingmasked multi-headself-attentionadd & normcross-attentionQ from decoder, K/V from encoderadd & normfeed-forwardlinear + softmaxnext-token probabilities
figure 4 — the full transformer. the encoder (left) builds a representation of the input. the decoder (right) generates output autoregressively, cross-attending to the encoder's output.

a few important details the diagram hides:

  • the decoder's self-attention is masked — when generating token t, you can only attend to tokens 1..t-1. otherwise the model cheats and learns to copy the target.
  • cross-attention is just ordinary attention where Q comes from the decoder and K, V come from the encoder. it's the bridge between the two stacks.
  • the feed-forward sub-layer is a simple two-layer mlp applied to each position independently. most of the model's parameters live here, not in attention.

#why this won

from a distance, the transformer is a sequence of:

  • attention layers that let tokens talk to each other
  • feed-forward layers that let each token think by itself
  • residual connections that keep gradients alive
  • layer norms that keep activations scaled

no recurrence. no convolutions. just matrix multiplications in a particular order. which happens to be exactly what gpus are good at.

what the paper got right, in hindsight:

  1. parallelism. training scales near-linearly with compute in a way rnns never could.
  2. direct access. any token can influence any other in one attention step, not n.
  3. composability. the block is the same whether you stack 6 of them (the original paper) or 80 (gpt-4 scale). the architecture didn't need to change — just get bigger.

what it didn't foresee:

  • that you could throw away the encoder entirely and get gpt (decoder-only).
  • that throwing away the decoder gives you bert (encoder-only).
  • that "just scale it" would produce emergent reasoning.
  • that the quadratic O(n²) attention cost would become the bottleneck for long context — the thing everyone is trying to work around with flash attention, linear attention, sliding windows, and state-space models.

if you want to go deeper:

  • the original paperarxiv.org/abs/1706.03762. short, dense, and less intimidating than you expect once you know the mechanism.
  • the illustrated transformer by jay alammar — the canonical visual walkthrough.
  • karpathy's "let's build gpt from scratch" video — if you want to implement every line of attention yourself. the 90 minutes you'll spend on it are the best self-contained ml investment i know of.
  • the annotated transformer — the paper, line by line, in pytorch. was updated in 2022 and still holds up.

the most surprising thing about attention, after eight years and hundreds of gpu-years of follow-up work, is that the core idea in the paper is still what's under the hood. we made it bigger, cheaper, and longer-context, but the mechanism is unchanged.

sometimes the title really is the whole thesis.

[n] now