news 2026/6/19 10:13:59

别再死记ResNet18结构图了!用PyTorch代码逐层拆解,手把手带你理解残差连接

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
别再死记ResNet18结构图了!用PyTorch代码逐层拆解,手把手带你理解残差连接

用PyTorch代码逐层拆解ResNet18:从张量流动理解残差连接

在深度学习领域,ResNet18作为经典卷积神经网络架构,其创新性的残差连接设计彻底改变了深层网络训练的范式。但传统学习方式往往停留在结构图记忆层面,难以真正理解数据在网络中的流动与变换机制。本文将带您通过PyTorch代码实现,逐层解剖ResNet18的运作原理,让残差连接的设计思想变得触手可及。

1. 残差网络的核心设计思想

残差网络(Residual Network)的诞生源于一个看似简单却影响深远的问题:为什么增加网络深度反而会导致性能下降?传统观点认为更深的网络理应具有更强的特征提取能力,但实验数据却显示56层的网络比20层的网络在ImageNet数据集上表现更差。这种"退化现象"直接挑战了深度学习的理论基础。

何恺明团队通过残差学习框架解决了这一难题。其核心创新在于将传统的直接拟合目标函数H(x)转变为拟合残差函数F(x) = H(x)-x。这种转变的物理意义在于:如果恒等映射已经是最优解,那么将残差推向零比堆叠非线性层来拟合恒等映射要容易得多。

用PyTorch代码可以直观展示这一思想:

import torch import torch.nn as nn # 传统网络块 class PlainBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) def forward(self, x): out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) return self.relu(out) # 残差网络块 class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) # 捷径连接 self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride), nn.BatchNorm2d(out_channels) ) def forward(self, x): residual = self.shortcut(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += residual # 关键残差连接 return self.relu(out)

关键差异在于残差块中的out += residual操作,这使得网络可以学习输入与输出之间的差异而非直接映射。实验表明,这种设计让训练极深层网络成为可能:

网络类型层数Top-1错误率训练稳定性
普通网络1828%中等
残差网络1825%
普通网络3432%
残差网络3424%

2. ResNet18的完整架构实现

ResNet18由四个主要阶段(Stage)组成,每个阶段包含若干残差块,整体架构如下:

class ResNet18(nn.Module): def __init__(self, num_classes=1000): super().__init__() self.in_channels = 64 # 初始卷积层 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 四个阶段 self.layer1 = self._make_layer(64, 64, 2, stride=1) # Stage1 self.layer2 = self._make_layer(64, 128, 2, stride=2) # Stage2 self.layer3 = self._make_layer(128, 256, 2, stride=2) # Stage3 self.layer4 = self._make_layer(256, 512, 2, stride=2) # Stage4 # 分类头 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512, num_classes) def _make_layer(self, in_channels, out_channels, blocks, stride): layers = [] # 第一个块可能需要下采样 layers.append(ResidualBlock(in_channels, out_channels, stride)) # 后续块保持通道数和尺寸 for _ in range(1, blocks): layers.append(ResidualBlock(out_channels, out_channels, stride=1)) return nn.Sequential(*layers) def forward(self, x): # 初始处理 x = self.conv1(x) # [b,3,224,224] -> [b,64,112,112] x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) # [b,64,112,112] -> [b,64,56,56] # 四个阶段 x = self.layer1(x) # [b,64,56,56] -> [b,64,56,56] x = self.layer2(x) # [b,64,56,56] -> [b,128,28,28] x = self.layer3(x) # [b,128,28,28] -> [b,256,14,14] x = self.layer4(x) # [b,256,14,14] -> [b,512,7,7] # 分类 x = self.avgpool(x) # [b,512,7,7] -> [b,512,1,1] x = torch.flatten(x, 1) # [b,512,1,1] -> [b,512] x = self.fc(x) # [b,512] -> [b,num_classes] return x

通过代码中的张量形状注释,我们可以清晰看到数据在网络中的流动过程。特别值得注意的是四个阶段之间的过渡:

  1. Stage1:保持56x56的空间分辨率,仅进行特征提取
  2. Stage2:通过stride=2的卷积将特征图尺寸减半,同时通道数翻倍
  3. Stage3Stage4:重复相同模式,逐步提取更高层次的特征

