news 2026/4/30 22:12:57

动手学深度学习——LSTM代码

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
动手学深度学习——LSTM代码

1. 前言

上一篇我们已经从原理上认识了LSTM(长短期记忆网络)

  • 它是门控循环神经网络

  • 它引入了独立的记忆单元C_t

  • 它通过遗忘门、输入门、输出门管理信息流

  • 它比基础 RNN 更适合处理长期依赖

这一篇就继续按照李沐的节奏,把这些公式真正落到代码上。

这一节最核心的任务就是看清楚:

  • LSTM 从零实现时到底要多写哪些参数

  • H_tC_t在代码里怎么同时维护

  • 三个门和候选记忆如何逐步计算

  • PyTorch 里的nn.LSTM又把哪些部分封装掉了

你会发现,LSTM 代码虽然比 RNN、GRU 更长一些,但逻辑其实非常清晰:

先算门,再更新记忆单元,最后从记忆单元里读出隐藏状态。


2. LSTM 从零实现需要解决什么

如果把这一节拆开看,核心其实还是 4 件事,只是比 GRU 多维护一个状态。

2.1 初始化更多参数

LSTM 需要为:

  • 输入门

  • 遗忘门

  • 输出门

  • 候选记忆

分别准备参数。

2.2 同时维护两个状态

不是只有隐藏状态H_t,还要维护记忆单元C_t

2.3 按照门控公式更新状态

先更新记忆单元,再更新隐藏状态。

2.4 保持语言模型接口一致

虽然内部更复杂,但外部仍然要能接:

  • 输入序列

  • 初始状态

  • 输出结果

  • 最终状态

所以从工程角度看,LSTM 还是一个“循环单元替换”的问题。


3. 先回顾 LSTM 的核心公式

写代码前,先把最重要的几条公式再摆一下。

输入门

I_t = σ(X_t W_xi + H_{t-1} W_hi + b_i)

遗忘门

F_t = σ(X_t W_xf + H_{t-1} W_hf + b_f)

输出门

O_t = σ(X_t W_xo + H_{t-1} W_ho + b_o)

候选记忆

C_t_tilde = tanh(X_t W_xc + H_{t-1} W_hc + b_c)

当前记忆单元

C_t = F_t ⊙ C_{t-1} + I_t ⊙ C_t_tilde

当前隐藏状态

H_t = O_t ⊙ tanh(C_t)

这六步,就是整段 LSTM 前向传播的主线。


4. 从零实现:先初始化参数

LSTM 比 GRU 还要多一组门,所以参数自然更多。

常见写法如下:

def get_lstm_params(vocab_size, num_hiddens, device): num_inputs = num_outputs = vocab_size def normal(shape): return torch.randn(size=shape, device=device) * 0.01 def three(): return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)), torch.zeros(num_hiddens, device=device)) W_xi, W_hi, b_i = three() # 输入门 W_xf, W_hf, b_f = three() # 遗忘门 W_xo, W_ho, b_o = three() # 输出门 W_xc, W_hc, b_c = three() # 候选记忆 W_hq = normal((num_hiddens, num_outputs)) b_q = torch.zeros(num_outputs, device=device) params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] for param in params: param.requires_grad_(True) return params

这段代码一眼看上去比 RNN、GRU 都长,但结构很规律。


5. 为什么这里是四组核心参数

因为 LSTM 每个时间步要计算四个关键中间量:

  • 输入门

  • 遗忘门

  • 输出门

  • 候选记忆

每一个都需要:

  • 输入到该单元的权重

  • 隐藏状态到该单元的权重

  • 对应偏置

所以会有四组:

(W_x?, W_h?, b_?)

这和 GRU 的三组参数是一个套路,只是 LSTM 多了一组。

所以你可以这样记:

  • RNN:一组

  • GRU:三组

  • LSTM:四组

门控越细,参数组越多。


6. LSTM 的状态初始化为什么和前面不同

LSTM 最大的代码差异之一,就是状态不止一个了。

基础 RNN 和 GRU 只返回:

  • 隐藏状态H

但 LSTM 需要同时维护:

  • 隐藏状态H

  • 记忆单元C

所以初始化通常写成:

