Transformer Architecture

Interview Priority

Transformer 是現代 AI 的核心架構 — BERT, GPT, ViT, Whisper 全都基於它。面試必考。你需要能解釋 self-attention 的數學和直覺、為什麼需要 scaling、multi-head 的意義、encoder vs decoder 的差異,以及 Transformer 和 RNN 的 tradeoffs。

Core Concepts

The Transformer (Vaswani et al., 2017, "Attention Is All You Need") replaced recurrence with self-attention, enabling:

  • Parallel processing of all positions(vs RNN's sequential)
  • Direct long-range dependencies between any two positions(vs RNN's gradient chain)

Scaled Dot-Product Attention

The fundamental building block:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Three matrices from the same input(self-attention)or different inputs(cross-attention):

  • Query (QQ): 「我在找什麼?」
  • Key (KK): 「我有什麼?」
  • Value (VV): 「如果 match,我提供什麼?」

直覺:像 database lookup — Query 和每個 Key 做 dot product 得到 relevance score → softmax → weighted sum of Values。

Why Scale by dk\sqrt{d_k}?

When dkd_k is large, dot products qkq \cdot k grow in magnitude(variance ≈ dkd_k)。Large dot products push softmax into saturation regions → gradients ≈ 0 → training unstable。

Dividing by dk\sqrt{d_k} keeps the variance at 1 regardless of dimension → softmax stays in a well-behaved range。

面試常見追問

「如果不 scale 會怎樣?」— Softmax 的 output 會變得非常 peaky(one-hot-like)→ 幾乎只 attend to one position → lose the ability to softly combine information → gradients near-zero for most positions → slow/unstable training。

Multi-Head Attention

Run attention hh times in parallel with different learned projections:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O

Why Multiple Heads?

Single attention head 只能學一種 attention pattern。Multi-head 讓 model 同時學多種 relationships:

HeadMight Learn
Head 1Syntactic: subject-verb agreement(「The cats are」)
Head 2Semantic: co-reference(「John ... he」)
Head 3Positional: attend to adjacent tokens
Head 4Long-range: attend to sentence beginning

Total parameters 不增加 — 每個 head 的 dimension = dmodel/hd_{\text{model}} / h,concatenate 後恢復到 dmodeld_{\text{model}}

Computational Complexity

OperationComplexityNote
Self-attentionO(n2d)O(n^2 \cdot d)nn = sequence length — quadratic in nn
Feed-forwardO(nd2)O(n \cdot d^2)Linear in nn
Overall per layerO(n2d+nd2)O(n^2 d + n d^2)Attention dominates for long sequences

O(n²) 的含義

n=512n = 512 → 262K attention scores per head。n=4096n = 4096 → 16.7M。n=32768n = 32768 → 1.07B。這就是為什麼 long-context models 需要 efficient attention(FlashAttention, linear attention, sparse attention)。

Positional Encoding

Self-attention is permutation-invariant — it treats input as a set, not a sequence. Positional encoding injects order information:

Sinusoidal encoding (original paper):

PE(pos,2i)=sin(pos100002i/dmodel)PE(pos,2i+1)=cos(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) \qquad PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)

Why sin/cos? 兩個理由:

  1. 可以 generalize to longer sequences than seen during training(any position has a unique encoding)
  2. Relative position 可以被 linear function 表示:PEpos+kPE_{pos+k} is a linear function of PEposPE_{pos}

Positional Encoding Variants

MethodDescriptionUsed In
SinusoidalFixed, sin/cosOriginal Transformer
LearnedTrainable embedding per positionBERT, GPT-2
RoPE (Rotary)Rotation in embedding spaceLLaMA, GPT-NeoX
ALiBiLinear bias added to attentionBLOOM, MPT

Learned 和 sinusoidal 在 practice 差異不大。RoPE 和 ALiBi 是為了 length generalization — 讓 model 能 handle 比 training 更長的 sequences。

Encoder-Decoder Structure

Encoder Layer

Each encoder layer has two sub-layers:

  1. Multi-head self-attention: Each position attends to all positions
  2. Position-wise feed-forward network: FFN(x)=ReLU(xW1+b1)W2+b2\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2

Both wrapped with residual connection + LayerNorm:

output=LayerNorm(x+SubLayer(x))\text{output} = \text{LayerNorm}(x + \text{SubLayer}(x))

Decoder Layer

Each decoder layer has three sub-layers:

  1. Masked self-attention: Each position attends to previous positions only(causal mask)
  2. Cross-attention: Decoder queries attend to encoder outputs(Keys + Values from encoder)
  3. Position-wise feed-forward network

Causal Mask

Decoder 的 self-attention 不能看到 future tokens(autoregressive constraint):

Maskij={0if ji(allowed)if j>i(blocked)\text{Mask}_{ij} = \begin{cases} 0 & \text{if } j \leq i \quad \text{(allowed)} \\ -\infty & \text{if } j > i \quad \text{(blocked)} \end{cases}

加入 mask 後 softmax → future positions 的 attention weight = 0。

Pre-Norm vs Post-Norm

VariantFormulaAdvantage
Post-Norm (original)LN(x+f(x))\text{LN}(x + f(x))Better final performance
Pre-Norm (modern)x+f(LN(x))x + f(\text{LN}(x))More stable training, easier to scale

大多數 modern models(GPT, LLaMA)用 Pre-Norm — training 更穩定,不需要 learning rate warmup 也能 converge。

Modern Variants

Encoder-Only: BERT

AspectDetail
ArchitectureEncoder only(no decoder)
Pre-trainingMasked Language Model (MLM): randomly mask 15% tokens, predict them
BidirectionalYes — each token sees full context(left + right)
Fine-tuningAdd task-specific head(classification, NER, QA)
Best forUnderstanding tasks(classification, NER, similarity)

Decoder-Only: GPT

AspectDetail
ArchitectureDecoder only(no encoder, no cross-attention)
Pre-trainingCausal Language Model (CLM): predict next token
AutoregressiveYes — each token sees only left context
InferenceGenerate one token at a time
Best forGeneration tasks(text, code, conversation)

Encoder-Decoder: T5, BART

AspectDetail
ArchitectureFull encoder-decoder
Pre-trainingVarious(T5: span corruption, BART: denoising)
Best forSeq2seq tasks(translation, summarization)

Comparison

ModelDirectionPre-trainingBest For
BERTBidirectionalMLM + NSPClassification, NER, similarity
GPTLeft-to-rightNext token predictionGeneration, few-shot, conversation
T5Encoder-decoderSpan corruptionTranslation, summarization, QA

面試經典問題

「BERT 和 GPT 有什麼不同?」— BERT 是 encoder-only(bidirectional, sees full context, good for understanding)。GPT 是 decoder-only(autoregressive, sees left context only, good for generation)。BERT 不能做 text generation(不是 autoregressive),GPT 做 classification 不如 BERT(沒有 bidirectional context)。

KV Cache (Inference Optimization)

Autoregressive generation 的效率問題:每生成一個 token,都要重新計算整個 sequence 的 attention — 非常浪費。

KV Cache: Cache previous tokens' Key and Value matrices → new token 只需要計算自己的 Q、K、V,然後 append 到 cache。

Without KV CacheWith KV Cache
Generate token tt: compute attention over all tt tokensGenerate token tt: compute only new K, V for token tt, reuse cached K, V for tokens 1..t11..t-1
O(T2)O(T^2) total for TT tokensO(T)O(T) total — linear!

這是為什麼 LLM inference 的 memory usage 和 context length 成正比 — KV cache 越長越吃 memory。

Efficient Attention Variants

Standard attention 的 O(n2)O(n^2) 在 long sequences 上不 scalable。Solutions:

MethodComplexityHow
FlashAttentionStill O(n2)O(n^2) but much fasterTiling + kernel fusion → 避免寫回 HBM → 2-4x speedup
Sparse AttentionO(nn)O(n \sqrt{n})Only attend to nearby + strided positions
Linear AttentionO(n)O(n)Approximate softmax with kernel trick
Sliding WindowO(nw)O(n \cdot w)Each token only attends to window of ww positions
Multi-Query AttentionReduces KV cacheShare K, V heads across all Q heads
Grouped-Query AttentionReduces KV cacheShare K, V across groups of Q heads

FlashAttention 是目前最 practical 的加速 — 不改 attention 結果(exact),只改 computation order → IO-aware optimization → 實務中幾乎所有 LLM 都用。

Attention Heatmap

Visualize how tokens attend to each other in multi-head attention:

Attention Heatmap

Visualize how each token attends to other tokens. Each row shows the attention distribution for one query token. Click 'Compute' to run the actual transformer model.

Head:

Run the Transformer

This code uses our pure Python implementation — every operation (matmul, softmax, layer norm) written from scratch.

Forward Pass

Transformer Forward PassClick Run to load Python

Greedy Decoding

Autoregressive GenerationClick Run to load Python

Inspect Attention Weights

Extract Attention WeightsClick Run to load Python

Hands-on: Transformer in PyTorch

Self-Attention

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8):
        super().__init__()
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        B = Q.size(0)
        Q = self.W_q(Q).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)

        out = out.transpose(1, 2).contiguous().view(B, -1, self.num_heads * self.d_k)
        return self.W_o(out)

