RNN & LSTM
Interview Context
雖然 Transformers 在大多數場景已取代 RNNs,RNN/LSTM 仍是面試高頻題 — 因為它們考驗你對 sequential processing、vanishing gradients、gating mechanisms 的理解。你還要能清楚說明為什麼 Transformers 贏了。
Vanilla RNN
A recurrent neural network processes sequences by maintaining a hidden state updated at each time step:
Hidden state 是 network 的「memory」— 它 encode 了所有 previous inputs 的資訊。Same weight matrices 在所有 time steps shared(parameter sharing)。
Key Properties
| Property | Detail |
|---|---|
| Arbitrary length | 可以處理任意長度的 sequence |
| Fixed hidden size | 不管 sequence 多長,hidden state 維度不變 |
| Parameter sharing | 所有 time steps 共用同一組 weights → parameter count 和 sequence length 無關 |
| Sequential | 必須一步一步計算 → 不能 parallelize → 慢 |
Backpropagation Through Time (BPTT)
To train an RNN, "unroll" it through time and apply backpropagation:
每個 term 涉及 Jacobians 的乘積:
Each factor 涉及 weight matrix 和 activation derivative。
The Vanishing Gradient Problem
且通常 → gradient product shrinks exponentially:
Long sequences () → gradient effectively vanishes → network cannot learn long-range dependencies。
Practical Impact
Vanilla RNN 通常無法學到超過 10-20 time steps 的 dependency。需要記住 50 或 100 步之前的資訊(長文件、對話)→ vanilla RNN 會完全失敗。這就是 LSTM 和 GRU 的核心動機。
Truncated BPTT: 實務中只 backpropagate fixed number of steps(例如 35)— 犧牲 long-range learning 換取 computational efficiency。
LSTM (Long Short-Term Memory)
LSTMs (Hochreiter & Schmidhuber, 1997) 用 cell state — 一個 separate memory pathway with carefully controlled gates — 解決 vanishing gradient problem。
The Three Gates + Cell State
Forget Gate — decides what to erase from memory:
Input Gate — decides what new information to store:
Cell State Update — the critical equation:
Output Gate — decides what to expose from memory:
Gate Intuition
| Gate | Sigmoid Output | Intuition |
|---|---|---|
| Forget | Near 1 = keep, near 0 = erase | 「這個 memory 還有用嗎?」 |
| Input | Near 1 = write, near 0 = ignore | 「這個新資訊值得記住嗎?」 |
| Output | Near 1 = expose, near 0 = hide | 「現在需要用到這個 memory 嗎?」 |
Sigmoid output 在 (0,1) → 控制 information flow 的「閥門」。Tanh 在 (-1,1) → 產出 candidate values。
Why LSTMs Solve Vanishing Gradients
Cell state update 是 additive(不是 multiplicative):
When (forget gate open)→ gradients flow through unchanged — no multiplicative shrinkage.
這和 ResNet 的 skip connections 是同一個原理 — additive path 讓 gradient 可以直接通過,不被 multiply/shrink。
LSTM Parameter Count
LSTM with hidden size and input size : parameters — 4x of vanilla RNN(因為有 4 個 gate weight matrices:forget, input, cell candidate, output)。
Forget Gate Bias Initialization
Important Training Detail
Forget gate bias 通常初始化為 1 或 2(而非 0)。原因:初始化 bias = 0 → sigmoid(0) = 0.5 → 一開始就 forget 一半的 memory → gradient flow 不好。Bias = 1 → sigmoid(1) ≈ 0.73 → 初始時傾向 keep memory → gradient flow through cell state → 更好的 training。這是 Jozefowicz et al., 2015 的重要 finding。
GRU (Gated Recurrent Unit)
GRU (Cho et al., 2014) simplifies LSTM by merging forget + input gates into one update gate and combining cell + hidden state:
Update Gate:
Reset Gate:
Hidden State:
GRU vs LSTM
| Aspect | LSTM | GRU |
|---|---|---|
| Gates | 3 (forget, input, output) | 2 (update, reset) |
| States | Hidden + Cell(separate memory) | Hidden only(dual purpose) |
| Parameters | — 25% fewer | |
| Performance | Slightly better on long sequences | Comparable on most tasks |
| Training speed | Slower | Faster |
| When to use | Very long dependencies, memory-critical | Default choice when LSTM not clearly better |
面試中被問「GRU vs LSTM?」— GRU 更簡單、更快,在大多數 tasks 上 performance comparable。LSTM 在非常長的 sequences 上可能略好(因為 separate cell state provides more memory capacity)。實務中先試 GRU,不夠再換 LSTM。
Bidirectional RNNs
Standard RNNs only use past context. Bidirectional RNNs run two separate RNNs — forward and backward — and concatenate:
每個 position 同時看到 past 和 future context → output dimension = 2 × hidden_size。
| Use Case | Bidirectional? | Why |
|---|---|---|
| NER / POS tagging | Yes | 需要前後 context 判斷詞性 |
| Sentiment analysis | Yes | 整個 sentence 都已知 |
| Language modeling | No | Autoregressive — 不能看未來 |
| Speech recognition (offline) | Yes | 整段 audio 都已錄好 |
| Real-time translation | No | 必須 incrementally process |
Bidirectional ≠ Always Better
Bidirectional 只在整個 sequence 都已知的場景才能用。Autoregressive generation(語言模型、翻譯 decoder)不能用 — 因為 backward RNN 需要 future tokens 但 generation 時 future tokens 還不存在。
Sequence-to-Sequence (Seq2Seq)
Encoder-decoder architecture for tasks like machine translation:
- Encoder: Process input sequence → final hidden state (context vector)
- Decoder: Generate output sequence one token at a time, conditioned on
The Bottleneck Problem
把整個 input sequence 壓縮成 single fixed-size vector → severe bottleneck — longer sequences 丟失更多 information。
Attention mechanism (Bahdanau et al., 2014) 解決了這個問題:decoder 在每個 step 可以 attend to all encoder hidden states,不只是最後一個。這是 Transformer 的前身。
每個 decoder step 有不同的 context vector — dynamically focus on relevant parts of the input。
Practical Architecture Patterns
Stacking Layers
import torch.nn as nn
# Multi-layer LSTM (stacked)
lstm = nn.LSTM(
input_size=128,
hidden_size=256,
num_layers=3, # 3 stacked LSTM layers
batch_first=True,
dropout=0.3, # dropout between layers (not on last)
bidirectional=True,
)
# Output shape: [batch, seq_len, 2 × 256] (bidirectional)
# Hidden shape: [2 × 3, batch, 256] (2 directions × 3 layers)
Packing Variable-Length Sequences
Real-world sequences have different lengths. Padding + packing handles this efficiently:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# sequences: list of tensors with different lengths
# lengths: actual length of each sequence
packed = pack_padded_sequence(padded_input, lengths, batch_first=True, enforce_sorted=False)
output, (h_n, c_n) = lstm(packed)
output, _ = pad_packed_sequence(output, batch_first=True)
# Packing tells LSTM to ignore padded positions → correct gradients
不用 packing → LSTM 會在 padding tokens 上做無意義的計算 → 浪費 compute + potentially 汙染 hidden state。
Common Patterns
| Pattern | Architecture | Use Case |
|---|---|---|
| Many-to-one | LSTM → take last hidden → FC | Sentiment analysis, classification |
| Many-to-many | LSTM → FC at each step | NER, POS tagging |
| Encoder-decoder | LSTM encoder → LSTM decoder | Translation, summarization |
| Encoder + attention | LSTM encoder → attention → decoder | Pre-Transformer MT |
Why Transformers Replaced RNNs
| Aspect | RNNs | Transformers |
|---|---|---|
| Parallelization | Sequential — step depends on | Fully parallel — all positions computed simultaneously |
| Long-range dependencies | Difficult even with LSTM gates | Direct attention between any two positions |
| Training speed | Slow(sequential + limited GPU utilization) | Much faster on GPUs(parallel + matrix ops) |
| Memory capacity | Fixed hidden state size | Attention scales with sequence length |
| Gradient flow | Vanishing/exploding(even with gating) | Direct paths via residual connections |
| Positional info | Implicit(order from sequential processing) | Must be explicitly added(positional encoding) |
When RNNs Still Win
| Scenario | Why RNN | Example |
|---|---|---|
| Streaming / online | Data arrives one step at a time → RNN naturally incremental | Real-time speech, live sensor data |
| Very long sequences | Transformer attention is ; RNN is per step | Genome sequences (millions of bases) |
| Edge / mobile | RNNs smaller, simpler, less memory | On-device keyboards, embedded systems |
| Sequential decision | Hidden state = natural "world model" | RL environments |
面試中的 RNN vs Transformer
不要說「RNN 已經沒用了」。正確答案:「大多數 NLP/sequence tasks 用 Transformer 更好(parallelism + long-range dependencies)。但 RNN 在 streaming scenarios(real-time inference)、very long sequences(O(n) vs O(n²))、和 resource-constrained environments 仍然有優勢。」
Real-World Use Cases
Case 1: 信用卡詐欺偵測 — Transaction Sequences
Fraud patterns 常是 sequential 的 — 不是單筆交易有問題,而是一系列交易的 pattern 異常。
# Transaction sequence: [amount, merchant_cat, time_diff, is_international, ...]
# Each user has a variable-length sequence of recent transactions
class FraudLSTM(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, 1)
def forward(self, x, lengths):
packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
_, (h_n, _) = self.lstm(packed)
return torch.sigmoid(self.fc(h_n.squeeze(0))) # P(fraud)
面試 follow-up:「為什麼不用 Transformer?」— 可以用,但 fraud detection 的 sequence 通常較短(最近 50-100 筆交易)→ LSTM 足夠。而且 real-time scoring 需要 incremental processing(新交易來一筆就更新 hidden state),RNN 的 online inference 更自然。
Case 2: 推薦系統 — Session-Based Recommendation
用 user 的 browsing session(click sequence)預測下一個要點的 item:
| Approach | Model | How |
|---|---|---|
| GRU4Rec | GRU | Session 中每次 click 更新 hidden state → predict next item |
| SASRec | Transformer | Self-attention on click sequence → predict next |
| BERT4Rec | Bidirectional Transformer | Masked item prediction(like MLM) |
GRU4Rec 是 session-based recommendation 的開山之作。Transformer-based methods(SASRec, BERT4Rec)在 longer sessions 上通常更好。
Case 3: 時間序列預測
LSTM 在 time series forecasting 中曾經是 standard(pre-Transformer era):
class TimeSeriesLSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, forecast_horizon):
super().__init__()
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=2, batch_first=True, dropout=0.2)
self.fc = nn.Linear(hidden_dim, forecast_horizon)
def forward(self, x):
# x: [batch, lookback_window, features]
output, (h_n, _) = self.lstm(x)
return self.fc(h_n[-1]) # use last layer's hidden state
面試 follow-up:「LSTM 和 traditional time series 方法(ARIMA, Prophet)比?」— LSTM 可以 handle multivariate inputs 和 non-linear patterns。但 ARIMA/Prophet 在 univariate, well-structured time series 上通常更好且更 interpretable。現代 trend 是用 Temporal Fusion Transformer 或 N-BEATS。
Hands-on: RNN & LSTM in PyTorch
Basic LSTM
import torch
import torch.nn as nn
seq_len, batch, input_size, hidden_size = 50, 1, 3, 8
# Vanilla RNN
rnn = nn.RNN(input_size, hidden_size, batch_first=True)
x = torch.randn(batch, seq_len, input_size)
output_rnn, h_rnn = rnn(x)
# h_rnn: [1, 1, 8] — single hidden state
# LSTM
lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
output_lstm, (h_lstm, c_lstm) = lstm(x)
# h_lstm: hidden state, c_lstm: cell state
# Bidirectional LSTM
bilstm = nn.LSTM(input_size, hidden_size, num_layers=2, batch_first=True, bidirectional=True)
output_bi, _ = bilstm(x)
# output: [1, 50, 16] — 2 × hidden_size
Text Classification with LSTM
class TextClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
self.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(hidden_dim * 2, num_classes), # 2x for bidirectional
)
def forward(self, x, lengths):
emb = self.embedding(x)
packed = pack_padded_sequence(emb, lengths, batch_first=True, enforce_sorted=False)
_, (h_n, _) = self.lstm(packed)
# h_n: [2, batch, hidden] for bidirectional
# Concatenate forward and backward final hidden states
h = torch.cat([h_n[0], h_n[1]], dim=1) # [batch, hidden*2]
return self.fc(h)
Interview Signals
What interviewers listen for:
- 你能解釋 vanishing gradient 的數學原因(Jacobian product → exponential shrinkage)
- 你知道 LSTM 的 additive cell state update 和 ResNet skip connections 的相似性
- 你能比較 LSTM 和 GRU 的 tradeoffs(不是只說「GRU 更快」)
- 你知道 Transformer 為什麼贏,但也知道 RNN 在什麼場景仍然有用
- 你理解 Bidirectional RNN 的限制(不能用於 autoregressive generation)
Practice
Flashcards
Flashcards (1/10)
LSTM 的三個 gates 分別控制什麼?
Forget gate: 決定從 cell state 中 erase 什麼(sigmoid ≈ 0 → forget)。Input gate: 決定寫入什麼新資訊到 cell state。Output gate: 決定 cell state 的哪部分要 expose 到 hidden state output。三個 gates 都用 sigmoid(0-1 → 閥門控制)。
Quiz
Why can't vanilla RNNs learn long-range dependencies effectively?