简述:torchvision
一、Torchvision 是什么
Torchvision 是 PyTorch 官方配套的计算机视觉专用库,专门处理图像任务。
作用:提供常用数据集、图像预处理、经典模型、可视化工具
核心内容:
内置数据集:MNIST、CIFAR、ImageNet、COCO 等
经典模型:ResNet、VGG、U-Net、FasterRCNN 等(直接用预训练权重)
图像变换:裁剪、归一化、翻转、增强
工具:图像读取、显示、保存
二、3 个简单实用例子
例子 1:图像预处理(最常用)
fromtorchvisionimporttransformsfromPILimportImage# 定义一套图像预处理流程transform=transforms.Compose([transforms.Resize((224,224)),# 改尺寸transforms.ToTensor(),# 转张量transforms.Normalize(mean=[0.5],std=[0.5])# 归一化])# 加载并处理图片img=Image.open("test.jpg")img_tensor=transform(img)print(img_tensor.shape)# 输出: torch.Size([3, 224, 224])例子 2:直接用预训练模型 ResNet
fromtorchvisionimportmodels# 加载训练好的ResNet18model=models.resnet18(pretrained=True)model.eval()# 设为推理模式# 输入一张图importtorch x=torch.randn(1,3,224,224)# 1张图,3通道out=model(x)print(out.shape)# 输出分类结果例子 3:加载官方数据集(CIFAR10)
fromtorchvisionimportdatasets,transformsfromtorch.utils.dataimportDataLoader transform=transforms.ToTensor()# 自动下载加载CIFAR10dataset=datasets.CIFAR10(root="./data",train=True,download=True,transform=transform)# 批量加载loader=DataLoader(dataset,batch_size=32,shuffle=True)# 取一批数据images,labels=next(iter(loader))print(images.shape)# torch.Size([32, 3, 32, 32])三、一句话总结
Torchvision 是 PyTorch 视觉必备工具包,负责图像预处理、模型、数据集,让图像分类 / 检测 / 分割开箱即用,非常适合做变化检测、影像识别等任务。
本blog地址:https://blog.csdn.net/hsg77