随着人工智能技术的飞速发展,其在各行各业的应用越来越广泛。在时尚设计领域,AI技术同样展现出了巨大的潜力。本文将介绍一个基于人工智能的深度学习系统,该系统能够通过更好地理解消费者需求来革新时尚设计行业。
在本系列文章中,将展示一个由AI驱动的深度学习系统。将使用以下工具和库:
假设已经熟悉深度学习的概念,以及Jupyter Notebooks 和 TensorFlow。如果是Jupyter Notebooks的新手,可以从这个教程开始。也可以下载项目代码。
生成器的训练是通过减少假图像和真实图像之间的损失和误差来完成的(log(D(x)) + log(D(G(z))))。将选择大量的迭代次数,因为这种网络需要多次迭代来减少真实和假图像之间的误差。将从40个迭代开始训练,并看看这会带来什么结果。将在自定义数据集上训练网络。参数和变量定义如下:
# 跟踪进度的列表
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
for epoch in range(num_epochs):
for i, data in enumerate(dataloader, 0):
# 更新D网络:最大化 log(D(x)) + log(1 - D(G(z)))
netD.zero_grad()
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
output = netD(real_cpu).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
D_x = output.mean().item()
noise = torch.randn(b_size, nz, 1, 1, device=device)
fake = netG(noise)
label.fill_(fake_label)
output = netD(fake.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
optimizerD.step()
# 更新G网络:最大化 log(D(G(z)))
netG.zero_grad()
label.fill_(real_label)
output = netD(fake).view(-1)
errG = criterion(output, label)
errG.backward()
D_G_z2 = output.mean().item()
optimizerG.step()
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
G_losses.append(errG.item())
D_losses.append(errD.item())
if iters % 500 == 0 or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
如所见,在40个迭代后,所有假批次的平均鉴别器输出D(G(Z))降低到了一个非常吸引人的值。有了这个,GAN已经足够熟练地生成类似于数据集中的图像。如果想要更好的图像,需要增加迭代次数并重新训练。
还可以绘制训练期间生成器和鉴别器损失的图表。
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
在训练期间可视化生成的图像,PyTorch提供了一个函数,可以将生成的图像作为动画视频进行可视化。
%%capture
fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())
在GAN训练完成后,可以通过以下代码获取它生成的一批时尚图像。
real_batch = next(iter(dataloader))
plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(), (1, 2, 0)))
plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
plt.show()
看起来GAN能够生成一些与训练数据集中的图像相似的时尚图像。
以下是一些快速简单的方法来进一步提高GAN的性能: