news 2026/4/22 11:50:45

用CycleGAN实现马变斑马:5分钟搞定无配对数据风格迁移(附PyTorch代码)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
用CycleGAN实现马变斑马:5分钟搞定无配对数据风格迁移(附PyTorch代码)

用CycleGAN实现马变斑马:5分钟搞定无配对数据风格迁移(附PyTorch代码)

想象一下,你手头有一批马的图片,但需要它们变成斑马的样子——不是简单的滤镜处理,而是保留马的姿态、背景,只改变纹理和颜色特征。传统方法需要成对的马和斑马图片(同一匹马在不同场景下的两种形态),但现实中这种数据几乎不存在。这就是CycleGAN的用武之地:它能在没有配对数据的情况下,实现两个图像域之间的高质量转换。

1. 环境准备与数据加载

1.1 安装必要依赖

推荐使用Python 3.8+和PyTorch 1.10+环境。以下命令可快速安装所需库:

pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 pip install matplotlib opencv-python tqdm

1.2 准备数据集

我们从Kaggle获取两个公开数据集:

  • 马图像数据集(包含1067张图片)
  • 斑马图像数据集(包含1334张图片)
from torchvision.datasets import ImageFolder from torchvision import transforms transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) horse_dataset = ImageFolder('horse2zebra/trainA', transform=transform) zebra_dataset = ImageFolder('horse2zebra/trainB', transform=transform)

提示:数据集无需成对匹配,但建议两个域的图片数量相近,且内容类型相似(如都是动物全身照)

2. 模型架构实现

2.1 生成器设计

CycleGAN采用改进的U-Net结构,包含下采样、残差块和上采样三部分:

import torch.nn as nn class ResidualBlock(nn.Module): def __init__(self, in_features): super().__init__() self.block = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features), nn.ReLU(inplace=True), nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features) ) def forward(self, x): return x + self.block(x) class Generator(nn.Module): def __init__(self, input_nc=3, output_nc=3, n_residual_blocks=9): super().__init__() # 初始卷积块 model = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), nn.InstanceNorm2d(64), nn.ReLU(inplace=True) ] # 下采样 in_features = 64 out_features = in_features*2 for _ in range(2): model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features*2 # 残差块 for _ in range(n_residual_blocks): model += [ResidualBlock(in_features)] # 上采样 out_features = in_features//2 for _ in range(2): model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features//2 # 输出层 model += [ nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh() ] self.model = nn.Sequential(*model) def forward(self, x): return self.model(x)

2.2 判别器设计

使用PatchGAN结构,对图像的局部区域进行真伪判断:

