1. 前言
上一篇我们已经从原理上认识了LSTM(长短期记忆网络):
它是门控循环神经网络
它引入了独立的记忆单元
C_t它通过遗忘门、输入门、输出门管理信息流
它比基础 RNN 更适合处理长期依赖
这一篇就继续按照李沐的节奏,把这些公式真正落到代码上。
这一节最核心的任务就是看清楚:
LSTM 从零实现时到底要多写哪些参数
H_t和C_t在代码里怎么同时维护三个门和候选记忆如何逐步计算
PyTorch 里的
nn.LSTM又把哪些部分封装掉了
你会发现,LSTM 代码虽然比 RNN、GRU 更长一些,但逻辑其实非常清晰:
先算门,再更新记忆单元,最后从记忆单元里读出隐藏状态。
2. LSTM 从零实现需要解决什么
如果把这一节拆开看,核心其实还是 4 件事,只是比 GRU 多维护一个状态。
2.1 初始化更多参数
LSTM 需要为:
输入门
遗忘门
输出门
候选记忆
分别准备参数。
2.2 同时维护两个状态
不是只有隐藏状态H_t,还要维护记忆单元C_t。
2.3 按照门控公式更新状态
先更新记忆单元,再更新隐藏状态。
2.4 保持语言模型接口一致
虽然内部更复杂,但外部仍然要能接:
输入序列
初始状态
输出结果
最终状态
所以从工程角度看,LSTM 还是一个“循环单元替换”的问题。
3. 先回顾 LSTM 的核心公式
写代码前,先把最重要的几条公式再摆一下。
输入门
I_t = σ(X_t W_xi + H_{t-1} W_hi + b_i)遗忘门
F_t = σ(X_t W_xf + H_{t-1} W_hf + b_f)输出门
O_t = σ(X_t W_xo + H_{t-1} W_ho + b_o)候选记忆
C_t_tilde = tanh(X_t W_xc + H_{t-1} W_hc + b_c)当前记忆单元
C_t = F_t ⊙ C_{t-1} + I_t ⊙ C_t_tilde当前隐藏状态
H_t = O_t ⊙ tanh(C_t)这六步,就是整段 LSTM 前向传播的主线。
4. 从零实现:先初始化参数
LSTM 比 GRU 还要多一组门,所以参数自然更多。
常见写法如下:
def get_lstm_params(vocab_size, num_hiddens, device): num_inputs = num_outputs = vocab_size def normal(shape): return torch.randn(size=shape, device=device) * 0.01 def three(): return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)), torch.zeros(num_hiddens, device=device)) W_xi, W_hi, b_i = three() # 输入门 W_xf, W_hf, b_f = three() # 遗忘门 W_xo, W_ho, b_o = three() # 输出门 W_xc, W_hc, b_c = three() # 候选记忆 W_hq = normal((num_hiddens, num_outputs)) b_q = torch.zeros(num_outputs, device=device) params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] for param in params: param.requires_grad_(True) return params这段代码一眼看上去比 RNN、GRU 都长,但结构很规律。
5. 为什么这里是四组核心参数
因为 LSTM 每个时间步要计算四个关键中间量:
输入门
遗忘门
输出门
候选记忆
每一个都需要:
输入到该单元的权重
隐藏状态到该单元的权重
对应偏置
所以会有四组:
(W_x?, W_h?, b_?)这和 GRU 的三组参数是一个套路,只是 LSTM 多了一组。
所以你可以这样记:
RNN:一组
GRU:三组
LSTM:四组
门控越细,参数组越多。
6. LSTM 的状态初始化为什么和前面不同
LSTM 最大的代码差异之一,就是状态不止一个了。
基础 RNN 和 GRU 只返回:
隐藏状态
H
但 LSTM 需要同时维护:
隐藏状态
H记忆单元
C
所以初始化通常写成:
def init_lstm_state(batch_size, num_hiddens, device): return (torch.zeros((batch_size, num_hiddens), device=device), torch.zeros((batch_size, num_hiddens), device=device))也就是说,初始状态是一个二元组:
(H_0, C_0)而且两者初始通常都设为 0。
7. 为什么要同时初始化H和C
因为这两个状态职责不同。
H
当前时刻的隐藏表示,更多用于输出和短期信息表达。
C
长期记忆通道,更多用于跨时间步保存稳定信息。
所以在 LSTM 里,“状态”不是一个向量,而是一对向量。
这点在代码里非常关键,后面前向传播时一定要记得同时更新它们。
8. LSTM 前向传播是这一节的核心
常见从零实现写法如下:
def lstm(inputs, state, params): [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params H, C = state outputs = [] for X in inputs: I = torch.sigmoid(torch.mm(X, W_xi) + torch.mm(H, W_hi) + b_i) F = torch.sigmoid(torch.mm(X, W_xf) + torch.mm(H, W_hf) + b_f) O = torch.sigmoid(torch.mm(X, W_xo) + torch.mm(H, W_ho) + b_o) C_tilde = torch.tanh(torch.mm(X, W_xc) + torch.mm(H, W_hc) + b_c) C = F * C + I * C_tilde H = O * torch.tanh(C) Y = torch.mm(H, W_hq) + b_q outputs.append(Y) return torch.cat(outputs, dim=0), (H, C)这段代码就是 LSTM 从零实现最关键的主体。
9. 输入门代码怎么理解
先看:
I = torch.sigmoid(torch.mm(X, W_xi) + torch.mm(H, W_hi) + b_i)它对应输入门公式:
I_t = σ(X_t W_xi + H_{t-1} W_hi + b_i)它控制的是:
当前时刻新候选记忆,允许写入多少。
如果I大,说明当前输入带来的新信息很重要。
如果I小,说明当前新内容不太值得写进长期记忆。
10. 遗忘门代码怎么理解
再看:
F = torch.sigmoid(torch.mm(X, W_xf) + torch.mm(H, W_hf) + b_f)它对应:
F_t = σ(X_t W_xf + H_{t-1} W_hf + b_f)它控制的是:
旧记忆
C_{t-1}中还要留下多少。
所以它本质上是在做“保留还是遗忘”的判断。
如果F接近 1,就说明旧记忆大部分保留。
如果接近 0,就说明旧记忆大部分被清掉。
11. 输出门代码怎么理解
这一句:
O = torch.sigmoid(torch.mm(X, W_xo) + torch.mm(H, W_ho) + b_o)对应:
O_t = σ(X_t W_xo + H_{t-1} W_ho + b_o)它的作用是:
决定当前记忆单元里有多少信息,要通过隐藏状态对外输出。
你可以理解为:
记忆单元
C是“仓库”输出门
O决定当前要“拿出多少货”
12. 候选记忆代码怎么理解
这一句:
C_tilde = torch.tanh(torch.mm(X, W_xc) + torch.mm(H, W_hc) + b_c)对应:
C_t_tilde = tanh(X_t W_xc + H_{t-1} W_hc + b_c)它是当前时刻生成的一份“新记忆草稿”。
注意,它不会直接成为C_t,
而是还要由输入门决定写进去多少。
所以:
C_tilde是候选内容I是写入强度
这和 GRU 里的候选隐藏状态有一点相似,但 LSTM 的结构更细。
13. 记忆单元更新是最关键的一步
这一句一定要看懂:
C = F * C + I * C_tilde它对应:
C_t = F_t ⊙ C_{t-1} + I_t ⊙ C_t_tilde意思非常明确:
当前记忆 = 保留下来的旧记忆 + 写入的新候选记忆
这条式子其实就是 LSTM 成功的核心。
因为它给长期记忆开辟了一条非常直接的传播路径。
旧记忆可以在门控控制下平滑流动,而不是每一步都被强制重算。
14. 隐藏状态更新为什么还要依赖C
再看:
H = O * torch.tanh(C)这对应:
H_t = O_t ⊙ tanh(C_t)它说明:
隐藏状态
H并不是独立算的它是从当前记忆单元
C中读出来的
所以 LSTM 里真正更核心的状态其实是C,
而H更像是当前时刻对外暴露的“工作结果”。
你可以理解为:
C更偏长期存储H更偏当前输出接口
15. 输出层为什么和前面一样
这一句:
Y = torch.mm(H, W_hq) + b_q和 RNN、GRU 完全一样。
原因也一样:
LSTM 改进的是内部记忆更新方式,
但语言模型最终仍然需要:
根据当前隐藏状态,对整个词表做分类打分
所以输出头部分并没有本质变化。
也就是说,循环单元可以换,但语言模型输出逻辑是统一的。
16. 从零实现时可以继续复用手写模型容器
和前面一样,通常可以直接用类似的封装:
net = d2l.RNNModelScratch(vocab_size, num_hiddens, device, get_lstm_params, init_lstm_state, lstm)这点特别漂亮。
你会发现:
容器类不需要改
输入 one-hot 不需要改
输出预测接口不需要改
只要换掉:
参数初始化函数
状态初始化函数
单步递推函数
模型就从 RNN / GRU 切换成了 LSTM。
这说明什么?
说明循环模型之间的差别,核心就在于内部单元设计。
17. 简洁实现:PyTorch 里的nn.LSTM
从零实现看懂之后,简洁实现就非常自然了。
PyTorch 直接封装了:
nn.LSTM最常见写法类似:
lstm_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens)它的对外用法和nn.RNN、nn.GRU非常相似:
Y, (H, C) = lstm_layer(X, (H0, C0))注意这里最大的区别:
返回的不再只是一个
state,而是(H, C)两部分。
这正对应了我们前面从零实现时手动维护的两个状态。
18. 为什么nn.LSTM的状态结构最特别
因为 LSTM 不像 RNN / GRU 只有一个隐藏状态。
它有两个状态:
H:隐藏状态C:记忆单元
所以在简洁实现里,初始化状态时也要写成:
(H0, C0)而不能只传一个张量。
这是使用nn.LSTM时最容易和nn.GRU、nn.RNN混的地方之一。
19. 简洁实现里的模型封装怎么处理
和简洁版 RNN、GRU 一样,LSTM 外面通常也会再包一层语言模型头。
也就是说:
nn.LSTM负责生成每个时间步的隐藏表示nn.Linear负责把隐藏表示映射到词表大小
所以整体模型结构仍然是:
输入 token
one-hot
LSTM
reshape
linear
输出词表打分
外层套路并没有变,变的是内部状态变复杂了。
20. LSTM 代码和 GRU 代码最本质的区别在哪里
这是这篇最该说透的一点。
GRU
一个状态
H两个门
一个候选状态
最终状态是旧
H和新候选状态的融合
LSTM
两个状态
H和C三个门
一个候选记忆
先更新
C,再从C读出H
所以从代码结构上,LSTM 比 GRU 最大的不同就是:
它把“记忆存储”和“输出表达”拆开了。
这也是为什么它更系统、更经典,但代码也更长一点。
21. LSTM 训练流程和前面有什么变化吗
整体训练逻辑几乎不变。
还是:
输入 token 序列
输出下一个 token 的预测
用交叉熵损失
反向传播
梯度裁剪
参数更新
变化的不是训练外壳,而是:
状态更新公式更复杂,且要同时维护
(H, C)
所以,如果你前面已经把 RNN、GRU 训练管道搞清楚了,
LSTM 在训练框架上其实不会让你特别陌生。
22. 这一节最该掌握什么
如果从学习重点来看,最关键的是这几件事。
22.1 明白为什么要有四组参数
因为输入门、遗忘门、输出门、候选记忆都要单独计算。
22.2 看懂C和H的职责分工
C:长期记忆H:当前输出表征
22.3 看懂C = F * C + I * C_tilde
这是 LSTM 最核心的一行代码。
22.4 知道nn.LSTM的状态是(H, C)
这是和nn.RNN、nn.GRU最大的接口差别。
22.5 把它和 GRU 对照起来理解
这样你不会只记住“LSTM 更复杂”,而是真知道复杂在哪。
23. 本节总结
这一节我们学习了 LSTM 的代码实现,核心内容可以总结为以下几点。
23.1 LSTM 从零实现需要四组核心参数
分别对应:
输入门
遗忘门
输出门
候选记忆
23.2 LSTM 同时维护隐藏状态和记忆单元
这让长期记忆流动更稳定。
23.3 记忆单元更新是 LSTM 的关键
旧记忆和新候选记忆在门控作用下融合。
23.4 隐藏状态是从记忆单元中读出来的
因此H和C职责不同。
23.5nn.LSTM是对这一整套机制的框架封装
使用时最需要注意的是状态结构为(H, C)。
24. 学习感悟
LSTM 代码这一节非常有代表性,因为它让你真正看到:
一个模型变强,不一定是因为网络更深了,而可能是因为“状态管理更精细了”。
RNN 像是在“顺着时间流动”;
GRU 像是在“加门控做筛选”;
而 LSTM 则进一步把记忆流程分成:
擦除
写入
读取
这已经非常像一个真正的记忆系统了。
所以学完 LSTM 代码之后,你对“循环神经网络为什么能记忆”这件事,理解会明显更深一层。