news 2026/4/16 12:21:45

第P3周:Pytorch实现天气识别

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
第P3周:Pytorch实现天气识别
  • 🍨本文为🔗365天深度学习训练营中的学习记录博客

  • 🍖原作者:K同学啊

目录

一、 前期准备

1. 设置GPU

2. 导入数据

3. 显示图片

4. 划分数据集

二、构建简单的CNN网络

三、 训练模型

1. 设置超参数

2. 编写训练函数

3. 编写测试函数

4. 正式训练

四、 结果可视化

五、 个人总结

过拟合的确认方法

解决方案

1. 正则化措施

2. 数据增强优化

3. 批归一化(BN)应用

4. 早停策略

5. 动态学习率调整

6. 优化器升级

一、 前期准备

1. 设置GPU

import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision from torchvision import transforms, datasets import os,PIL,pathlib,random device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device

2. 导入数据

data_dir = './data/' data_dir = pathlib.Path(data_dir) data_paths = list(data_dir.glob('*')) classeNames = [str(path).split("\\")[1] for path in data_paths] classeNames
  • 第一步:使用pathlib.Path()函数将字符串类型的文件夹路径转换为pathlib.Path对象。
  • 第二步:使用glob()方法获取data_dir路径下的所有文件路径,并以列表形式存储在data_paths中。
  • 第三步:通过split()函数对data_paths中的每个文件路径执行分割操作,获得各个文件所属的类别名称,并存储在classeNames
  • 第四步:打印classeNames列表,显示每个文件所属的类别名称。

3. 显示图片

import matplotlib.pyplot as plt from PIL import Image # 指定图像文件夹路径 image_folder = './data/cloudy/' # 获取文件夹中的所有图像文件 image_files = [f for f in os.listdir(image_folder) if f.endswith((".jpg", ".png", ".jpeg"))] # 创建Matplotlib图像 fig, axes = plt.subplots(3, 8, figsize=(16, 6)) # 使用列表推导式加载和显示图像 for ax, img_file in zip(axes.flat, image_files): img_path = os.path.join(image_folder, img_file) img = Image.open(img_path) ax.imshow(img) ax.axis('off') # 显示图像 plt.tight_layout() plt.show()

total_datadir = './data/' # 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863 train_transforms = transforms.Compose([ transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸 transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间 transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。 ]) total_data = datasets.ImageFolder(total_datadir,transform=train_transforms) total_data

4. 划分数据集

train_size = int(0.8 * len(total_data)) test_size = len(total_data) - train_size train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size]) train_dataset, test_dataset train_size,test_size batch_size = 32 train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) for X, y in test_dl: print("Shape of X [N, C, H, W]: ", X.shape) print("Shape of y: ", y.shape, y.dtype) break

二、构建简单的CNN网络

import torch.nn.functional as F class Network_bn(nn.Module): def __init__(self): super(Network_bn, self).__init__() """ nn.Conv2d()函数: 第一个参数(in_channels)是输入的channel数量 第二个参数(out_channels)是输出的channel数量 第三个参数(kernel_size)是卷积核大小 第四个参数(stride)是步长,默认为1 第五个参数(padding)是填充大小,默认为0 """ self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0) self.bn1 = nn.BatchNorm2d(12) self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0) self.bn2 = nn.BatchNorm2d(12) self.pool1 = nn.MaxPool2d(2,2) self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0) self.bn4 = nn.BatchNorm2d(24) self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0) self.bn5 = nn.BatchNorm2d(24) self.pool2 = nn.MaxPool2d(2,2) self.fc1 = nn.Linear(24*50*50, len(classeNames)) def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = self.pool1(x) x = F.relu(self.bn4(self.conv4(x))) x = F.relu(self.bn5(self.conv5(x))) x = self.pool2(x) x = x.view(-1, 24*50*50) x = self.fc1(x) return x device = "cuda" if torch.cuda.is_available() else "cpu" print("Using {} device".format(device)) model = Network_bn().to(device) model

三、 训练模型

1. 设置超参数

loss_fn = nn.CrossEntropyLoss() # 创建损失函数 learn_rate = 1e-4 # 学习率 opt = torch.optim.SGD(model.parameters(),lr=learn_rate)

2. 编写训练函数

# 训练循环 def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) # 训练集的大小,一共60000张图片 num_batches = len(dataloader) # 批次数目,1875(60000/32) train_loss, train_acc = 0, 0 # 初始化训练损失和正确率 for X, y in dataloader: # 获取图片及其标签 X, y = X.to(device), y.to(device) # 计算预测误差 pred = model(X) # 网络输出 loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失 # 反向传播 optimizer.zero_grad() # grad属性归零 loss.backward() # 反向传播 optimizer.step() # 每一步自动更新 # 记录acc与loss train_acc += (pred.argmax(1) == y).type(torch.float).sum().item() train_loss += loss.item() train_acc /= size train_loss /= num_batches return train_acc, train_loss

3. 编写测试函数

def test (dataloader, model, loss_fn): size = len(dataloader.dataset) # 测试集的大小,一共10000张图片 num_batches = len(dataloader) # 批次数目,313(10000/32=312.5,向上取整) test_loss, test_acc = 0, 0 # 当不进行训练时,停止梯度更新,节省计算内存消耗 with torch.no_grad(): for imgs, target in dataloader: imgs, target = imgs.to(device), target.to(device) # 计算loss target_pred = model(imgs) loss = loss_fn(target_pred, target) test_loss += loss.item() test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item() test_acc /= size test_loss /= num_batches return test_acc, test_loss

4. 正式训练