def init_lstm_state(batch_size, num_hiddens, device): return (torch.zeros((batch_size, num_hiddens), device=device), torch.zeros((batch_size, num_hiddens), device=device))

也就是说,初始状态是一个二元组:

(H_0, C_0)

而且两者初始通常都设为 0。


7. 为什么要同时初始化HC

因为这两个状态职责不同。

H

当前时刻的隐藏表示,更多用于输出和短期信息表达。

C

长期记忆通道,更多用于跨时间步保存稳定信息。

所以在 LSTM 里,“状态”不是一个向量,而是一对向量。

这点在代码里非常关键,后面前向传播时一定要记得同时更新它们。


8. LSTM 前向传播是这一节的核心

常见从零实现写法如下:

def lstm(inputs, state, params): [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params H, C = state outputs = [] for X in inputs: I = torch.sigmoid(torch.mm(X, W_xi) + torch.mm(H, W_hi) + b_i) F = torch.sigmoid(torch.mm(X, W_xf) + torch.mm(H, W_hf) + b_f) O = torch.sigmoid(torch.mm(X, W_xo) + torch.mm(H, W_ho) + b_o) C_tilde = torch.tanh(torch.mm(X, W_xc) + torch.mm(H, W_hc) + b_c) C = F * C + I * C_tilde H = O * torch.tanh(C) Y = torch.mm(H, W_hq) + b_q outputs.append(Y) return torch.cat(outputs, dim=0), (H, C)

这段代码就是 LSTM 从零实现最关键的主体。


9. 输入门代码怎么理解

先看:

I = torch.sigmoid(torch.mm(X, W_xi) + torch.mm(H, W_hi) + b_i)

它对应输入门公式:

I_t = σ(X_t W_xi + H_{t-1} W_hi + b_i)

它控制的是:

当前时刻新候选记忆,允许写入多少。

如果I大,说明当前输入带来的新信息很重要。
如果I小,说明当前新内容不太值得写进长期记忆。


10. 遗忘门代码怎么理解

再看:

F = torch.sigmoid(torch.mm(X, W_xf) + torch.mm(H, W_hf) + b_f)

它对应:

F_t = σ(X_t W_xf + H_{t-1} W_hf + b_f)

它控制的是:

旧记忆C_{t-1}中还要留下多少。

所以它本质上是在做“保留还是遗忘”的判断。

如果F接近 1,就说明旧记忆大部分保留。
如果接近 0,就说明旧记忆大部分被清掉。


11. 输出门代码怎么理解

这一句:

O = torch.sigmoid(torch.mm(X, W_xo) + torch.mm(H, W_ho) + b_o)

对应:

O_t = σ(X_t W_xo + H_{t-1} W_ho + b_o)

它的作用是:

决定当前记忆单元里有多少信息,要通过隐藏状态对外输出。

你可以理解为:

  • 记忆单元C是“仓库”

  • 输出门O决定当前要“拿出多少货”


12. 候选记忆代码怎么理解

这一句:

C_tilde = torch.tanh(torch.mm(X, W_xc) + torch.mm(H, W_hc) + b_c)

对应:

C_t_tilde = tanh(X_t W_xc + H_{t-1} W_hc + b_c)

它是当前时刻生成的一份“新记忆草稿”。

注意,它不会直接成为C_t
而是还要由输入门决定写进去多少。

所以:

  • C_tilde是候选内容

  • I是写入强度

这和 GRU 里的候选隐藏状态有一点相似,但 LSTM 的结构更细。


13. 记忆单元更新是最关键的一步

这一句一定要看懂:

C = F * C + I * C_tilde

它对应:

C_t = F_t ⊙ C_{t-1} + I_t ⊙ C_t_tilde

意思非常明确:

当前记忆 = 保留下来的旧记忆 + 写入的新候选记忆

这条式子其实就是 LSTM 成功的核心。

因为它给长期记忆开辟了一条非常直接的传播路径。
旧记忆可以在门控控制下平滑流动,而不是每一步都被强制重算。


14. 隐藏状态更新为什么还要依赖C

再看:

H = O * torch.tanh(C)

这对应:

H_t = O_t ⊙ tanh(C_t)

它说明:

  • 隐藏状态H并不是独立算的

  • 它是从当前记忆单元C中读出来的

所以 LSTM 里真正更核心的状态其实是C
H更像是当前时刻对外暴露的“工作结果”。

你可以理解为:

  • C更偏长期存储

  • H更偏当前输出接口


15. 输出层为什么和前面一样

这一句:

Y = torch.mm(H, W_hq) + b_q

和 RNN、GRU 完全一样。

原因也一样:

LSTM 改进的是内部记忆更新方式,
但语言模型最终仍然需要:

根据当前隐藏状态,对整个词表做分类打分

所以输出头部分并没有本质变化。

也就是说,循环单元可以换,但语言模型输出逻辑是统一的。


16. 从零实现时可以继续复用手写模型容器

和前面一样,通常可以直接用类似的封装:

net = d2l.RNNModelScratch(vocab_size, num_hiddens, device, get_lstm_params, init_lstm_state, lstm)

这点特别漂亮。

你会发现:

  • 容器类不需要改

  • 输入 one-hot 不需要改

  • 输出预测接口不需要改

只要换掉:

  • 参数初始化函数

  • 状态初始化函数

  • 单步递推函数

模型就从 RNN / GRU 切换成了 LSTM。

这说明什么?

说明循环模型之间的差别,核心就在于内部单元设计


17. 简洁实现:PyTorch 里的nn.LSTM

从零实现看懂之后,简洁实现就非常自然了。

PyTorch 直接封装了:

nn.LSTM

最常见写法类似:

lstm_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens)

