在本文中,将构建一个能够执行非成对图像到图像翻译的CycleGAN,并展示一些既有趣又具有学术深度的例子。还将讨论如何将使用TensorFlow和Keras构建的训练网络转换为TensorFlow Lite,并在移动设备上作为应用程序使用。
假设熟悉深度学习的概念,以及Jupyter Notebooks和TensorFlow。欢迎下载项目代码。
CycleGAN将使用马到斑马的数据集执行非成对图像到图像的翻译,可以下载该数据集。将使用TensorFlow和Keras实现网络,使用来自Pix2Pix库的生成器和鉴别器。将通过tensorflow_examples包导入生成器和鉴别器,以简化实现。然而,在随后的文章中,还将向展示如何从头开始构建新的生成器和鉴别器。
重要的是要提到CycleGAN是一个非常消耗计算能力和内存的网络。系统必须至少有8GB的RAM和至少与GTX 1660 Ti一样好的GPU,以训练和运行CycleGAN,而不会遇到内存不足错误或超时。
将使用Google Colab训练网络,这是一个托管的Jupyter Notebook服务,提供免费的计算资源访问,包括GPU。最重要的是,它是免费的,不像其他一些云计算服务。
让加载数据集并应用一些预处理技术,如裁剪、抖动和镜像,这将帮助避免网络的过拟合:
上述技术在原始的CycleGAN论文中有描述。
将上传数据到Google Drive,使其可以被Google Colab访问。数据上传后,可以开始读取数据。或者,可以直接在代码中使用tfds.load来直接从TensorFlow数据集包中加载数据集,如下所示。
首先,让导入一些所需的依赖项:
Python
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
现在将下载数据集并应用上述讨论的增强技术:
dataset, metadata = tfds.load(
'cycle_gan/horse2zebra',
with_info=True, as_supervised=True
)
train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
数据加载后,让添加一些预处理函数:
def random_crop(image):
cropped_image = tf.image.random_crop(
image, size=[IMG_HEIGHT, IMG_WIDTH, 3]
)
return cropped_image
def normalize(image):
image = tf.cast(image, tf.float32)
image = (image / 127.5) - 1
return image
def random_jitter(image):
image = tf.image.resize(image, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
image = random_crop(image)
image = tf.image.random_flip_left_right(image)
return image
def preprocess_image_train(image, label):
image = random_jitter(image)
image = normalize(image)
return image
def preprocess_image_test(image, label):
image = normalize(image)
return image
现在,将读取图像:
train_horses = train_horses.map(
preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
train_zebras = train_zebras.map(
preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
test_horses = test_horses.map(
preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
test_zebras = test_zebras.map(
preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
以下是抖动图像的一个示例:
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)
plt.subplot(122)
plt.title('Horse with random mirroring')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_horse[0] * 0.5 + 0.5)
plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
现在,从pix2pix模型中导入生成器和鉴别器。将使用基于U-Net的生成器,而不是CycleGAN论文中使用的残差块生成器。将使用U-Net,因为它的结构更简单,计算量比残差块少。然而,将在另一篇文章中发现基于残差块的生成器。
OUTPUT_CHANNELS = 3
generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
有了生成器和鉴别器,可以开始设置损失。由于CycleGAN是非成对图像到图像的翻译,因此在训练网络时不需要成对的数据。因此,没有人可以保证输入和目标图像在训练期间构成有意义的配对。这就是为什么计算循环一致性损失以确保网络正确映射是很重要的:
LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
real_loss = loss_obj(tf.ones_like(real), real)
generated_loss = loss_obj(tf.zeros_like(generated), generated)
total_disc_loss = real_loss + generated_loss
return total_disc_loss * 0.5
def generator_loss(generated):
return loss_obj(tf.ones_like(generated), generated)
现在,计算循环一致性损失,以确保翻译结果接近原始图像:
def calc_cycle_loss(real_image, cycled_image):
loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
return LAMBDA * loss1
def identity_loss(real_image, same_image):
loss = tf.reduce_mean(tf.abs(real_image - same_image))
return LAMBDA * 0.5 * loss
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)