Transformer Architecture
Interview Priority
Transformer 是現代 AI 的核心架構 — BERT, GPT, ViT, Whisper 全都基於它。面試必考。你需要能解釋 self-attention 的數學和直覺、為什麼需要 scaling、multi-head 的意義、encoder vs decoder 的差異,以及 Transformer 和 RNN 的 tradeoffs。
Core Concepts
The Transformer (Vaswani et al., 2017, "Attention Is All You Need") replaced recurrence with self-attention, enabling:
- Parallel processing of all positions(vs RNN's sequential)
- Direct long-range dependencies between any two positions(vs RNN's gradient chain)
Scaled Dot-Product Attention
The fundamental building block:
Three matrices from the same input(self-attention)or different inputs(cross-attention):
- Query (): 「我在找什麼?」
- Key (): 「我有什麼?」
- Value (): 「如果 match,我提供什麼?」
直覺:像 database lookup — Query 和每個 Key 做 dot product 得到 relevance score → softmax → weighted sum of Values。
Why Scale by ?
When is large, dot products grow in magnitude(variance ≈ )。Large dot products push softmax into saturation regions → gradients ≈ 0 → training unstable。
Dividing by keeps the variance at 1 regardless of dimension → softmax stays in a well-behaved range。
面試常見追問
「如果不 scale 會怎樣?」— Softmax 的 output 會變得非常 peaky(one-hot-like)→ 幾乎只 attend to one position → lose the ability to softly combine information → gradients near-zero for most positions → slow/unstable training。
Multi-Head Attention
Run attention times in parallel with different learned projections:
Why Multiple Heads?
Single attention head 只能學一種 attention pattern。Multi-head 讓 model 同時學多種 relationships:
| Head | Might Learn |
|---|---|
| Head 1 | Syntactic: subject-verb agreement(「The cats are」) |
| Head 2 | Semantic: co-reference(「John ... he」) |
| Head 3 | Positional: attend to adjacent tokens |
| Head 4 | Long-range: attend to sentence beginning |
Total parameters 不增加 — 每個 head 的 dimension = ,concatenate 後恢復到 。
Computational Complexity
| Operation | Complexity | Note |
|---|---|---|
| Self-attention | = sequence length — quadratic in | |
| Feed-forward | Linear in | |
| Overall per layer | Attention dominates for long sequences |
O(n²) 的含義
→ 262K attention scores per head。 → 16.7M。 → 1.07B。這就是為什麼 long-context models 需要 efficient attention(FlashAttention, linear attention, sparse attention)。
Positional Encoding
Self-attention is permutation-invariant — it treats input as a set, not a sequence. Positional encoding injects order information:
Sinusoidal encoding (original paper):
Why sin/cos? 兩個理由:
- 可以 generalize to longer sequences than seen during training(any position has a unique encoding)
- Relative position 可以被 linear function 表示: is a linear function of
Positional Encoding Variants
| Method | Description | Used In |
|---|---|---|
| Sinusoidal | Fixed, sin/cos | Original Transformer |
| Learned | Trainable embedding per position | BERT, GPT-2 |
| RoPE (Rotary) | Rotation in embedding space | LLaMA, GPT-NeoX |
| ALiBi | Linear bias added to attention | BLOOM, MPT |
Learned 和 sinusoidal 在 practice 差異不大。RoPE 和 ALiBi 是為了 length generalization — 讓 model 能 handle 比 training 更長的 sequences。
Encoder-Decoder Structure
Encoder Layer
Each encoder layer has two sub-layers:
- Multi-head self-attention: Each position attends to all positions
- Position-wise feed-forward network:
Both wrapped with residual connection + LayerNorm:
Decoder Layer
Each decoder layer has three sub-layers:
- Masked self-attention: Each position attends to previous positions only(causal mask)
- Cross-attention: Decoder queries attend to encoder outputs(Keys + Values from encoder)
- Position-wise feed-forward network
Causal Mask
Decoder 的 self-attention 不能看到 future tokens(autoregressive constraint):
加入 mask 後 softmax → future positions 的 attention weight = 0。
Pre-Norm vs Post-Norm
| Variant | Formula | Advantage |
|---|---|---|
| Post-Norm (original) | Better final performance | |
| Pre-Norm (modern) | More stable training, easier to scale |
大多數 modern models(GPT, LLaMA)用 Pre-Norm — training 更穩定,不需要 learning rate warmup 也能 converge。
Modern Variants
Encoder-Only: BERT
| Aspect | Detail |
|---|---|
| Architecture | Encoder only(no decoder) |
| Pre-training | Masked Language Model (MLM): randomly mask 15% tokens, predict them |
| Bidirectional | Yes — each token sees full context(left + right) |
| Fine-tuning | Add task-specific head(classification, NER, QA) |
| Best for | Understanding tasks(classification, NER, similarity) |
Decoder-Only: GPT
| Aspect | Detail |
|---|---|
| Architecture | Decoder only(no encoder, no cross-attention) |
| Pre-training | Causal Language Model (CLM): predict next token |
| Autoregressive | Yes — each token sees only left context |
| Inference | Generate one token at a time |
| Best for | Generation tasks(text, code, conversation) |
Encoder-Decoder: T5, BART
| Aspect | Detail |
|---|---|
| Architecture | Full encoder-decoder |
| Pre-training | Various(T5: span corruption, BART: denoising) |
| Best for | Seq2seq tasks(translation, summarization) |
Comparison
| Model | Direction | Pre-training | Best For |
|---|---|---|---|
| BERT | Bidirectional | MLM + NSP | Classification, NER, similarity |
| GPT | Left-to-right | Next token prediction | Generation, few-shot, conversation |
| T5 | Encoder-decoder | Span corruption | Translation, summarization, QA |
面試經典問題
「BERT 和 GPT 有什麼不同?」— BERT 是 encoder-only(bidirectional, sees full context, good for understanding)。GPT 是 decoder-only(autoregressive, sees left context only, good for generation)。BERT 不能做 text generation(不是 autoregressive),GPT 做 classification 不如 BERT(沒有 bidirectional context)。
KV Cache (Inference Optimization)
Autoregressive generation 的效率問題:每生成一個 token,都要重新計算整個 sequence 的 attention — 非常浪費。
KV Cache: Cache previous tokens' Key and Value matrices → new token 只需要計算自己的 Q、K、V,然後 append 到 cache。
| Without KV Cache | With KV Cache |
|---|---|
| Generate token : compute attention over all tokens | Generate token : compute only new K, V for token , reuse cached K, V for tokens |
| total for tokens | total — linear! |
這是為什麼 LLM inference 的 memory usage 和 context length 成正比 — KV cache 越長越吃 memory。
Efficient Attention Variants
Standard attention 的 在 long sequences 上不 scalable。Solutions:
| Method | Complexity | How |
|---|---|---|
| FlashAttention | Still but much faster | Tiling + kernel fusion → 避免寫回 HBM → 2-4x speedup |
| Sparse Attention | Only attend to nearby + strided positions | |
| Linear Attention | Approximate softmax with kernel trick | |
| Sliding Window | Each token only attends to window of positions | |
| Multi-Query Attention | Reduces KV cache | Share K, V heads across all Q heads |
| Grouped-Query Attention | Reduces KV cache | Share K, V across groups of Q heads |
FlashAttention 是目前最 practical 的加速 — 不改 attention 結果(exact),只改 computation order → IO-aware optimization → 實務中幾乎所有 LLM 都用。
Attention Heatmap
Visualize how tokens attend to each other in multi-head attention:
Attention Heatmap
Visualize how each token attends to other tokens. Each row shows the attention distribution for one query token. Click 'Compute' to run the actual transformer model.
Run the Transformer
This code uses our pure Python implementation — every operation (matmul, softmax, layer norm) written from scratch.
Forward Pass
Greedy Decoding
Inspect Attention Weights
Hands-on: Transformer in PyTorch
Self-Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, num_heads=8):
super().__init__()
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
B = Q.size(0)
Q = self.W_q(Q).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
out = out.transpose(1, 2).contiguous().view(B, -1, self.num_heads * self.d_k)
return self.W_o(out)
Encoder Layer
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model=512, num_heads=8, d_ff=2048, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Pre-Norm variant (modern standard)
x = x + self.dropout(self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), mask))
x = x + self.dropout(self.ffn(self.norm2(x)))
return x
Causal Mask
def generate_causal_mask(seq_len):
"""Decoder mask: prevent attending to future tokens."""
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
return ~mask # True = allowed, False = masked
Real-World Use Cases
Case 1: 信用卡詐欺 — Transformer for Sequences
BERT-like model on transaction sequences 可以學到 contextual fraud patterns:
# Each transaction = [amount, merchant_category, time_delta, ...]
# Sequence of last 100 transactions → Transformer encoder → [CLS] → P(fraud)
class FraudTransformer(nn.Module):
def __init__(self, feature_dim, d_model, nhead, num_layers):
super().__init__()
self.projection = nn.Linear(feature_dim, d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.classifier = nn.Linear(d_model, 1)
def forward(self, x, mask=None):
x = self.projection(x)
x = self.encoder(x, src_key_padding_mask=mask)
return torch.sigmoid(self.classifier(x[:, -1])) # last position
面試 follow-up:「Transformer vs LSTM for fraud?」— Transformer 在 longer sequences 上更好(direct attention vs gradient chain),且 training 更快(parallel)。但 real-time scoring 時 RNN 的 incremental inference 更有效率(不需要 KV cache)。
Case 2: 推薦系統 — Self-Attention on User History
SASRec (Self-Attentive Sequential Recommendation): 把 user 的 interaction history 當成 sequence → Transformer encoder → predict next item。
和 RNN-based GRU4Rec 的差異:Transformer 可以直接 attend to 歷史中任何一個 item(不受 recency bias 限制)→ 更好地捕捉 long-term preferences。
Case 3: NLP — The Foundation of LLMs
| Application | Model | How |
|---|---|---|
| Text classification | BERT + [CLS] → FC | Fine-tune on labeled data |
| Named Entity Recognition | BERT + token-level head | Each token → entity tag |
| Question Answering | BERT + start/end prediction | Predict answer span in passage |
| Text generation | GPT (autoregressive) | Next token prediction |
| Translation | T5 / BART (encoder-decoder) | Input → encoder, output from decoder |
| Embeddings | BERT / Sentence-BERT | [CLS] or mean pooling → semantic vector |
Interview Signals
What interviewers listen for:
- 你能解釋 Q, K, V 的直覺(database lookup analogy)
- 你知道為什麼 scale by (softmax saturation → gradient issues)
- 你能比較 BERT vs GPT(encoder vs decoder, bidirectional vs autoregressive)
- 你知道 complexity 的意義和 efficient attention 的存在
- 你理解 KV cache 對 inference 的重要性
Practice
Flashcards
Flashcards (1/10)
Why scale attention by √d_k?
When d_k is large, dot products grow in magnitude → softmax saturates → near-zero gradients → training unstable。Dividing by √d_k keeps variance at 1 → softmax stays in well-behaved range → proper gradient flow。
Quiz
Self-attention 對 sequence length n 的 computational complexity?