它的对外用法和nn.RNNnn.GRU非常相似:

Y, (H, C) = lstm_layer(X, (H0, C0))

注意这里最大的区别:

返回的不再只是一个state,而是(H, C)两部分。

这正对应了我们前面从零实现时手动维护的两个状态。


18. 为什么nn.LSTM的状态结构最特别

因为 LSTM 不像 RNN / GRU 只有一个隐藏状态。
它有两个状态:

  • H:隐藏状态

  • C:记忆单元

所以在简洁实现里,初始化状态时也要写成:

(H0, C0)

而不能只传一个张量。

这是使用nn.LSTM时最容易和nn.GRUnn.RNN混的地方之一。


19. 简洁实现里的模型封装怎么处理

和简洁版 RNN、GRU 一样,LSTM 外面通常也会再包一层语言模型头。

也就是说:

  • nn.LSTM负责生成每个时间步的隐藏表示

  • nn.Linear负责把隐藏表示映射到词表大小

所以整体模型结构仍然是:

  1. 输入 token

  2. one-hot

  3. LSTM

  4. reshape

  5. linear

  6. 输出词表打分

外层套路并没有变,变的是内部状态变复杂了。


20. LSTM 代码和 GRU 代码最本质的区别在哪里

这是这篇最该说透的一点。

GRU

  • 一个状态H

  • 两个门

  • 一个候选状态

  • 最终状态是旧H和新候选状态的融合

LSTM

  • 两个状态HC

  • 三个门

  • 一个候选记忆

  • 先更新C,再从C读出H

所以从代码结构上,LSTM 比 GRU 最大的不同就是:

它把“记忆存储”和“输出表达”拆开了。

这也是为什么它更系统、更经典,但代码也更长一点。


21. LSTM 训练流程和前面有什么变化吗

整体训练逻辑几乎不变。

还是:

  • 输入 token 序列

  • 输出下一个 token 的预测

  • 用交叉熵损失

  • 反向传播

  • 梯度裁剪

  • 参数更新

变化的不是训练外壳,而是:

状态更新公式更复杂,且要同时维护(H, C)

所以,如果你前面已经把 RNN、GRU 训练管道搞清楚了,
LSTM 在训练框架上其实不会让你特别陌生。


22. 这一节最该掌握什么

如果从学习重点来看,最关键的是这几件事。

22.1 明白为什么要有四组参数

因为输入门、遗忘门、输出门、候选记忆都要单独计算。

22.2 看懂CH的职责分工

  • C:长期记忆

  • H:当前输出表征

22.3 看懂C = F * C + I * C_tilde

