用CycleGAN实现马变斑马:5分钟搞定无配对数据风格迁移(附PyTorch代码)
想象一下,你手头有一批马的图片,但需要它们变成斑马的样子——不是简单的滤镜处理,而是保留马的姿态、背景,只改变纹理和颜色特征。传统方法需要成对的马和斑马图片(同一匹马在不同场景下的两种形态),但现实中这种数据几乎不存在。这就是CycleGAN的用武之地:它能在没有配对数据的情况下,实现两个图像域之间的高质量转换。
1. 环境准备与数据加载
1.1 安装必要依赖
推荐使用Python 3.8+和PyTorch 1.10+环境。以下命令可快速安装所需库:
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 pip install matplotlib opencv-python tqdm1.2 准备数据集
我们从Kaggle获取两个公开数据集:
- 马图像数据集(包含1067张图片)
- 斑马图像数据集(包含1334张图片)
from torchvision.datasets import ImageFolder from torchvision import transforms transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) horse_dataset = ImageFolder('horse2zebra/trainA', transform=transform) zebra_dataset = ImageFolder('horse2zebra/trainB', transform=transform)提示:数据集无需成对匹配,但建议两个域的图片数量相近,且内容类型相似(如都是动物全身照)
2. 模型架构实现
2.1 生成器设计
CycleGAN采用改进的U-Net结构,包含下采样、残差块和上采样三部分:
import torch.nn as nn class ResidualBlock(nn.Module): def __init__(self, in_features): super().__init__() self.block = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features), nn.ReLU(inplace=True), nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features) ) def forward(self, x): return x + self.block(x) class Generator(nn.Module): def __init__(self, input_nc=3, output_nc=3, n_residual_blocks=9): super().__init__() # 初始卷积块 model = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), nn.InstanceNorm2d(64), nn.ReLU(inplace=True) ] # 下采样 in_features = 64 out_features = in_features*2 for _ in range(2): model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features*2 # 残差块 for _ in range(n_residual_blocks): model += [ResidualBlock(in_features)] # 上采样 out_features = in_features//2 for _ in range(2): model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features//2 # 输出层 model += [ nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh() ] self.model = nn.Sequential(*model) def forward(self, x): return self.model(x)2.2 判别器设计
使用PatchGAN结构,对图像的局部区域进行真伪判断:
class Discriminator(nn.Module): def __init__(self, input_nc=3): super().__init__() def discriminator_block(in_filters, out_filters, normalize=True): layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] if normalize: layers.append(nn.InstanceNorm2d(out_filters)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers self.model = nn.Sequential( *discriminator_block(input_nc, 64, normalize=False), *discriminator_block(64, 128), *discriminator_block(128, 256), *discriminator_block(256, 512), nn.ZeroPad2d((1, 0, 1, 0)), nn.Conv2d(512, 1, 4, padding=1) ) def forward(self, img): return self.model(img)3. 训练策略与损失函数
3.1 对抗损失
标准的GAN损失函数,用于让生成图像逼近目标域分布:
criterion_GAN = torch.nn.MSELoss() # 计算生成器对抗损失 def compute_generator_loss(fake_pred): return criterion_GAN(fake_pred, torch.ones_like(fake_pred)) # 计算判别器损失 def compute_discriminator_loss(real_pred, fake_pred): real_loss = criterion_GAN(real_pred, torch.ones_like(real_pred)) fake_loss = criterion_GAN(fake_pred, torch.zeros_like(fake_pred)) return (real_loss + fake_loss) * 0.53.2 循环一致性损失
确保转换后的图像能还原回原始图像:
criterion_cycle = torch.nn.L1Loss() def compute_cycle_loss(real_img, cycled_img): return criterion_cycle(cycled_img, real_img) * 10.0 # λ=103.3 身份损失
保持输入图像的颜色分布:
criterion_identity = torch.nn.L1Loss() def compute_identity_loss(input_img, identity_img): return criterion_identity(identity_img, input_img) * 0.5 # λ=0.54. 完整训练流程
4.1 初始化模型与优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") G_A2B = Generator().to(device) # 马→斑马 G_B2A = Generator().to(device) # 斑马→马 D_A = Discriminator().to(device) # 判别马 D_B = Discriminator().to(device) # 判别斑马 optimizer_G = torch.optim.Adam( list(G_A2B.parameters()) + list(G_B2A.parameters()), lr=0.0002, betas=(0.5, 0.999) ) optimizer_D = torch.optim.Adam( list(D_A.parameters()) + list(D_B.parameters()), lr=0.0002, betas=(0.5, 0.999) )4.2 训练循环关键代码
for epoch in range(200): for i, (real_A, real_B) in enumerate(zip(horse_loader, zebra_loader)): # 前向传播 fake_B = G_A2B(real_A) cycled_A = G_B2A(fake_B) fake_A = G_B2A(real_B) cycled_B = G_A2B(fake_A) # 身份映射 identity_A = G_B2A(real_A) identity_B = G_A2B(real_B) # 判别器输出 pred_real_A = D_A(real_A) pred_fake_A = D_A(fake_A.detach()) pred_real_B = D_B(real_B) pred_fake_B = D_B(fake_B.detach()) # 生成器损失 loss_G_A2B = compute_generator_loss(D_B(fake_B)) loss_G_B2A = compute_generator_loss(D_A(fake_A)) loss_cycle_A = compute_cycle_loss(real_A, cycled_A) loss_cycle_B = compute_cycle_loss(real_B, cycled_B) loss_id_A = compute_identity_loss(real_A, identity_A) loss_id_B = compute_identity_loss(real_B, identity_B) loss_G = (loss_G_A2B + loss_G_B2A + loss_cycle_A + loss_cycle_B + loss_id_A + loss_id_B) # 判别器损失 loss_D_A = compute_discriminator_loss(pred_real_A, pred_fake_A) loss_D_B = compute_discriminator_loss(pred_real_B, pred_fake_B) loss_D = (loss_D_A + loss_D_B) * 0.5 # 反向传播 optimizer_G.zero_grad() loss_G.backward() optimizer_G.step() optimizer_D.zero_grad() loss_D.backward() optimizer_D.step()4.3 训练技巧与参数设置
| 参数 | 推荐值 | 作用说明 |
|---|---|---|
| 学习率 | 0.0002 | Adam优化器初始学习率 |
| β1 | 0.5 | Adam动量参数 |
| Batch Size | 1-4 | 受限于显存,小批量也能工作 |
| 残差块数量 | 9 | 平衡模型容量与训练难度 |
| λ_cycle | 10 | 循环一致性损失权重 |
| λ_identity | 0.5 | 身份损失权重 |
注意:训练初期(约前10个epoch)可暂时禁用身份损失,待模型初步收敛后再加入
5. 效果展示与应用扩展
5.1 转换效果对比
经过200个epoch的训练后,我们得到以下典型转换效果:
马→斑马转换:
- 保留原始姿态和背景
- 成功添加斑马条纹
- 自然调整颜色至斑马特征
斑马→马转换:
- 去除条纹纹理
- 调整为单色毛发
- 保持四肢结构和场景不变
5.2 其他应用场景
只需更换训练数据,同一套代码可用于:
- 风景照片的季节转换(夏↔冬)
- 素描与照片的相互转换
- 不同艺术风格间的迁移
- 医学图像域适应(CT↔MRI)
# 快速测试模型 def test_transform(image_path): image = Image.open(image_path) transform = transforms.Compose([ transforms.Resize(256), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) return transform(image).unsqueeze(0) horse_img = test_transform("test_horse.jpg").to(device) zebra_result = G_A2B(horse_img) save_image(zebra_result*0.5+0.5, "converted_zebra.jpg")在实际项目中,我发现调整身份损失的权重对颜色保持特别关键。当处理人像照片时,λ_identity=0.5能有效防止肤色异常变化,而对于风景照则可以适当降低到0.2-0.3。另一个实用技巧是在训练后期(最后20%的epoch)将学习率线性衰减到零,这能显著提升生成图像的细节质量。