一、问题现象:WGAN-GP在AMP训练中完全失效
我们在MindSpore上复现WGAN-GP(带有梯度惩罚的Wasserstein GAN)模型。在FP32精度下,训练正常,判别器(Critic)损失能稳步下降,生成器(Generator)能学习到有效分布。然而,当启用自动混合精度以加速训练和节省显存时,训练过程完全崩溃:
# 启用AMP O2级别 (几乎全部算子使用FP16) from mindspore import amp network = Generator() critic = Critic() # 将网络和损失函数转换为AMP net_with_loss = MyWGANGPLoss(network, critic) optimizer_g = nn.Adam(network.trainable_params(), learning_rate=1e-4) optimizer_c = nn.Adam(critic.trainable_params(), learning_rate=4e-4) net_with_loss, optimizer_g, optimizer_c = amp.build_train_network( net_with_loss, [optimizer_g, optimizer_c], level="O2", loss_scale_manager=DynamicLossScaleManager() # 使用动态损失缩放 )启用AMP后,出现以下现象:
- 判别器的损失值在最初几次迭代后迅速变为一个极大的负数(例如-1e8),之后不再变化。
- 生成器的损失同样停滞。
- 生成的图片始终是噪声,没有任何学习迹象。
- 关键线索:在训练日志中,偶尔会出现
[WARNING] OVERFLOW!提示,但频率极低。
这表面上看像是梯度爆炸或消失,但在FP32下正常,说明问题与AMP的精度转换直接相关。
二、根因分析:梯度下溢与Loss Scale机制
混合精度训练的核心是用FP16做前向和反向传播,用FP32保存主权重。但FP16的取值范围(约 5.96e-8 ~ 65504)远小于FP32,在反向传播中,梯度值可能小于FP16能表示的最小正值,从而在转换为FP16时变为0,即梯度下溢。
MindSpore的AMP通过损失缩放(Loss Scaling) 来解决梯度下溢问题:在计算损失函数后,将其乘以一个较大的系数(如loss_scale=1024),等比例放大后续的梯度,使其避开FP16的下溢区。反向传播完成后,再将梯度除以相同的loss_scale,更新FP32权重。
我们的问题在于:WGAN-GP的梯度惩罚(Gradient Penalty)项计算,使得某些梯度分量变得极其微小,超出了默认LossScaleManager的处理能力。
- 梯度惩罚的计算: WGAN-GP需要在真实数据和生成数据的插值点处计算判别器输出的梯度范数。这个计算涉及二阶导,容易产生非常小的梯度值。
- 默认
DynamicLossScaleManager的行为: 它监控梯度是否溢出(Overflow,即梯度变为inf或nan)。如果发生溢出,则降低loss_scale;如果连续一段时间没有溢出,则提高loss_scale。但它对梯度下溢(Underflow)不敏感! 梯度下溢变为0,不会被识别为“溢出”,因此管理器不会主动调高loss_scale来应对。 - 下溢的后果: 当判别器某些层的梯度因下溢而变为0时,这些层的参数无法更新。判别器“局部瘫痪”,导致其提供不了有效的梯度信号给生成器,整个对抗训练过程失败。损失函数出现的巨大负值,可能是由数值不稳定或未更新的参数导致的异常计算。
三、诊断与定位:使用AMP调试模式
MindSpore AMP提供了调试接口,可以输出各算子的梯度统计信息,帮助我们定位下溢发生的具体位置。
# 方法1:在build_train_network时设置debug_level net_with_loss, optimizer_g, optimizer_c = amp.build_train_network( net_with_loss, [optimizer_g, optimizer_c], level="O2", loss_scale_manager=DynamicLossScaleManager(), # 启用调试,输出梯度信息 debug_level=1 # 或 2 获取更详细信息 ) # 方法2:在训练循环中,手动检查梯度 # 在自定义的训练步骤中,可以在计算梯度后,遍历参数查看 grads = amp.get_grads(net_with_loss, loss, optimizer_g.parameters) for grad in grads: if grad is not None: # 检查梯度中极小值的比例 if (grad.abs() < 1e-7).any(): print(f"发现极小梯度: {grad.name}, min={grad.min()}, max={grad.max()}")运行带有调试信息的训练,观察日志输出。可以发现在计算梯度惩罚项相关的反向传播路径中,某些Gradients的max和min值在FP16表示下已经接近于0,而同时loss_scale的值保持在一个较低水平(例如128)且长期不变。这证实了梯度下溢正在发生,而动态损失缩放管理器并未采取有效行动。
四、解决方案:自定义损失缩放与训练策略调整
我们需要一个更积极的策略来对抗梯度下溢。
方案一:定制更激进的DynamicLossScaleManager
默认的DynamicLossScaleManager对下溢不敏感。我们可以继承并重写其更新逻辑,将梯度幅值过小视为需要提高loss_scale的信号。
class CustomDynamicLossScaleManager(amp.DynamicLossScaleManager): def __init__(self, init_scale=2**24, scale_factor=2, scale_window=2000): super().__init__(init_scale, scale_factor, scale_window) self.gradient_norm_threshold_low = 1e-6 # 梯度范数下限,低于此值认为可能下溢 self.steps_since_last_scale = 0 def update_loss_scale(self, gradients): """ 重写更新逻辑,同时检测溢出和下溢 gradients: 当前迭代的梯度列表 """ # 1. 检查梯度溢出 (继承父类逻辑) is_overflow = self._check_overflow(gradients) # 假设有这个方法检查inf/nan if is_overflow: # 溢出,降低scale self.loss_scale = max(self.loss_scale / self.scale_factor, 1) self.steps_since_last_scale = 0 print(f"[OVERFLOW] Loss scale decreased to {self.loss_scale}") else: # 2. 检查梯度幅值是否过小 (新增逻辑) total_norm = 0.0 for grad in gradients: if grad is not None: total_norm += (grad ** 2).sum().asnumpy() # 计算梯度L2范数 total_norm = np.sqrt(total_norm) if total_norm < self.gradient_norm_threshold_low: # 梯度范数太小,可能下溢,提高scale self.loss_scale *= self.scale_factor self.steps_since_last_scale = 0 print(f"[UNDERFLOW RISK] Gradient norm {total_norm:.2e} is too low. Loss scale increased to {self.loss_scale}") else: # 正常,按窗口期递增 self.steps_since_last_scale += 1 if self.steps_since_last_scale >= self.scale_window: self.loss_scale *= self.scale_factor self.steps_since_last_scale = 0 print(f"[NORMAL] Loss scale increased to {self.loss_scale}") return is_overflow注意: 上述代码为概念演示。实际中需要更精细地获取梯度,并确保与MindSpore的Tensor格式兼容。核心思想是监控梯度范数,当其异常偏小时,主动提高loss_scale。
方案二:调整梯度惩罚计算与混合精度策略
有时,单独调整Loss Scale还不够,需要调整模型或训练策略。
- 在FP32下计算梯度惩罚: 这是最直接有效的方法。强制WGAN-GP损失函数中计算梯度范数的部分在FP32精度下进行,避免该敏感部分受FP16精度限制。
class WGANGPLossFP32Safe(nn.Cell): def construct(self, real_data, fake_data, critic_net): # ... 其他损失计算 ... # 插值点 alpha = ops.UniformReal()((real_data.shape[0], 1, 1, 1)) interpolates = alpha * real_data + (1 - alpha) * fake_data # 关键:将插值点转换为FP32再进行梯度计算 interpolates = ops.Cast()(interpolates, mstype.float32) # 计算判别器对插值点的输出 disc_interpolates = critic_net(interpolates) # 计算梯度(此处会自动在FP32下进行) gradients = ops.GradOperation()(disc_interpolates, interpolates) gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() # 将梯度惩罚项转换回与整体损失相同的精度 gradient_penalty = ops.Cast()(gradient_penalty, mstype.float16) # ... 合并损失 ...2. 使用amp.custom_mixed_precision进行更细粒度控制: 如果问题出在特定层(如LayerNorm),可以指定该层使用FP32计算。
from mindspore import amp # 指定某些cell使用FP32 network = amp.custom_mixed_precision(network, custom_white_list=[nn.LayerNorm, MySensitiveModule])方案三:使用更大的初始loss_scale并配合梯度裁剪
对于WGAN,梯度裁剪本身是稳定训练的标准操作。在AMP下,可以将其与较大的固定loss_scale结合。
# 使用较大的固定loss_scale,并启用梯度裁剪 loss_scale_manager = amp.FixedLossScaleManager(loss_scale=1024.0) # 或更大,如8192 # 在优化器中配置梯度裁剪 optimizer_g = nn.Adam(network.trainable_params(), learning_rate=1e-4, grad_clip=1.0) optimizer_c = nn.Adam(critic.trainable_params(), learning_rate=4e-4, grad_clip=1.0)较大的固定loss_scale可以抬升大部分梯度,避免下溢;梯度裁剪则可以防止因loss_scale过大导致的少数梯度爆炸。这是一种简单粗暴但往往有效的策略。