这是 LSTM 最核心的一行代码。

22.4 知道nn.LSTM的状态是(H, C)

这是和nn.RNNnn.GRU最大的接口差别。

22.5 把它和 GRU 对照起来理解

这样你不会只记住“LSTM 更复杂”,而是真知道复杂在哪。


23. 本节总结

这一节我们学习了 LSTM 的代码实现,核心内容可以总结为以下几点。

23.1 LSTM 从零实现需要四组核心参数

分别对应:

  • 输入门

  • 遗忘门

  • 输出门

  • 候选记忆

23.2 LSTM 同时维护隐藏状态和记忆单元

这让长期记忆流动更稳定。

23.3 记忆单元更新是 LSTM 的关键

旧记忆和新候选记忆在门控作用下融合。

23.4 隐藏状态是从记忆单元中读出来的

因此HC职责不同。

23.5nn.LSTM是对这一整套机制的框架封装

使用时最需要注意的是状态结构为(H, C)


24. 学习感悟

LSTM 代码这一节非常有代表性,因为它让你真正看到:

一个模型变强,不一定是因为网络更深了,而可能是因为“状态管理更精细了”。

RNN 像是在“顺着时间流动”;
GRU 像是在“加门控做筛选”;
而 LSTM 则进一步把记忆流程分成:

  • 擦除

  • 写入

  • 读取

这已经非常像一个真正的记忆系统了。

所以学完 LSTM 代码之后,你对“循环神经网络为什么能记忆”这件事,理解会明显更深一层。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/15 21:17:33

Windows 10/11下达梦数据库8.0安装避坑指南(附常见错误解决方案)

Windows 10/11下达梦数据库8.0安装避坑指南(附常见错误解决方案) 在国产数据库生态快速发展的今天,达梦数据库作为核心产品之一,正被越来越多的企业采用。但对于初次接触达梦的技术人员来说,Windows环境下的安装过程往…

作者头像 李华
网站建设 2026/4/17 2:55:58

网安靶场平台大盘点(2026版)

网安靶场平台大盘点(2026版) 摘要:网络安全的核心是实战,而靶场平台正是网安从业者的“练兵场”——无论是零基础新手入门、转行从业者积累实战经验,还是资深工程师提升攻防能力,靶场都是不可或缺的核心工…

作者头像 李华
网站建设 2026/4/16 3:01:25

Anti-UAV 反无人机数据集实战:YOLO格式转换脚本解析与优化

1. 反无人机数据集与YOLO格式转换的必要性 在无人机检测领域,高质量的数据集是模型训练的基础。Anti-UAV这类专业数据集包含了丰富的无人机飞行场景,但原始数据往往采用JSON等通用格式存储标注信息,而YOLO系列算法需要特定的文本标注格式。这…

作者头像 李华
网站建设 2026/4/16 17:28:05

从汽车到工厂:深入浅出解析PTP在TSN和AUTOSAR中的实现差异

从汽车到工厂:深入浅出解析PTP在TSN和AUTOSAR中的实现差异 在工业自动化和汽车电子领域,时间同步技术正成为支撑下一代智能系统的关键基础设施。想象一下,当一辆自动驾驶汽车以120公里时速行驶时,其传感器、控制器和执行器之间的时…

作者头像 李华
网站建设 2026/4/14 18:01:21

Web漏洞全景解析:从原理溯源到实战攻防的进阶指南

web类型漏洞Web 类型漏洞是指在 Web 应用程序的设计、开发或配置过程中产生的安全缺陷,攻击者可以利用这些缺陷执行未授权的操作,例如窃取数据、控制服务器或破坏服务。以下是一些最常见和最具危害性的 Web 漏洞类型:💉 注入 (Inj…

作者头像 李华
网站建设 2026/4/16 8:07:16

ROS2 实时性能调优实战:从内核到应用的确定性延迟达成

1. 从内核层开始:打造实时系统的基石 第一次在机器人手臂上部署ROS2时,我遇到了一个诡异现象:明明代码逻辑没问题,机械臂却总在特定角度出现"卡顿"。用示波器抓取信号后发现,某些控制指令的延迟会突然从200μ…

作者头像 李华