高二的时候社团设计本来打算从底层搓一个GAN的,但是后来训练起来太慢了,就彻底咕咕了。现在突然想起来,所以就用pytouch重新写了一下。还有当初想用flet写的一个项目,这下融合到了一起,也算是填了高中的一个坑吧。
找不到当初在哪里学的GAN了,这里就随便贴一个详细点的解释:适合小白学习的GAN(生成对抗网络)算法超详细解读_gan网络-CSDN博客
(再贴一个当初学Adam梯度下降的视频:【论文必读#2: ADAM算法】史上最火梯度下降算法是如何炼成的?_哔哩哔哩_bilibili
简言之,GAN就是两个模型(正向卷积和反向卷积)的博弈和学习
这是开发中的kernel的一些常用大小:
参数 (kernel_size, stride, padding) 作用
(3,1,1) 保持输入尺寸不变(常见)
(3,2,1) 尺寸缩小 2×(常见)
(4,2,1) 尺寸缩小 2×,但影响范围稍大
(5,1,2) 尺寸不变,但感受野大一些
(5,2,2) 尺寸缩小 2×,感受野更大
在模型中用的便是(4,2,1)的卷积核,易于根据图片大小动态修改卷积次数,来适配各种大小图片
在训练中,生成器还容易崩溃,通过如下方法可以避免训练崩溃
(1) 降低判别器的学习率
- 试着降低
D
的学习率,比如lr=0.0001
,保持G
的学习率在0.0002
左右: optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
(2) 让 G 训练更多次
- 可以尝试让
G
训练 2~5 次,而D
训练 1 次: for _ in range(2): train_generator() # 让 G 训练两次
train_discriminator()
(3) 使用 Label Smoothing(标签平滑)
- 让
D
预测的真实样本标签不是1.0
而是0.9
: real_labels = torch.full((batch_size,), 0.9) fake_labels = torch.full((batch_size,), 0.0)
(4) 添加噪声
- 给
D
的输入增加高斯噪声,让D
不那么容易分辨真假: real_data += 0.05 * torch.randn_like(real_data)
(5) 使用 WGAN-GP
改用 Wasserstein GAN(WGAN)或者 WGAN-GP,可以减少梯度消失的问题:
loss = -torch.mean(D(real)) + torch.mean(D(fake)) # WGAN 损失
Comments | NOTHING