3. 残差块的两种类型与1x1卷积的作用

ResNet18中实际存在两种残差块设计,区别主要在于是否改变特征图的尺寸和通道数:

类型一:恒等残差块(Identity Block)

  • 输入输出尺寸相同
  • 通道数保持不变
  • 直接使用原始输入作为捷径连接
# 示例:Stage1中的第二个残差块 x = torch.randn(1, 64, 56, 56) # 模拟输入 block = ResidualBlock(64, 64) # 创建残差块 out = block(x) print(out.shape) # torch.Size([1, 64, 56, 56])

类型二:卷积残差块(Convolutional Block)

  • 当需要改变特征图尺寸或通道数时使用
  • 在捷径连接中添加1x1卷积和批归一化
  • 主路径中使用stride=2的卷积进行下采样
# 示例:Stage2的第一个残差块 x = torch.randn(1, 64, 56, 56) # 模拟输入 block = ResidualBlock(64, 128, stride=2) # 创建残差块 out = block(x) print(out.shape) # torch.Size([1, 128, 28, 28])

1x1卷积在残差网络中扮演着关键角色:

  • 通道数调整:当主路径改变通道数时,捷径连接需要通过1x1卷积匹配维度
  • 参数效率:相比3x3卷积,1x1卷积能以更少的参数实现通道变换
  • 非线性增强:配合ReLU激活,1x1卷积可以引入额外的非线性能力

下表对比了两种残差块的结构差异:

特性恒等残差块卷积残差块
输入输出尺寸保持不变通常减半
输入输出通道数保持不变通常翻倍
捷径连接直接连接1x1卷积
使用场景同一阶段内部阶段过渡时
典型位置Stage1的所有块Stage2/3/4的第一个块

4. 实战:可视化ResNet18中的张量流动

理解网络结构的最佳方式是通过实际运行观察中间结果的变换。我们可以使用PyTorch的hook机制捕获各层的输出:

def register_hooks(model): features = {} def get_hook(name): def hook(module, input, output): features[name] = output.shape return hook # 注册hook到各关键层 model.conv1.register_forward_hook(get_hook('conv1')) model.maxpool.register_forward_hook(get_hook('maxpool')) for i, layer in enumerate([model.layer1, model.layer2, model.layer3, model.layer4]): layer.register_forward_hook(get_hook(f'layer{i+1}')) return features # 实例化模型并运行 model = ResNet18() features = register_hooks(model) x = torch.randn(1, 3, 224, 224) # 模拟输入图像 out = model(x) # 打印特征图尺寸变化 for name, shape in features.items(): print(f"{name}: {shape}")

典型输出结果展示了数据在ResNet18中的完整流动路径:

conv1: torch.Size([1, 64, 112, 112]) maxpool: torch.Size([1, 64, 56, 56]) layer1: torch.Size([1, 64, 56, 56]) layer2: torch.Size([1, 128, 28, 28]) layer3: torch.Size([1, 256, 14, 14]) layer4: torch.Size([1, 512, 7, 7])

要进一步理解残差连接的实际作用,我们可以比较有/无残差连接时的梯度流动:

# 梯度可视化工具函数 def plot_gradients(model, use_residual=True): model.train() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) criterion = nn.CrossEntropyLoss() # 模拟训练步骤 x = torch.randn(1, 3, 224, 224) target = torch.randint(0, 1000, (1,)) optimizer.zero_grad() out = model(x) loss = criterion(out, target) loss.backward() # 收集各层梯度均值 grads = [] for name, param in model.named_parameters(): if 'weight' in name and 'conv' in name: grads.append(param.grad.abs().mean().item()) # 绘制梯度分布 plt.figure(figsize=(10, 5)) plt.plot(grads, marker='o') plt.title(f'梯度流动 (残差连接: {use_residual})') plt.xlabel('网络深度') plt.ylabel('平均梯度值') plt.grid() plt.show() # 比较两种架构 resnet = ResNet18() plainnet = PlainNet18() # 假设实现的普通18层网络 plot_gradients(resnet, use_residual=True) plot_gradients(plainnet, use_residual=False)

