RNN & LSTM

Interview Context

雖然 Transformers 在大多數場景已取代 RNNs,RNN/LSTM 仍是面試高頻題 — 因為它們考驗你對 sequential processing、vanishing gradients、gating mechanisms 的理解。你還要能清楚說明為什麼 Transformers 贏了

Vanilla RNN

A recurrent neural network processes sequences by maintaining a hidden state updated at each time step:

ht=tanh(Whhht1+Wxhxt+bh)\mathbf{h}_t = \tanh(\mathbf{W}_{hh}\mathbf{h}_{t-1} + \mathbf{W}_{xh}\mathbf{x}_t + \mathbf{b}_h) yt=Whyht+by\mathbf{y}_t = \mathbf{W}_{hy}\mathbf{h}_t + \mathbf{b}_y

Hidden state ht\mathbf{h}_t 是 network 的「memory」— 它 encode 了所有 previous inputs 的資訊。Same weight matrices 在所有 time steps shared(parameter sharing)。

Key Properties

PropertyDetail
Arbitrary length可以處理任意長度的 sequence
Fixed hidden size不管 sequence 多長,hidden state 維度不變
Parameter sharing所有 time steps 共用同一組 weights → parameter count 和 sequence length 無關
Sequential必須一步一步計算 → 不能 parallelize → 慢

Backpropagation Through Time (BPTT)

To train an RNN, "unroll" it through time and apply backpropagation:

LWhh=t=1TLtWhh\frac{\partial \mathcal{L}}{\partial \mathbf{W}_{hh}} = \sum_{t=1}^{T} \frac{\partial \mathcal{L}_t}{\partial \mathbf{W}_{hh}}

每個 term 涉及 Jacobians 的乘積:

Lthk=Lthti=k+1thihi1\frac{\partial \mathcal{L}_t}{\partial \mathbf{h}_k} = \frac{\partial \mathcal{L}_t}{\partial \mathbf{h}_t} \prod_{i=k+1}^{t} \frac{\partial \mathbf{h}_i}{\partial \mathbf{h}_{i-1}}

