高二的时候社团设计本来打算从底层搓一个GAN的,但是后来训练起来太慢了,就彻底咕咕了。现在突然想起来,所以就用pytouch重新写了一下。还有当初想用flet写的一个项目,这下融合到了一起,也算是填了高中的一个坑吧。

找不到当初在哪里学的GAN了,这里就随便贴一个详细点的解释:适合小白学习的GAN(生成对抗网络)算法超详细解读_gan网络-CSDN博客

(再贴一个当初学Adam梯度下降的视频:【论文必读#2: ADAM算法】史上最火梯度下降算法是如何炼成的?_哔哩哔哩_bilibili

简言之,GAN就是两个模型(正向卷积和反向卷积)的博弈和学习

这是开发中的kernel的一些常用大小:

参数 (kernel_size, stride, padding) 作用
(3,1,1) 保持输入尺寸不变(常见)
(3,2,1) 尺寸缩小 2×(常见)
(4,2,1) 尺寸缩小 2×,但影响范围稍大
(5,1,2) 尺寸不变,但感受野大一些
(5,2,2) 尺寸缩小 2×,感受野更大

在模型中用的便是(4,2,1)的卷积核,易于根据图片大小动态修改卷积次数,来适配各种大小图片

在训练中,生成器还容易崩溃,通过如下方法可以避免训练崩溃
(1) 降低判别器的学习率

  • 试着降低 D 的学习率,比如 lr=0.0001,保持 G 的学习率在 0.0002 左右:
  • optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))

(2) 让 G 训练更多次

  • 可以尝试让 G 训练 2~5 次,而 D 训练 1 次:
  • for _ in range(2): train_generator() # 让 G 训练两次
  • train_discriminator()

(3) 使用 Label Smoothing(标签平滑)

  • D 预测的真实样本标签不是 1.0 而是 0.9
  • real_labels = torch.full((batch_size,), 0.9) fake_labels = torch.full((batch_size,), 0.0)

(4) 添加噪声

  • D 的输入增加高斯噪声,让 D 不那么容易分辨真假:
  • real_data += 0.05 * torch.randn_like(real_data)

(5) 使用 WGAN-GP

改用 Wasserstein GAN(WGAN)或者 WGAN-GP,可以减少梯度消失的问题:

loss = -torch.mean(D(real)) + torch.mean(D(fake)) # WGAN 损失


一沙一世界,一花一天堂。君掌盛无边,刹那成永恒。