class Discriminator(nn.Module): def __init__(self, input_nc=3): super().__init__() def discriminator_block(in_filters, out_filters, normalize=True): layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] if normalize: layers.append(nn.InstanceNorm2d(out_filters)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers self.model = nn.Sequential( *discriminator_block(input_nc, 64, normalize=False), *discriminator_block(64, 128), *discriminator_block(128, 256), *discriminator_block(256, 512), nn.ZeroPad2d((1, 0, 1, 0)), nn.Conv2d(512, 1, 4, padding=1) ) def forward(self, img): return self.model(img)

3. 训练策略与损失函数

3.1 对抗损失

标准的GAN损失函数,用于让生成图像逼近目标域分布:

criterion_GAN = torch.nn.MSELoss() # 计算生成器对抗损失 def compute_generator_loss(fake_pred): return criterion_GAN(fake_pred, torch.ones_like(fake_pred)) # 计算判别器损失 def compute_discriminator_loss(real_pred, fake_pred): real_loss = criterion_GAN(real_pred, torch.ones_like(real_pred)) fake_loss = criterion_GAN(fake_pred, torch.zeros_like(fake_pred)) return (real_loss + fake_loss) * 0.5

3.2 循环一致性损失

确保转换后的图像能还原回原始图像:

criterion_cycle = torch.nn.L1Loss() def compute_cycle_loss(real_img, cycled_img): return criterion_cycle(cycled_img, real_img) * 10.0 # λ=10

3.3 身份损失

保持输入图像的颜色分布:

criterion_identity = torch.nn.L1Loss() def compute_identity_loss(input_img, identity_img): return criterion_identity(identity_img, input_img) * 0.5 # λ=0.5

4. 完整训练流程

4.1 初始化模型与优化器

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") G_A2B = Generator().to(device) # 马→斑马 G_B2A = Generator().to(device) # 斑马→马 D_A = Discriminator().to(device) # 判别马 D_B = Discriminator().to(device) # 判别斑马 optimizer_G = torch.optim.Adam( list(G_A2B.parameters()) + list(G_B2A.parameters()), lr=0.0002, betas=(0.5, 0.999) ) optimizer_D = torch.optim.Adam( list(D_A.parameters()) + list(D_B.parameters()), lr=0.0002, betas=(0.5, 0.999) )

4.2 训练循环关键代码

for epoch in range(200): for i, (real_A, real_B) in enumerate(zip(horse_loader, zebra_loader)): # 前向传播 fake_B = G_A2B(real_A) cycled_A = G_B2A(fake_B) fake_A = G_B2A(real_B) cycled_B = G_A2B(fake_A) # 身份映射 identity_A = G_B2A(real_A) identity_B = G_A2B(real_B) # 判别器输出 pred_real_A = D_A(real_A) pred_fake_A = D_A(fake_A.detach()) pred_real_B = D_B(real_B) pred_fake_B = D_B(fake_B.detach()) # 生成器损失 loss_G_A2B = compute_generator_loss(D_B(fake_B)) loss_G_B2A = compute_generator_loss(D_A(fake_A)) loss_cycle_A = compute_cycle_loss(real_A, cycled_A) loss_cycle_B = compute_cycle_loss(real_B, cycled_B) loss_id_A = compute_identity_loss(real_A, identity_A) loss_id_B = compute_identity_loss(real_B, identity_B) loss_G = (loss_G_A2B + loss_G_B2A + loss_cycle_A + loss_cycle_B + loss_id_A + loss_id_B) # 判别器损失 loss_D_A = compute_discriminator_loss(pred_real_A, pred_fake_A) loss_D_B = compute_discriminator_loss(pred_real_B, pred_fake_B) loss_D = (loss_D_A + loss_D_B) * 0.5 # 反向传播 optimizer_G.zero_grad() loss_G.backward() optimizer_G.step() optimizer_D.zero_grad() loss_D.backward() optimizer_D.step()

4.3 训练技巧与参数设置

参数推荐值作用说明
学习率0.0002Adam优化器初始学习率
β10.5Adam动量参数
Batch Size1-4受限于显存,小批量也能工作
残差块数量9平衡模型容量与训练难度
λ_cycle10循环一致性损失权重
λ_identity0.5身份损失权重

注意:训练初期(约前10个epoch)可暂时禁用身份损失,待模型初步收敛后再加入

5. 效果展示与应用扩展

5.1 转换效果对比

经过200个epoch的训练后,我们得到以下典型转换效果:

  • 马→斑马转换

    • 保留原始姿态和背景
    • 成功添加斑马条纹
    • 自然调整颜色至斑马特征
  • 斑马→马转换

    • 去除条纹纹理
    • 调整为单色毛发
    • 保持四肢结构和场景不变

5.2 其他应用场景

只需更换训练数据,同一套代码可用于:

  • 风景照片的季节转换(夏↔冬)
  • 素描与照片的相互转换
  • 不同艺术风格间的迁移
  • 医学图像域适应(CT↔MRI)
# 快速测试模型 def test_transform(image_path): image = Image.open(image_path) transform = transforms.Compose([ transforms.Resize(256), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) return transform(image).unsqueeze(0) horse_img = test_transform("test_horse.jpg").to(device) zebra_result = G_A2B(horse_img) save_image(zebra_result*0.5+0.5, "converted_zebra.jpg")

在实际项目中,我发现调整身份损失的权重对颜色保持特别关键。当处理人像照片时,λ_identity=0.5能有效防止肤色异常变化,而对于风景照则可以适当降低到0.2-0.3。另一个实用技巧是在训练后期(最后20%的epoch)将学习率线性衰减到零,这能显著提升生成图像的细节质量。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/22 11:49:45

FPGA约束文件(XDC)的‘潜规则’:除了引脚和时序,你更该注意的语法细节

FPGA约束文件(XDC)的语法哲学:从工具使用者到规则制定者的思维跃迁 当我们第一次接触XDC文件时,往往把它当作普通的配置文件对待——简单记录引脚位置和时序要求。但随着项目复杂度提升,这种认知会让我们陷入各种难以排查的约束失效陷阱。实际…

作者头像 李华
网站建设 2026/4/22 11:47:53

终极指南:如何彻底解除极域电子教室控制,重获电脑自由

终极指南:如何彻底解除极域电子教室控制,重获电脑自由 【免费下载链接】JiYuTrainer 极域电子教室防控制软件, StudenMain.exe 破解 项目地址: https://gitcode.com/gh_mirrors/ji/JiYuTrainer 你是否曾在课堂上被极域电子教室的全屏广播锁住电脑…

作者头像 李华
网站建设 2026/4/22 11:47:40

draw.io桌面版:革命性的跨平台绘图解决方案

draw.io桌面版:革命性的跨平台绘图解决方案 【免费下载链接】drawio-desktop Official electron build of draw.io 项目地址: https://gitcode.com/GitHub_Trending/dr/drawio-desktop draw.io桌面版是一款基于Electron构建的专业级图表绘制工具,…

作者头像 李华