Each factor hihi1=diag(tanh(zi))Whh\frac{\partial \mathbf{h}_i}{\partial \mathbf{h}_{i-1}} = \text{diag}(\tanh'(\mathbf{z}_i)) \cdot \mathbf{W}_{hh} 涉及 weight matrix 和 activation derivative。

The Vanishing Gradient Problem

tanh(x)1|\tanh'(x)| \leq 1 且通常 Whh<1\|\mathbf{W}_{hh}\| < 1 → gradient product shrinks exponentially:

i=k+1thihi1(γ)tk,γ<1\left\|\prod_{i=k+1}^{t}\frac{\partial \mathbf{h}_i}{\partial \mathbf{h}_{i-1}}\right\| \leq (\gamma)^{t-k}, \quad \gamma < 1

Long sequences (tk1t - k \gg 1) → gradient effectively vanishes → network cannot learn long-range dependencies

Practical Impact

Vanilla RNN 通常無法學到超過 10-20 time steps 的 dependency。需要記住 50 或 100 步之前的資訊(長文件、對話)→ vanilla RNN 會完全失敗。這就是 LSTM 和 GRU 的核心動機。

Truncated BPTT: 實務中只 backpropagate fixed number of steps(例如 35)— 犧牲 long-range learning 換取 computational efficiency。

LSTM (Long Short-Term Memory)

LSTMs (Hochreiter & Schmidhuber, 1997) 用 cell state ct\mathbf{c}_t — 一個 separate memory pathway with carefully controlled gates — 解決 vanishing gradient problem。

The Three Gates + Cell State

Forget Gate — decides what to erase from memory:

ft=σ(Wf[ht1,xt]+bf)\mathbf{f}_t = \sigma(\mathbf{W}_f[\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f)

Input Gate — decides what new information to store:

it=σ(Wi[ht1,xt]+bi)c~t=tanh(Wc[ht1,xt]+bc)\mathbf{i}_t = \sigma(\mathbf{W}_i[\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i) \qquad \tilde{\mathbf{c}}_t = \tanh(\mathbf{W}_c[\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c)

Cell State Update — the critical equation:

ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t

Output Gate — decides what to expose from memory:

ot=σ(Wo[ht1,xt]+bo)ht=ottanh(ct)\mathbf{o}_t = \sigma(\mathbf{W}_o[\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o) \qquad \mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{c}_t)

Gate Intuition

GateSigmoid OutputIntuition
Forget ft\mathbf{f}_tNear 1 = keep, near 0 = erase「這個 memory 還有用嗎?」
Input it\mathbf{i}_tNear 1 = write, near 0 = ignore「這個新資訊值得記住嗎?」
Output ot\mathbf{o}_tNear 1 = expose, near 0 = hide「現在需要用到這個 memory 嗎?」

Sigmoid output 在 (0,1) → 控制 information flow 的「閥門」。Tanh 在 (-1,1) → 產出 candidate values。

Why LSTMs Solve Vanishing Gradients

Cell state update 是 additive(不是 multiplicative):

ctct1=ft\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}} = \mathbf{f}_t

When ft1\mathbf{f}_t \approx 1(forget gate open)→ gradients flow through unchanged — no multiplicative shrinkage.

這和 ResNet 的 skip connections 是同一個原理 — additive path 讓 gradient 可以直接通過,不被 multiply/shrink。

LSTM Parameter Count

LSTM with hidden size hh and input size xx: 4×(h×(h+x)+h)4 \times (h \times (h + x) + h) parameters — 4x of vanilla RNN(因為有 4 個 gate weight matrices:forget, input, cell candidate, output)。

Forget Gate Bias Initialization

Important Training Detail

Forget gate bias 通常初始化為 1 或 2(而非 0)。原因:初始化 bias = 0 → sigmoid(0) = 0.5 → 一開始就 forget 一半的 memory → gradient flow 不好。Bias = 1 → sigmoid(1) ≈ 0.73 → 初始時傾向 keep memory → gradient flow through cell state → 更好的 training。這是 Jozefowicz et al., 2015 的重要 finding。

GRU (Gated Recurrent Unit)

GRU (Cho et al., 2014) simplifies LSTM by merging forget + input gates into one update gate and combining cell + hidden state:

Update Gate: zt=σ(Wz[ht1,xt])\mathbf{z}_t = \sigma(\mathbf{W}_z[\mathbf{h}_{t-1}, \mathbf{x}_t])

Reset Gate: rt=σ(Wr[ht1,xt])\mathbf{r}_t = \sigma(\mathbf{W}_r[\mathbf{h}_{t-1}, \mathbf{x}_t])

Hidden State:

h~t=tanh(W[rtht1,xt])\tilde{\mathbf{h}}_t = \tanh(\mathbf{W}[\mathbf{r}_t \odot \mathbf{h}_{t-1}, \mathbf{x}_t]) ht=(1zt)ht1+zth~t\mathbf{h}_t = (1 - \mathbf{z}_t) \odot \mathbf{h}_{t-1} + \mathbf{z}_t \odot \tilde{\mathbf{h}}_t

GRU vs LSTM

AspectLSTMGRU
Gates3 (forget, input, output)2 (update, reset)
StatesHidden + Cell(separate memory)Hidden only(dual purpose)
Parameters4(h(h+x)+h)4(h(h+x)+h)3(h(h+x)+h)3(h(h+x)+h) — 25% fewer
PerformanceSlightly better on long sequencesComparable on most tasks
Training speedSlowerFaster
When to useVery long dependencies, memory-criticalDefault choice when LSTM not clearly better

面試中被問「GRU vs LSTM?」— GRU 更簡單、更快,在大多數 tasks 上 performance comparable。LSTM 在非常長的 sequences 上可能略好(因為 separate cell state provides more memory capacity)。實務中先試 GRU,不夠再換 LSTM。

Bidirectional RNNs

Standard RNNs only use past context. Bidirectional RNNs run two separate RNNs — forward and backward — and concatenate:

ht=[ht;ht]\mathbf{h}_t = [\overrightarrow{\mathbf{h}}_t; \overleftarrow{\mathbf{h}}_t]

每個 position 同時看到 past 和 future context → output dimension = 2 × hidden_size。

Use CaseBidirectional?Why
NER / POS taggingYes需要前後 context 判斷詞性
Sentiment analysisYes整個 sentence 都已知
Language modelingNoAutoregressive — 不能看未來
Speech recognition (offline)Yes整段 audio 都已錄好
Real-time translationNo必須 incrementally process

Bidirectional ≠ Always Better

Bidirectional 只在整個 sequence 都已知的場景才能用。Autoregressive generation(語言模型、翻譯 decoder)不能用 — 因為 backward RNN 需要 future tokens 但 generation 時 future tokens 還不存在。

Sequence-to-Sequence (Seq2Seq)

Encoder-decoder architecture for tasks like machine translation:

  1. Encoder: Process input sequence → final hidden state c\mathbf{c}(context vector)
  2. Decoder: Generate output sequence one token at a time, conditioned on c\mathbf{c}
c=Encoder(x1,,xT)yt=Decoder(c,y1,,yt1)\mathbf{c} = \text{Encoder}(x_1, \ldots, x_T) \qquad y_t = \text{Decoder}(\mathbf{c}, y_1, \ldots, y_{t-1})

The Bottleneck Problem

把整個 input sequence 壓縮成 single fixed-size vector c\mathbf{c} → severe bottleneck — longer sequences 丟失更多 information。

Attention mechanism (Bahdanau et al., 2014) 解決了這個問題:decoder 在每個 step 可以 attend to all encoder hidden states,不只是最後一個。這是 Transformer 的前身。

αtj=softmax(etj),ct=jαtjhj\alpha_{tj} = \text{softmax}(e_{tj}), \quad \mathbf{c}_t = \sum_j \alpha_{tj} \mathbf{h}_j

每個 decoder step tt 有不同的 context vector ct\mathbf{c}_t — dynamically focus on relevant parts of the input。

Practical Architecture Patterns

Stacking Layers

import torch.nn as nn

# Multi-layer LSTM (stacked)
lstm = nn.LSTM(
    input_size=128,
    hidden_size=256,
    num_layers=3,       # 3 stacked LSTM layers
    batch_first=True,
    dropout=0.3,        # dropout between layers (not on last)
    bidirectional=True,
)
# Output shape: [batch, seq_len, 2 × 256] (bidirectional)
# Hidden shape: [2 × 3, batch, 256] (2 directions × 3 layers)

Packing Variable-Length Sequences

Real-world sequences have different lengths. Padding + packing handles this efficiently:

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# sequences: list of tensors with different lengths
# lengths: actual length of each sequence
packed = pack_padded_sequence(padded_input, lengths, batch_first=True, enforce_sorted=False)
output, (h_n, c_n) = lstm(packed)
output, _ = pad_packed_sequence(output, batch_first=True)
# Packing tells LSTM to ignore padded positions → correct gradients

不用 packing → LSTM 會在 padding tokens 上做無意義的計算 → 浪費 compute + potentially 汙染 hidden state。

Common Patterns

PatternArchitectureUse Case
Many-to-oneLSTM → take last hidden → FCSentiment analysis, classification
Many-to-manyLSTM → FC at each stepNER, POS tagging
Encoder-decoderLSTM encoder → LSTM decoderTranslation, summarization
Encoder + attentionLSTM encoder → attention → decoderPre-Transformer MT

Why Transformers Replaced RNNs

AspectRNNsTransformers
ParallelizationSequential — step tt depends on t1t-1Fully parallel — all positions computed simultaneously
Long-range dependenciesDifficult even with LSTM gatesDirect attention between any two positions
Training speedSlow(sequential + limited GPU utilization)Much faster on GPUs(parallel + matrix ops)
Memory capacityFixed hidden state sizeAttention scales with sequence length
Gradient flowVanishing/exploding(even with gating)Direct paths via residual connections
Positional infoImplicit(order from sequential processing)Must be explicitly added(positional encoding)

When RNNs Still Win

ScenarioWhy RNNExample
Streaming / onlineData arrives one step at a time → RNN naturally incrementalReal-time speech, live sensor data
Very long sequencesTransformer attention is O(n2)O(n^2); RNN is O(n)O(n) per stepGenome sequences (millions of bases)
Edge / mobileRNNs smaller, simpler, less memoryOn-device keyboards, embedded systems
Sequential decisionHidden state = natural "world model"RL environments

面試中的 RNN vs Transformer

不要說「RNN 已經沒用了」。正確答案:「大多數 NLP/sequence tasks 用 Transformer 更好(parallelism + long-range dependencies)。但 RNN 在 streaming scenarios(real-time inference)、very long sequences(O(n) vs O(n²))、和 resource-constrained environments 仍然有優勢。」

Real-World Use Cases

Case 1: 信用卡詐欺偵測 — Transaction Sequences

Fraud patterns 常是 sequential 的 — 不是單筆交易有問題,而是一系列交易的 pattern 異常。

# Transaction sequence: [amount, merchant_cat, time_diff, is_international, ...]
# Each user has a variable-length sequence of recent transactions
class FraudLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x, lengths):
        packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        _, (h_n, _) = self.lstm(packed)
        return torch.sigmoid(self.fc(h_n.squeeze(0)))  # P(fraud)

面試 follow-up:「為什麼不用 Transformer?」— 可以用,但 fraud detection 的 sequence 通常較短(最近 50-100 筆交易)→ LSTM 足夠。而且 real-time scoring 需要 incremental processing(新交易來一筆就更新 hidden state),RNN 的 online inference 更自然。

Case 2: 推薦系統 — Session-Based Recommendation

用 user 的 browsing session(click sequence)預測下一個要點的 item:

ApproachModelHow
GRU4RecGRUSession 中每次 click 更新 hidden state → predict next item
SASRecTransformerSelf-attention on click sequence → predict next
BERT4RecBidirectional TransformerMasked item prediction(like MLM)

GRU4Rec 是 session-based recommendation 的開山之作。Transformer-based methods(SASRec, BERT4Rec)在 longer sessions 上通常更好。

Case 3: 時間序列預測

LSTM 在 time series forecasting 中曾經是 standard(pre-Transformer era):

class TimeSeriesLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, forecast_horizon):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=2, batch_first=True, dropout=0.2)
        self.fc = nn.Linear(hidden_dim, forecast_horizon)

    def forward(self, x):
        # x: [batch, lookback_window, features]
        output, (h_n, _) = self.lstm(x)
        return self.fc(h_n[-1])  # use last layer's hidden state

面試 follow-up:「LSTM 和 traditional time series 方法(ARIMA, Prophet)比?」— LSTM 可以 handle multivariate inputs 和 non-linear patterns。但 ARIMA/Prophet 在 univariate, well-structured time series 上通常更好且更 interpretable。現代 trend 是用 Temporal Fusion Transformer 或 N-BEATS。

Hands-on: RNN & LSTM in PyTorch

Basic LSTM

import torch
import torch.nn as nn

seq_len, batch, input_size, hidden_size = 50, 1, 3, 8

# Vanilla RNN
rnn = nn.RNN(input_size, hidden_size, batch_first=True)
x = torch.randn(batch, seq_len, input_size)
output_rnn, h_rnn = rnn(x)
# h_rnn: [1, 1, 8] — single hidden state

# LSTM
lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
output_lstm, (h_lstm, c_lstm) = lstm(x)
# h_lstm: hidden state, c_lstm: cell state

# Bidirectional LSTM
bilstm = nn.LSTM(input_size, hidden_size, num_layers=2, batch_first=True, bidirectional=True)
output_bi, _ = bilstm(x)
# output: [1, 50, 16] — 2 × hidden_size

Text Classification with LSTM

class TextClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(hidden_dim * 2, num_classes),  # 2x for bidirectional
        )

    def forward(self, x, lengths):
        emb = self.embedding(x)
        packed = pack_padded_sequence(emb, lengths, batch_first=True, enforce_sorted=False)
        _, (h_n, _) = self.lstm(packed)
        # h_n: [2, batch, hidden] for bidirectional
        # Concatenate forward and backward final hidden states
        h = torch.cat([h_n[0], h_n[1]], dim=1)  # [batch, hidden*2]
        return self.fc(h)

Interview Signals

What interviewers listen for:

  • 你能解釋 vanishing gradient 的數學原因(Jacobian product → exponential shrinkage)
  • 你知道 LSTM 的 additive cell state update 和 ResNet skip connections 的相似性
  • 你能比較 LSTM 和 GRU 的 tradeoffs(不是只說「GRU 更快」)
  • 你知道 Transformer 為什麼贏,但也知道 RNN 在什麼場景仍然有用
  • 你理解 Bidirectional RNN 的限制(不能用於 autoregressive generation)

Practice

Flashcards

Flashcards (1/10)

LSTM 的三個 gates 分別控制什麼?

Forget gate: 決定從 cell state 中 erase 什麼(sigmoid ≈ 0 → forget)。Input gate: 決定寫入什麼新資訊到 cell state。Output gate: 決定 cell state 的哪部分要 expose 到 hidden state output。三個 gates 都用 sigmoid(0-1 → 閥門控制)。

Click card to flip

Quiz

Question 1/10

Why can't vanilla RNNs learn long-range dependencies effectively?

Mark as Complete

3/5 — Okay