news 2026/4/15 23:10:26

第P2周:CIFAR10彩色图片识别

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
第P2周:CIFAR10彩色图片识别
  • 🍨本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖原作者:K同学啊

目录

一、 前期准备

1. 设置GPU

2. 导入数据

3. 数据可视化

二、构建简单的CNN网络

三、 训练模型

1. 设置超参数

2. 编写训练函数

3. 编写测试函数

4. 正式训练

四、 结果可视化

五、个人总结

一、 前期准备

1. 设置GPU

import torch import torch.nn as nn import matplotlib.pyplot as plt import torchvision device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device

2. 导入数据

train_ds = torchvision.datasets.CIFAR10('data', train=True, transform=torchvision.transforms.ToTensor(), download=True) test_ds = torchvision.datasets.CIFAR10('data', train=False, transform=torchvision.transforms.ToTensor(), download=True) batch_size = 32 train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True) test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size) imgs, labels = next(iter(train_dl)) imgs.shape

3. 数据可视化

import numpy as np plt.figure(figsize=(20, 5)) for i, imgs in enumerate(imgs[:20]): npimg = imgs.numpy().transpose((1, 2, 0)) plt.subplot(2, 10, i+1) plt.imshow(npimg, cmap=plt.cm.binary) plt.axis('off')

二、构建简单的CNN网络

import torch.nn.functional as F num_classes = 10 class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3) self.pool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = nn.Conv2d(64, 64, kernel_size=3) self.pool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = nn.Conv2d(64, 128, kernel_size=3) self.pool3 = nn.MaxPool2d(kernel_size=2) self.fc1 = nn.Linear(512, 256) self.fc2 = nn.Linear(256, num_classes) def forward(self, x): x = self.pool1(F.relu(self.conv1(x))) x = self.pool2(F.relu(self.conv2(x))) x = self.pool3(F.relu(self.conv3(x))) x = torch.flatten(x, start_dim=1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x from torchinfo import summary # 将模型转移到GPU中(我们模型运行均在GPU中进行) model = Model().to(device) summary(model)

三、 训练模型

1. 设置超参数

loss_fn = nn.CrossEntropyLoss() # 创建损失函数 learn_rate = 1e-2 # 学习率 opt = torch.optim.SGD(model.parameters(),lr=learn_rate)

2. 编写训练函数

# 训练循环 def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) # 训练集的大小,一共60000张图片 num_batches = len(dataloader) # 批次数目,1875(60000/32) train_loss, train_acc = 0, 0 # 初始化训练损失和正确率 for X, y in dataloader: # 获取图片及其标签 X, y = X.to(device), y.to(device) # 计算预测误差 pred = model(X) # 网络输出 loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失 # 反向传播 optimizer.zero_grad() # grad属性归零 loss.backward() # 反向传播 optimizer.step() # 每一步自动更新 # 记录acc与loss train_acc += (pred.argmax(1) == y).type(torch.float).sum().item() train_loss += loss.item() train_acc /= size train_loss /= num_batches return train_acc, train_loss

3. 编写测试函数

def test (dataloader, model, loss_fn): size = len(dataloader.dataset) # 测试集的大小,一共10000张图片 num_batches = len(dataloader) # 批次数目,313(10000/32=312.5,向上取整) test_loss, test_acc = 0, 0 # 当不进行训练时,停止梯度更新,节省计算内存消耗 with torch.no_grad(): for imgs, target in dataloader: imgs, target = imgs.to(device), target.to(device) # 计算loss target_pred = model(imgs) loss = loss_fn(target_pred, target) test_loss += loss.item() test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item() test_acc /= size test_loss /= num_batches return test_acc, test_loss

4. 正式训练

epochs = 10 train_loss = [] train_acc = [] test_loss = [] test_acc = [] for epoch in range(epochs): model.train() epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt) model.eval() epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn) train_acc.append(epoch_train_acc) train_loss.append(epoch_train_loss) test_acc.append(epoch_test_acc) test_loss.append(epoch_test_loss) template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}') print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss)) print('Done')

四、 结果可视化

import matplotlib.pyplot as plt #隐藏警告 import warnings warnings.filterwarnings("ignore") #忽略警告信息 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 plt.rcParams['figure.dpi'] = 100 #分辨率 from datetime import datetime current_time = datetime.now() # 获取当前时间 epochs_range = range(epochs) plt.figure(figsize=(12, 3)) plt.subplot(1, 2, 1) plt.plot(epochs_range, train_acc, label='Training Accuracy') plt.plot(epochs_range, test_acc, label='Test Accuracy') plt.legend(loc='lower right') plt.title('Training and Validation Accuracy') plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效 plt.subplot(1, 2, 2) plt.plot(epochs_range, train_loss, label='Training Loss') plt.plot(epochs_range, test_loss, label='Test Loss') plt.legend(loc='upper right') plt.title('Training and Validation Loss') plt.show()

五、个人总结

逐渐熟悉CNN模型构建过程,并逐步理解其原理。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/16 9:21:32

Spring Retry 全维度详解(结合 OpenFeign 实战)

目录 一、核心设计理念 二、快速入门:核心依赖与启用 1. 引入依赖 2. 启用 Spring Retry 三、核心使用方式 方式 1:声明式重试(注解方式,推荐) 1. 基础示例(结合 OpenFeign) 2. 注解参数…

作者头像 李华
网站建设 2026/4/15 22:08:16

无需训练数据!EmotiVoice实现秒级声音克隆的秘密

无需训练数据!EmotiVoice实现秒级声音克隆的秘密 在智能语音助手越来越“懂人心”的今天,我们是否曾期待过——它开口说话时,用的是亲人的嗓音?或是喜欢的主播语气?甚至,在讲笑话时真的能“笑出声”&#x…

作者头像 李华
网站建设 2026/4/16 10:59:19

EmotiVoice语音合成情感渐变功能:从平静到激动平滑过渡

EmotiVoice语音合成情感渐变功能:从平静到激动平滑过渡 在虚拟主播声情并茂地讲述故事、游戏角色因剧情转折突然爆发怒吼的那一刻,你是否曾好奇——这些声音是如何生成的?它们为何听起来如此真实而富有感染力?随着AI语音技术的发展…

作者头像 李华
网站建设 2026/4/16 9:25:22

告别低效!3种工具大幅提升大文件下载测试效率

快速体验 打开 InsCode(快马)平台 https://www.inscode.net输入框内输入如下内容: 开发一个对比测试工具,能同时运行curl、wget和aria2三种下载方式,自动记录各自的下载速度、CPU占用和内存消耗。要求可视化展示对比结果,支持导出…

作者头像 李华
网站建设 2026/4/16 9:20:30

揭秘大数据领域规范性分析的关键流程

揭秘大数据领域规范性分析的关键流程:从原理到实践 摘要/引言 在大数据时代,如何从海量的数据中提取有价值的信息并做出明智的决策成为了众多企业和组织关注的焦点。规范性分析作为大数据分析的重要组成部分,旨在为决策者提供具体的行动建议&…

作者头像 李华