news 2026/5/6 8:59:32

别只刷题了!用Python和PyTorch复现那些‘经典’的深度学习期末考题(附代码)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
别只刷题了!用Python和PyTorch复现那些‘经典’的深度学习期末考题(附代码)

别只刷题了!用Python和PyTorch复现那些‘经典’的深度学习期末考题(附代码)

深度学习理论考试总让人头疼——公式推导、参数计算、概念辨析,稍不留神就会陷入"纸上谈兵"的困境。但换个角度想,这些试题本质上是检验我们对核心算法的理解程度。与其死记硬背,不如打开Jupyter Notebook,用代码将这些抽象问题具象化。本文将带你用PyTorch重新演绎五类经典考题,从反向传播实现到LSTM结构拆解,让理论在代码中"活"起来。

1. 反向传播的代码解剖

考试要求推导三层网络梯度?PyTorch的自动微分机制能让我们直观验证计算结果。先构建一个无激活函数的简易网络:

import torch # 试题参数初始化 x = torch.tensor([1.0], requires_grad=True) w1 = torch.tensor([0.5], requires_grad=True) w2 = torch.tensor([0.3], requires_grad=True) w3 = torch.tensor([0.2], requires_grad=True) # 前向传播 y = x * w1 * w2 * w3 print(f"前向输出值: {y.item()}") # 输出: 0.03

现在模拟考题要求计算梯度:

# 设置损失梯度为1.0 y.backward(torch.tensor([1.0])) print(f"dL/dw1: {w1.grad.item()}") # 0.06 (w2*w3*x) print(f"dL/dw2: {w2.grad.item()}") # 0.1 (w1*w3*x) print(f"dL/dw3: {w3.grad.item()}") # 0.15 (w1*w2*x)

对比手工计算:

  • $\frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial w_1} = 1.0 \times (w_2 w_3 x) = 0.3 \times 0.2 \times 1 = 0.06$

关键发现:通过.grad属性可以直接验证梯度计算是否正确。实践中建议用torch.autograd.grad()函数更灵活地获取特定梯度:

dy_dw1 = torch.autograd.grad(outputs=y, inputs=w1, retain_graph=True)[0]

2. 卷积网络参数实战

考试常见的卷积参数计算题,用PyTorch的nn.Conv2d可以直观验证。题目给出:

  • 输入尺寸:32×32×3
  • 卷积核:10个5×5,stride=1, padding=0
import torch.nn as nn conv = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=5, stride=1, padding=0) input = torch.randn(1, 3, 32, 32) # batch=1 output = conv(input) print(f"输出特征图尺寸: {output.shape[2:]}") # torch.Size([1, 10, 28, 28])

参数数量计算:

  • 每个5×5卷积核有$5 \times 5 \times 3 = 75$个权重
  • 10个卷积核共$75 \times 10 = 750$权重参数
  • 每个卷积核1个偏置,共$10$个偏置参数
  • 总计$750 + 10 = 760$个可训练参数

提示:使用sum(p.numel() for p in conv.parameters())可自动统计参数总量

添加池化层验证:

pool = nn.MaxPool2d(kernel_size=2, stride=2) pool_output = pool(output) print(f"池化后尺寸: {pool_output.shape[2:]}") # torch.Size([1, 10, 14, 14])

3. Word2Vec模型对比实现

CBOW和Skip-gram的结构差异常出现在简答题中。下面用PyTorch实现两者的核心区别:

class CBOW(nn.Module): def __init__(self, vocab_size, embedding_dim): super().__init__() self.embeddings = nn.Embedding(vocab_size, embedding_dim) self.linear = nn.Linear(embedding_dim, vocab_size) def forward(self, context): # context: [batch, 2*window_size] embeds = self.embeddings(context).mean(dim=1) # 上下文词向量平均 return self.linear(embeds) class SkipGram(nn.Module): def __init__(self, vocab_size, embedding_dim): super().__init__() self.embeddings = nn.Embedding(vocab_size, embedding_dim) self.linear = nn.Linear(embedding_dim, vocab_size) def forward(self, target): # target: [batch, 1] embeds = self.embeddings(target).squeeze(1) return self.linear(embeds)

架构对比表

特性CBOWSkip-gram
输入上下文词索引目标词索引
输出目标词概率分布上下文词概率分布
计算效率适合高频词适合低频词
数学本质上下文词向量的均值预测目标词目标词向量预测上下文词分布

负采样实现示例:

# 负采样损失函数 neg_loss = -torch.log(torch.sigmoid(pos_score)) - \ torch.sum(torch.log(torch.sigmoid(-neg_scores)))

4. LSTM长期依赖解决方案

用代码揭示LSTM如何解决RNN的梯度消失问题。关键在门控机制:

class CustomLSTMCell(nn.Module): def __init__(self, input_size, hidden_size): super().__init__() # 输入门、遗忘门、输出门、候选记忆 self.gates = nn.Linear(input_size + hidden_size, 4*hidden_size) def forward(self, x, hc): h, c = hc gates = self.gates(torch.cat([x, h], dim=1)) i, f, o, g = gates.chunk(4, 1) # 拆分为四个部分 i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o) g = torch.tanh(g) new_c = f * c + i * g # 记忆更新公式 new_h = o * torch.tanh(new_c) return new_h, new_c

门控机制解析

  1. 遗忘门($f_t$)控制历史记忆保留量
  2. 输入门($i_t$)调节新记忆的写入比例
  3. 候选记忆($\tilde{C}_t$)存储当前时间步的新信息
  4. 输出门($o_t$)决定隐藏状态的输出比例

可视化门控信号变化:

plt.plot(forget_gate_history, label='Forget Gate') plt.plot(input_gate_history, label='Input Gate') plt.legend()

5. Attention机制代码演绎

Seq2Seq的注意力改进方案,通过代码展示其工作原理:

class Attention(nn.Module): def __init__(self, enc_dim, dec_dim): super().__init__() self.attn = nn.Linear(enc_dim + dec_dim, dec_dim) self.v = nn.Parameter(torch.rand(dec_dim)) def forward(self, hidden, encoder_outputs): # hidden: [1, batch, dec_dim] # encoder_outputs: [seq_len, batch, enc_dim] seq_len = encoder_outputs.shape[0] hidden = hidden.repeat(seq_len, 1, 1) # 沿序列维度复制 energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2))) energy = energy.permute(1, 2, 0) # [batch, dec_dim, seq_len] v = self.v.repeat(encoder_outputs.size(1), 1).unsqueeze(1) attention = torch.bmm(v, energy).squeeze(1) # [batch, seq_len] return torch.softmax(attention, dim=1)

注意力计算流程

  1. 将解码器隐藏状态与所有编码器输出拼接
  2. 通过全连接层和tanh激活计算能量值
  3. 使用可学习参数$v$计算注意力分数
  4. 应用softmax归一化得到注意力权重

对比传统Seq2Seq与Attention的效果差异:

# 传统Seq2Seq output, hidden = decoder(input, hidden) # 加入Attention后 attn_weights = attention(hidden, encoder_outputs) context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs.transpose(0, 1)) output, hidden = decoder(input, torch.cat((context, hidden), dim=2))
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/12 3:37:00

照片变3D模型就这么简单!Face3D.ai Pro保姆级教程,从安装到导出

照片变3D模型就这么简单!Face3D.ai Pro保姆级教程,从安装到导出 1. 环境准备与快速部署 1.1 系统要求检查 在开始之前,请确认你的设备满足以下最低配置要求: 操作系统:Linux(推荐Ubuntu 18.04及以上&am…

作者头像 李华
网站建设 2026/4/11 17:20:07

超厉害的AI教材写作工具,低查重快速产出高质量教材!

在整理教材的过程中,我们常常遇到棘手的难题,这项工作简直像是一种“精细活”。其中最大的挑战就是如何找到平衡与衔接的点!一方面,我们总是担心会遗漏重要的核心知识点;另一方面,如何控制好难度的递进关系…

作者头像 李华
网站建设 2026/4/11 22:24:07

相亲软件靠谱,还是知名品牌靠谱?我给你讲明白

各位单身的朋友,今天咱们不绕弯子,直接聊聊当下市面上那些五花八门的相亲平台——有工具型的,有连锁型的,也有主打创新模式的。我会把它们的优缺点掰开揉碎了说,最后给大家一个最实在的推荐。先说说两款工具型平台&…

作者头像 李华
网站建设 2026/4/11 18:31:35

react native如何发送蓝牙命令

使用react-native-ble-plx插件: import { createContext, useState, useEffect, useContext, useRef } from react; import { BleManager } from react-native-ble-plx; import * as Location from expo-location; import { Platform, PermissionsAndroid, ToastAn…

作者头像 李华
网站建设 2026/4/11 14:12:31

Maomi.In | .NET 全能多语言解决方案兆

AI Agent 时代的沙箱需求 从 Copilot 到 Agent:执行能力的质变 在生成式 AI 的早期阶段,应用主要以“Copilot”形式存在,AI 仅作为辅助生成建议。然而,随着 AutoGPT、BabyAGI 以及 OpenAI Code Interpreter(现为 Advan…

作者头像 李华
网站建设 2026/4/12 4:49:19

3步掌握BiliTools:从B站资源收藏到高效管理的完整指南

3步掌握BiliTools:从B站资源收藏到高效管理的完整指南 【免费下载链接】BiliTools A cross-platform bilibili toolbox. 跨平台哔哩哔哩工具箱,支持下载视频、番剧等等各类资源 项目地址: https://gitcode.com/GitHub_Trending/bilit/BiliTools 想…

作者头像 李华