epochs = 20 train_loss = [] train_acc = [] test_loss = [] test_acc = [] for epoch in range(epochs): model.train() epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt) model.eval() epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn) train_acc.append(epoch_train_acc) train_loss.append(epoch_train_loss) test_acc.append(epoch_test_acc) test_loss.append(epoch_test_loss) template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}') print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss)) print('Done')

四、 结果可视化

import matplotlib.pyplot as plt #隐藏警告 import warnings warnings.filterwarnings("ignore") #忽略警告信息 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 plt.rcParams['figure.dpi'] = 100 #分辨率 from datetime import datetime current_time = datetime.now() # 获取当前时间 epochs_range = range(epochs) plt.figure(figsize=(12, 3)) plt.subplot(1, 2, 1) plt.plot(epochs_range, train_acc, label='Training Accuracy') plt.plot(epochs_range, test_acc, label='Test Accuracy') plt.legend(loc='lower right') plt.title('Training and Validation Accuracy') plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效 plt.subplot(1, 2, 2) plt.plot(epochs_range, train_loss, label='Training Loss') plt.plot(epochs_range, test_loss, label='Test Loss') plt.legend(loc='upper right') plt.title('Training and Validation Loss') plt.show()

五、 个人总结

过拟合的确认方法

当模型出现疑似过拟合时,可通过以下方法进一步确认:

  1. 增加训练轮次:持续训练时若验证集准确率下降而损失上升,则确认过拟合
  2. 降低模型复杂度:如减少全连接层或通道数后验证指标提升,说明原模型过于复杂
  3. 增强数据多样性:添加数据增强后若验证指标改善,表明原训练数据不足导致模型死记硬背

解决方案

1. 正则化措施

  • Dropout:在全连接层前添加nn.Dropout(0.5),通过随机丢弃神经元强制学习鲁棒特征
  • L2正则化:优化器中设置weight_decay=1e-4,抑制过大权重,防止噪声拟合

2. 数据增强优化

  • 方法:使用torchvision.transforms实现随机裁剪、翻转、颜色变换等
  • 作用:提升数据多样性,促使模型学习通用特征而非样本细节,增强泛化能力

3. 批归一化(BN)应用

  • 原理:对卷积层输出进行归一化处理(均值≈0,方差≈1)后缩放平移
  • 优势
    • 稳定层间输入分布,加速收敛并缓解梯度问题
    • 具有正则化效果,类似"mini-batch级数据增强"

4. 早停策略

  • 实施:持续监控验证集准确率,当性能不再提升(超过patience轮次)时终止训练
  • 价值
    • 防止过度拟合训练噪声
    • 节省计算资源,自动获取最佳泛化模型

5. 动态学习率调整

  • 机制:当验证性能停滞时,按因子(如0.5)降低学习率
  • 效益
    • 实现更精细的参数优化
    • 平缓的参数更新可降低过拟合风险

6. 优化器升级

  • 改进方案:将SGD替换为Adam优化器并配合权重衰减
  • 原因:当前SGD学习率(1e-4)偏低导致收敛缓慢,且缺乏动量和正则化支持
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/13 13:52:07

学术圈公认最好用的十大降ai率产品全测评

家人们,现在学校查得是真严,不仅重复率,还得降ai率,学校规定必须得20%以下... 折腾了半个月,终于把市面上各类方法试了个遍,坑踩了不少,智商税也交了。今天这就把这份十大降AI工具合集掏心窝子…

作者头像 李华
网站建设 2026/4/15 16:53:42

企业数据API对接稳定性挑战与高可用架构实践指南

在数字化转型浪潮席卷全球的今天,企业数据API(Application Programming Interface)已成为连接内部系统、第三方服务与合作伙伴生态的核心纽带。然而,随着API调用量的指数级增长,企业面临着严峻的技术挑战:A…

作者头像 李华
网站建设 2026/4/15 7:17:21

通信协议仿真:5G NR协议仿真_(12).5G NR仿真中的移动性管理

5G NR仿真中的移动性管理 1. 移动性管理概述 移动性管理是5G NR(New Radio)协议中的一个重要组成部分,它确保用户在移动过程中能够保持无缝的连接和服务质量。移动性管理涉及多个方面,包括小区选择与重选、切换、重定向、连接恢复…

作者头像 李华
网站建设 2026/4/15 4:18:20

Chart.js 极地图

Chart.js 极地图 引言 极地图(Polar Chart)是一种展示数据分布和关系的图表类型,它通过极坐标系统来展示数据。在众多图表库中,Chart.js 是一个功能强大且易于使用的 JavaScript 图表库。本文将详细介绍如何使用 Chart.js 创建极地图,并探讨其在数据可视化中的应用。 极…

作者头像 李华
网站建设 2026/4/16 12:09:07

只要十分钟,AI率从89%降到13%!2025年度十大降AI工具推荐

家人们,现在学校查得是真严,不仅重复率,还得降ai率,学校规定必须得20%以下... 折腾了半个月,终于把市面上各类方法试了个遍,坑踩了不少,智商税也交了。今天这就把这份十大降AI工具合集掏心窝子…

作者头像 李华
网站建设 2026/4/16 2:28:48

九尾狐AI:传统企业AI转型实战白皮书

——从「技术恐惧」到「订单暴涨」的落地指南第一章:行业困局与趋势1.1 传统企业的AI转型痛点在数字经济浪潮下,企业AI培训已成为传统行业破局的关键赛道,但80%的中小企业仍面临「转型堰塞湖」:认知断层:佛山某雕塑公司…

作者头像 李华