残差网络通常显示出更均匀的梯度分布,验证了其缓解梯度消失问题的能力。而普通深层网络的后几层梯度往往显著变小,导致训练困难。

5. ResNet18的实际应用与变体调整

理解了基础架构后,我们可以针对不同任务调整ResNet18:

图像分类任务调整

# 更换分类头适应CIFAR-10(10类) model = ResNet18(num_classes=10) # 修改输入层适应32x32小图像 model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) model.maxpool = nn.Identity() # 移除初始下采样

特征提取任务调整

# 移除分类头,输出最后卷积层特征 class FeatureExtractor(nn.Module): def __init__(self, base_model): super().__init__() self.features = nn.Sequential( base_model.conv1, base_model.bn1, base_model.relu, base_model.maxpool, base_model.layer1, base_model.layer2, base_model.layer3, base_model.layer4 ) def forward(self, x): return self.features(x) extractor = FeatureExtractor(ResNet18()) features = extractor(torch.randn(1, 3, 224, 224)) print(features.shape) # torch.Size([1, 512, 7, 7])

轻量化调整

# 减少通道数创建轻量版 class LiteResNet18(nn.Module): def __init__(self, num_classes=1000): super().__init__() self.in_channels = 32 # 原为64 self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) self.bn1 = nn.BatchNorm2d(32) self.relu = nn.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(32, 32, 2) self.layer2 = self._make_layer(32, 64, 2, stride=2) self.layer3 = self._make_layer(64, 128, 2, stride=2) self.layer4 = self._make_layer(128, 256, 2, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(256, num_classes) # _make_layer实现与之前类似

ResNet系列的成功催生了许多变体,它们在基础架构上进行了不同改进:

变体主要改进适用场景
ResNeXt分组卷积增加基数(Cardinality)需要更高准确率的任务
Wide ResNet增加每层通道数,减少深度计算资源受限的环境
Res2Net在单个残差块内构建分层次特征细粒度识别任务
ResNet-D改进下采样路径设计需要更稳定训练的场景

在实际项目中,选择合适变体需要考虑:

  • 计算资源:移动端应用可能更适合轻量变体
  • 数据特性:细粒度识别任务可能受益于多尺度架构
  • 部署环境:某些硬件对特定操作(如分组卷积)有优化
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/19 10:08:23

COM3D2.MaidFiddler完全手册:实时女仆编辑器的实战指南

COM3D2.MaidFiddler完全手册:实时女仆编辑器的实战指南 【免费下载链接】COM3D2.MaidFiddler Maid Fiddler for COM3D2 -- a real-time value editor for COM3D2 项目地址: https://gitcode.com/gh_mirrors/co/COM3D2.MaidFiddler COM3D2.MaidFiddler是一款专…

作者头像 李华
网站建设 2026/6/6 3:07:54

告别理论!ADC0809八通道采集的三种数据读取方式详解(查询/中断/定时)

ADC0809八通道采集的三种数据读取方式实战解析 在嵌入式系统开发中,模拟信号采集是连接物理世界与数字系统的关键环节。ADC0809作为经典的8位8通道模数转换芯片,至今仍在许多工业控制、仪器仪表和教学实验中广泛应用。但很多开发者在实际项目中常遇到一个…

作者头像 李华
网站建设 2026/6/6 3:06:34

用Python搞定激光雷达地图坐标转换:从局部XY到WGS84经纬度的保姆级教程

激光雷达地图坐标转换实战:从局部XY到WGS84的高精度工程指南当无人机掠过城市上空或机器人穿梭于复杂环境时,激光雷达扫描生成的二维地图就像一张数字化的藏宝图。但如何将图纸上的XY坐标点转化为真实世界的经纬度?这不仅是测绘工程师的日常挑…

作者头像 李华
网站建设 2026/6/7 19:37:55

从一次Ping不通的故障说起:深入理解MTU、MSS和VLAN Tag对云网络的影响

云网络故障排查实战:当MTU与VLAN Tag成为隐形杀手深夜的告警铃声划破了运维中心的宁静——某金融云平台的跨机房虚拟机突发通信异常。同子网的两台关键业务虚拟机之间,ICMP探测全部超时,但诡异的是,基于TCP 8080端口的业务请求却时…

作者头像 李华