基于循环一致对抗网络的图像翻译系统

在本文中,将构建一个能够执行非成对图像到图像翻译的CycleGAN,并展示一些既有趣又具有学术深度的例子。还将讨论如何将使用TensorFlow和Keras构建的训练网络转换为TensorFlow Lite,并在移动设备上作为应用程序使用。

假设熟悉深度学习的概念,以及Jupyter Notebooks和TensorFlow。欢迎下载项目代码。

CycleGAN将使用马到斑马的数据集执行非成对图像到图像的翻译,可以下载该数据集。将使用TensorFlow和Keras实现网络,使用来自Pix2Pix库的生成器和鉴别器。将通过tensorflow_examples包导入生成器和鉴别器,以简化实现。然而,在随后的文章中,还将向展示如何从头开始构建新的生成器和鉴别器。

重要的是要提到CycleGAN是一个非常消耗计算能力和内存的网络。系统必须至少有8GB的RAM和至少与GTX 1660 Ti一样好的GPU,以训练和运行CycleGAN,而不会遇到内存不足错误或超时。

将使用Google Colab训练网络,这是一个托管的Jupyter Notebook服务,提供免费的计算资源访问,包括GPU。最重要的是,它是免费的,不像其他一些云计算服务。

处理数据集

让加载数据集并应用一些预处理技术,如裁剪、抖动和镜像,这将帮助避免网络的过拟合:

  • 图像抖动将图像大小调整为286x286像素,然后从随机选定的起点裁剪为256x256像素。
  • 图像镜像将图像水平翻转,从左到右。

上述技术在原始的CycleGAN论文中有描述。

将上传数据到Google Drive,使其可以被Google Colab访问。数据上传后,可以开始读取数据。或者,可以直接在代码中使用tfds.load来直接从TensorFlow数据集包中加载数据集,如下所示。

首先,让导入一些所需的依赖项:

Python import tensorflow as tf import tensorflow_datasets as tfds from tensorflow_examples.models.pix2pix import pix2pix import os import time import matplotlib.pyplot as plt from IPython.display import clear_output

现在将下载数据集并应用上述讨论的增强技术:

dataset, metadata = tfds.load( 'cycle_gan/horse2zebra', with_info=True, as_supervised=True ) train_horses, train_zebras = dataset['trainA'], dataset['trainB'] test_horses, test_zebras = dataset['testA'], dataset['testB']

数据加载后,让添加一些预处理函数:

def random_crop(image): cropped_image = tf.image.random_crop( image, size=[IMG_HEIGHT, IMG_WIDTH, 3] ) return cropped_image def normalize(image): image = tf.cast(image, tf.float32) image = (image / 127.5) - 1 return image def random_jitter(image): image = tf.image.resize(image, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) image = random_crop(image) image = tf.image.random_flip_left_right(image) return image def preprocess_image_train(image, label): image = random_jitter(image) image = normalize(image) return image def preprocess_image_test(image, label): image = normalize(image) return image

现在,将读取图像:

train_horses = train_horses.map( preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1) train_zebras = train_zebras.map( preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1) test_horses = test_horses.map( preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1) test_zebras = test_zebras.map( preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)

以下是抖动图像的一个示例:

sample_horse = next(iter(train_horses)) sample_zebra = next(iter(train_zebras)) plt.subplot(121) plt.title('Horse') plt.imshow(sample_horse[0] * 0.5 + 0.5) plt.subplot(122) plt.title('Horse with random mirroring') plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5) plt.subplot(121) plt.title('Zebra') plt.imshow(sample_horse[0] * 0.5 + 0.5) plt.subplot(122) plt.title('Zebra with random jitter') plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)

构建生成器和鉴别器

现在,从pix2pix模型中导入生成器和鉴别器。将使用基于U-Net的生成器,而不是CycleGAN论文中使用的残差块生成器。将使用U-Net,因为它的结构更简单,计算量比残差块少。然而,将在另一篇文章中发现基于残差块的生成器。

OUTPUT_CHANNELS = 3 generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm') generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm') discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False) discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

有了生成器和鉴别器,可以开始设置损失。由于CycleGAN是非成对图像到图像的翻译,因此在训练网络时不需要成对的数据。因此,没有人可以保证输入和目标图像在训练期间构成有意义的配对。这就是为什么计算循环一致性损失以确保网络正确映射是很重要的:

LAMBDA = 10 loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True) def discriminator_loss(real, generated): real_loss = loss_obj(tf.ones_like(real), real) generated_loss = loss_obj(tf.zeros_like(generated), generated) total_disc_loss = real_loss + generated_loss return total_disc_loss * 0.5 def generator_loss(generated): return loss_obj(tf.ones_like(generated), generated)

现在,计算循环一致性损失,以确保翻译结果接近原始图像:

def calc_cycle_loss(real_image, cycled_image): loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image)) return LAMBDA * loss1 def identity_loss(real_image, same_image): loss = tf.reduce_mean(tf.abs(real_image - same_image)) return LAMBDA * 0.5 * loss generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5) generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5) discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5) discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485