AI在时尚设计中的应用

随着人工智能技术的飞速发展,其在各行各业的应用越来越广泛。在时尚设计领域,AI技术同样展现出了巨大的潜力。本文将介绍一个基于人工智能的深度学习系统,该系统能够通过更好地理解消费者需求来革新时尚设计行业。

在本系列文章中,将展示一个由AI驱动的深度学习系统。将使用以下工具和库:

  • Jupyter Notebook 作为集成开发环境(IDE)
  • TensorFlow 2.0
  • NumPy
  • Matplotlib
  • DeepFashion 数据集的一个自定义子集,以减少计算和内存开销

假设已经熟悉深度学习的概念,以及Jupyter Notebooks 和 TensorFlow。如果是Jupyter Notebooks的新手,可以从这个教程开始。也可以下载项目代码。

训练GAN

生成器的训练是通过减少假图像和真实图像之间的损失和误差来完成的(log(D(x)) + log(D(G(z))))。将选择大量的迭代次数,因为这种网络需要多次迭代来减少真实和假图像之间的误差。将从40个迭代开始训练,并看看这会带来什么结果。将在自定义数据集上训练网络。参数和变量定义如下:

  • G_losses: 生成器损失,通过在训练生成器期间计算所有生成图像的损失总和
  • D_losses: 鉴别器损失,通过计算真实和假批次的所有损失总和
  • D(G(z)): 所有假批次的平均鉴别器输出
  • D(x): 鉴别器对所有真实批次的平均输出(跨批次)

# 跟踪进度的列表 img_list = [] G_losses = [] D_losses = [] iters = 0 print("Starting Training Loop...") for epoch in range(num_epochs): for i, data in enumerate(dataloader, 0): # 更新D网络:最大化 log(D(x)) + log(1 - D(G(z))) netD.zero_grad() real_cpu = data[0].to(device) b_size = real_cpu.size(0) label = torch.full((b_size,), real_label, dtype=torch.float, device=device) output = netD(real_cpu).view(-1) errD_real = criterion(output, label) errD_real.backward() D_x = output.mean().item() noise = torch.randn(b_size, nz, 1, 1, device=device) fake = netG(noise) label.fill_(fake_label) output = netD(fake.detach()).view(-1) errD_fake = criterion(output, label) errD_fake.backward() D_G_z1 = output.mean().item() errD = errD_real + errD_fake optimizerD.step() # 更新G网络:最大化 log(D(G(z))) netG.zero_grad() label.fill_(real_label) output = netD(fake).view(-1) errG = criterion(output, label) errG.backward() D_G_z2 = output.mean().item() optimizerG.step() if i % 50 == 0: print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) G_losses.append(errG.item()) D_losses.append(errD.item()) if iters % 500 == 0 or ((epoch == num_epochs-1) and (i == len(dataloader)-1)): with torch.no_grad(): fake = netG(fixed_noise).detach().cpu() img_list.append(vutils.make_grid(fake, padding=2, normalize=True)) iters += 1

如所见,在40个迭代后,所有假批次的平均鉴别器输出D(G(Z))降低到了一个非常吸引人的值。有了这个,GAN已经足够熟练地生成类似于数据集中的图像。如果想要更好的图像,需要增加迭代次数并重新训练。

还可以绘制训练期间生成器和鉴别器损失的图表。 plt.figure(figsize=(10, 5)) plt.title("Generator and Discriminator Loss During Training") plt.plot(G_losses, label="G") plt.plot(D_losses, label="D") plt.xlabel("iterations") plt.ylabel("Loss") plt.legend() plt.show()

在训练期间可视化生成的图像,PyTorch提供了一个函数,可以将生成的图像作为动画视频进行可视化。 %%capture fig = plt.figure(figsize=(8, 8)) plt.axis("off") ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list] ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True) HTML(ani.to_jshtml())

从训练好的GAN生成时尚图像

在GAN训练完成后,可以通过以下代码获取它生成的一批时尚图像。 real_batch = next(iter(dataloader)) plt.figure(figsize=(15, 15)) plt.subplot(1, 2, 1) plt.axis("off") plt.title("Real Images") plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(), (1, 2, 0))) plt.subplot(1, 2, 2) plt.axis("off") plt.title("Fake Images") plt.imshow(np.transpose(img_list[-1], (1, 2, 0))) plt.show()

看起来GAN能够生成一些与训练数据集中的图像相似的时尚图像。

以下是一些快速简单的方法来进一步提高GAN的性能:

  • 使用转置卷积或上采样层构建更深层次的生成器
  • 将生成器输入噪声的类型更改为高斯
  • 构建更深层次的鉴别器以提高其预测性能
  • 使用更多的迭代次数和更多的图像进行训练

沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485