Encoder Layer

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model=512, num_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model),
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Pre-Norm variant (modern standard)
        x = x + self.dropout(self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), mask))
        x = x + self.dropout(self.ffn(self.norm2(x)))
        return x

Causal Mask

def generate_causal_mask(seq_len):
    """Decoder mask: prevent attending to future tokens."""
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return ~mask  # True = allowed, False = masked

Real-World Use Cases

Case 1: 信用卡詐欺 — Transformer for Sequences

BERT-like model on transaction sequences 可以學到 contextual fraud patterns:

# Each transaction = [amount, merchant_category, time_delta, ...]
# Sequence of last 100 transactions → Transformer encoder → [CLS] → P(fraud)
class FraudTransformer(nn.Module):
    def __init__(self, feature_dim, d_model, nhead, num_layers):
        super().__init__()
        self.projection = nn.Linear(feature_dim, d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.classifier = nn.Linear(d_model, 1)

    def forward(self, x, mask=None):
        x = self.projection(x)
        x = self.encoder(x, src_key_padding_mask=mask)
        return torch.sigmoid(self.classifier(x[:, -1]))  # last position

面試 follow-up:「Transformer vs LSTM for fraud?」— Transformer 在 longer sequences 上更好(direct attention vs gradient chain),且 training 更快(parallel)。但 real-time scoring 時 RNN 的 incremental inference 更有效率(不需要 KV cache)。

Case 2: 推薦系統 — Self-Attention on User History

SASRec (Self-Attentive Sequential Recommendation): 把 user 的 interaction history 當成 sequence → Transformer encoder → predict next item。

和 RNN-based GRU4Rec 的差異:Transformer 可以直接 attend to 歷史中任何一個 item(不受 recency bias 限制)→ 更好地捕捉 long-term preferences。

Case 3: NLP — The Foundation of LLMs

ApplicationModelHow
Text classificationBERT + [CLS] → FCFine-tune on labeled data
Named Entity RecognitionBERT + token-level headEach token → entity tag
Question AnsweringBERT + start/end predictionPredict answer span in passage
Text generationGPT (autoregressive)Next token prediction
TranslationT5 / BART (encoder-decoder)Input → encoder, output from decoder
EmbeddingsBERT / Sentence-BERT[CLS] or mean pooling → semantic vector

Interview Signals

What interviewers listen for:

  • 你能解釋 Q, K, V 的直覺(database lookup analogy)
  • 你知道為什麼 scale by dk\sqrt{d_k}(softmax saturation → gradient issues)
  • 你能比較 BERT vs GPT(encoder vs decoder, bidirectional vs autoregressive)
  • 你知道 O(n2)O(n^2) complexity 的意義和 efficient attention 的存在
  • 你理解 KV cache 對 inference 的重要性

Practice

Flashcards

Flashcards (1/10)

Why scale attention by √d_k?

When d_k is large, dot products grow in magnitude → softmax saturates → near-zero gradients → training unstable。Dividing by √d_k keeps variance at 1 → softmax stays in well-behaved range → proper gradient flow。

Click card to flip

Quiz

Question 1/10

Self-attention 對 sequence length n 的 computational complexity?

Mark as Complete

3/5 — Okay