首页 理论教育 深度合成:GAN的训练

深度合成:GAN的训练

时间:2023-11-18 理论教育 版权反馈
【摘要】:利用GAN生成手写数字的训练流程如下。1]batch_size个g_loss=self.combined.train_on_batchgan=GAN()gan. train在代码中同样可以看到训练分为两个步骤,首先训练判别器,这里通过把MNIST数据集中的真实图像及对应标签1和生成图像及对应标签0输入训练判别器网络,代码中设置half_batch=64,表明每次随机选择64张真实图像和64张生成图像。因此,GAN的训练极其困难,想要让GAN正常运行,需要对模型架构和超参数进行仔细设计,而且训练的时间成本非常高。

深度合成:GAN的训练

利用GAN生成手写数字的训练流程如下。训练时需要MNIST图片库,代表真实的手写数字图像让GAN进行参考学习,这里使用G代表生成器网络,D代表判别器网络,E代表数学期望:

第一步,对生成器和判别器的参数进行初始化

第二步,循环迭代:

(1)固定生成器的参数,训练更新判别器的参数:

①在MNIST图片样本库中采样n个样本X;

②在先验分步(如正态分布)中采样n个样本Z;

③将样本Z,送入生成器得到生成图像Y;

④用梯度上升法更新参数,优化目标为E{logD(X)}+E{log(1-D(Y))}最大。

(2)固定判别器的参数,训练更新生成器的参数:

①重新生成n个样本Z’;

②利用梯度下降法更新生成器,优化目标为E{logD(G(Z’))}最小。

def train(self, epochs, batch_size=128,sample_interval=50):

#加载数据集,训练集矩阵,训练集标签,测试集矩阵,测试集标签,

(x_train,_),(_,_)=mnist.load_data()

#将像素值归整到-1到1之间

x_train=(x_train.astype(np.float32)-127.5)/127.5

x_train=np.expand_dims(x_train, axis=3)#扩展维度>>(60000,28,28,1)

half_batch=int(batch_size/2)#64=128/2

#创建标签

valid=np.ones((batch_size,1))

fake=np.zeros((batch_size,1))

for epoch in range(epochs):

#训练判别器,随机选择半批图像

#在0~60000之间,随机生成64个数构成列表索引

idx=np.random.randint(0,x_train.shape[0],half_batch)

#随机选择图片

imgs=x_train[idx]

noise=np.random.normal(0,1,(half_batch,100))

#生成半批新的图像

gen_imgs=self.generator.predict(noise)

#训练判别器

#手动将一个个batch的数据送入网络中训练(www.xing528.com)

d_loss_real=self.discriminator.train_on_batch(imgs, valid)

d_loss_fake=self.discriminator.train_on_batch(gen_imgs, fake)

d_loss=0.5∗np.add(d_loss_real, d_loss_fake)

#训练生成器

noise=np.random.normal(0,1,(batch_size,100))

#生成器希望判别器将生成的样本标记为1

valid_y=np.array([1]∗batch_size)#[1,1,1……1]batch_size个

g_loss=self.combined.train_on_batch(noise, valid_y)

gan=GAN()

gan. train(epochs=30000,batch_size=256,sample_interval=200)

在代码中同样可以看到训练分为两个步骤,首先训练判别器,这里通过把MNIST数据集中的真实图像及对应标签1和生成图像及对应标签0输入训练判别器网络,代码中设置half_batch=64,表明每次随机选择64张真实图像和64张生成图像。训练判别器的目的是让它学会准确区分真实图像和生成图像。

在训练生成器过程中,由于需要使用判别器模型,因此采用对抗网络(combine)进行训练,此时把判别器模型冻结,让生成网络伪造一批图像,把它作为真实图像输入对抗网络,并由此输入判别器网络得到识别结果。

值得注意的是,GAN优化的损失函数目标值是不固定的,它是一个动态变化的系统,其最优化过程寻找的不是一个最小值,而是判别能力和生成能力之间的平衡。因此,GAN的训练极其困难,想要让GAN正常运行,需要对模型架构和超参数进行仔细设计,而且训练的时间成本非常高。

为及时查看训练过程中生成图像的质量,可以利用以下代码把生成器构造的图像保存绘制出来,当满足需要时也可提前终止网络训练。

def sample_images(self, epoch):

r, c=5,5

noise=np.random.normal(0,1,(r∗c, self.latent_dim))

gen_imgs=self.generator.predict(noise)

gen_imgs=0.5∗gen_imgs+0.5#-1到1归整到0-1

fig, axs=plt.subplots(r, c)

cnt=0

for i in range(r):

for j in range(c):

axs[i, j]. imshow(gen_imgs[cnt,:,:,0],cmap='gray')

axs[i, j]. axis('off')

cnt+=1

fig. savefig("images/%d.png"%epoch)

plt. close()

图7-7列出了不同迭代次数时的生成图像。

图7-7 不同迭代次数的生成图像

可以看到迭代500次后各图像还是无法辨认,当迭代到7000次时,生成器生成的手写数字图片已经很形象了。当生成器持续迭代改进后,生成的图像质量会越来越好,迭代到40000次时,生成的图片已经令人无法辨识真假了。

免责声明:以上内容源自网络,版权归原作者所有,如有侵犯您的原创版权请告知,我们将尽快删除相关内容。

我要反馈