在本文中,将构建一个能够执行无配对图像到图像翻译的循环一致对抗网络(CycleGAN),并展示一些既有趣又学术深度的例子。此外,还将讨论如何将使用TensorFlow和Keras构建的这种训练网络转换为TensorFlow Lite,并在移动设备上作为应用程序使用。
假设熟悉深度学习的概念,以及Jupyter Notebooks和TensorFlow。欢迎下载项目代码。
在之前的文章中,从头开始实现了一个CycleGAN。在本文中,将在horse2zebra数据集上训练和测试网络,并评估其性能。
是时候训练CycleGAN进行一些有趣的翻译了,比如马变斑马,反之亦然。将从设置一个检查点路径来保存最佳模型开始:
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(generator_g=generator_g,
generator_f=generator_f,
discriminator_x=discriminator_x,
discriminator_y=discriminator_y,
generator_g_optimizer=generator_g_optimizer,
generator_f_optimizer=generator_f_optimizer,
discriminator_x_optimizer=discriminator_x_optimizer,
discriminator_y_optimizer=discriminator_y_optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
print('Latest checkpoint restored!!')
首先,将训练20个周期,看看这是否足以获得可接受的结果。根据获得的结果,可能需要增加周期数。即使训练结果看起来不错,预测可能仍然不够准确。因此,80到100个周期更有可能获得完美的翻译,但这将需要超过3天的训练时间,除非使用的是具有非常高规格的系统或付费的云计算服务,如AWS或Microsoft Azure。
EPOCHS = 20
def generate_images(model, test_input):
prediction = model(test_input)
plt.figure(figsize=(12, 12))
display_list = [test_input[0], prediction[0]]
title = ['Input Image', 'Predicted Image']
for i in range(2):
plt.subplot(1, 2, i+1)
plt.title(title[i])
plt.imshow(display_list[i] * 0.5 + 0.5)
plt.axis('off')
plt.show()
上面的训练循环执行以下操作:
在训练过程中,网络将从训练集中随机选择一张图像,并显示其翻译版本,以便可视化每个周期后性能的变化,如下图所示。
一旦CycleGAN训练完成,就可以开始输入新图像并评估其在将马翻译成斑马以及反之亦然的性能。
让在数据集的图像上测试训练的CycleGAN,并可视化其泛化能力。将使用generate_images函数,它将挑选一些图像,通过训练好的网络,然后显示翻译结果。
def generate_images(model, test_input):
prediction = model(test_input)
plt.figure(figsize=(12, 12))
display_list = [test_input[0], prediction[0]]
title = ['Input Image', 'Predicted Image']
for i in range(2):
plt.subplot(1, 2, i+1)
plt.title(title[i])
plt.imshow(display_list[i] * 0.5 + 0.5)
plt.axis('off')
plt.show()
现在,可以选择任何测试图像并可视化翻译结果:
for inp in test_horses.take(5):
generate_images(generator_g, inp)
以下是在网络仅训练20个周期后获得的一些示例。对于如此短的训练来说,结果相当不错。可以通过增加更多周期来改善它们。
可以使用设计的网络执行不同的任务,比如白天到夜晚的转换或季节转换。为了训练网络进行季节转换,所需要做的就是将训练数据集更改为summer2winter。