{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "v1CUZ0dkOo_F" }, "outputs": [], "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "qmkj-80IHxnd", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "_xnMOsbqHz61" }, "source": [ "# CycleGAN" ] }, { "cell_type": "markdown", "metadata": { "id": "Ds4o1h4WHz9U" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看在 Google Colab 中运行在 GitHub 上查看源代码下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "ITZuApL56Mny" }, "source": [ "本笔记演示了使用条件 GAN 进行的未配对图像到图像转换,如[使用循环一致的对抗网络进行未配对图像到图像转换](https://arxiv.org/abs/1703.10593) 中所述,也称之为 CycleGAN。论文提出了一种可以捕捉图像域特征并找出如何将这些特征转换为另一个图像域的方法,而无需任何成对的训练样本。\n", "\n", "本笔记假定您熟悉 Pix2Pix,您可以在 [Pix2Pix 教程](https://tensorflow.google.cn/tutorials/generative/pix2pix)中了解有关它的信息。CycleGAN 的代码与其相似,主要区别在于额外的损失函数,以及非配对训练数据的使用。\n", "\n", "CycleGAN 使用循环一致损失来使训练过程无需配对数据。换句话说,它可以从一个域转换到另一个域,而不需要在源域与目标域之间进行一对一映射。\n", "\n", "这为完成许多有趣的任务开辟了可能性,例如照片增强、图片着色、风格迁移等。您所需要的只是源数据集和目标数据集(仅仅是图片目录)\n", "\n", "![输出图像 1](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/horse2zebra_1.png?raw=1) ![输出图像 2](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/horse2zebra_2.png?raw=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "e1_Y75QXJS6h" }, "source": [ "## 设定输入管线" ] }, { "cell_type": "markdown", "metadata": { "id": "5fGHWOKPX4ta" }, "source": [ "安装 [tensorflow_examples](https://github.com/tensorflow/examples) 包,以导入生成器和判别器。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bJ1ROiQxJ-vY", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "!pip install git+https://github.com/tensorflow/examples.git" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lhSsUx9Nyb3t", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import tensorflow as tf" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YfIk2es3hJEd", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import tensorflow_datasets as tfds\n", "from tensorflow_examples.models.pix2pix import pix2pix\n", "\n", "import os\n", "import time\n", "import matplotlib.pyplot as plt\n", "from IPython.display import clear_output\n", "\n", "AUTOTUNE = tf.data.AUTOTUNE" ] }, { "cell_type": "markdown", "metadata": { "id": "iYn4MdZnKCey" }, "source": [ "## 输入管线\n", "\n", "本教程训练一个模型,以将普通马图片转换为斑马图片。您可以在[此处](https://tensorflow.google.cn/datasets/datasets#cycle_gan)获取该数据集以及类似数据集。\n", "\n", "如[论文](https://arxiv.org/abs/1703.10593)所述,将随机抖动和镜像应用到训练集。这是一些避免过拟合的图像增强技术。\n", "\n", "这类似于 [pix2pix](https://tensorflow.google.cn/tutorials/generative/pix2pix#load_the_dataset) 中所做的工作。\n", "\n", "- 在随机抖动中,图片大小调整为 `286 x 286`,随后被随机裁剪为 `256 x 256`。\n", "- 在随机镜像中,图像被水平(即从左到右)随机翻转。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iuGVPOo7Cce0", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "dataset, metadata = tfds.load('cycle_gan/horse2zebra',\n", " with_info=True, as_supervised=True)\n", "\n", "train_horses, train_zebras = dataset['trainA'], dataset['trainB']\n", "test_horses, test_zebras = dataset['testA'], dataset['testB']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2CbTEt448b4R", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "BUFFER_SIZE = 1000\n", "BATCH_SIZE = 1\n", "IMG_WIDTH = 256\n", "IMG_HEIGHT = 256" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Yn3IwqhiIszt", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def random_crop(image):\n", " cropped_image = tf.image.random_crop(\n", " image, size=[IMG_HEIGHT, IMG_WIDTH, 3])\n", "\n", " return cropped_image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "muhR2cgbLKWW", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# normalizing the images to [-1, 1]\n", "def normalize(image):\n", " image = tf.cast(image, tf.float32)\n", " image = (image / 127.5) - 1\n", " return image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fVQOjcPVLrUc", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def random_jitter(image):\n", " # resizing to 286 x 286 x 3\n", " image = tf.image.resize(image, [286, 286],\n", " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", "\n", " # randomly cropping to 256 x 256 x 3\n", " image = random_crop(image)\n", "\n", " # random mirroring\n", " image = tf.image.random_flip_left_right(image)\n", "\n", " return image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tyaP4hLJ8b4W", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def preprocess_image_train(image, label):\n", " image = random_jitter(image)\n", " image = normalize(image)\n", " return image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VB3Z6D_zKSru", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def preprocess_image_test(image, label):\n", " image = normalize(image)\n", " return image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RsajGXxd5JkZ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_horses = train_horses.cache().map(\n", " preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(\n", " BUFFER_SIZE).batch(BATCH_SIZE)\n", "\n", "train_zebras = train_zebras.cache().map(\n", " preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(\n", " BUFFER_SIZE).batch(BATCH_SIZE)\n", "\n", "test_horses = test_horses.map(\n", " preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(\n", " BUFFER_SIZE).batch(BATCH_SIZE)\n", "\n", "test_zebras = test_zebras.map(\n", " preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(\n", " BUFFER_SIZE).batch(BATCH_SIZE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e3MhJ3zVLPan", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "sample_horse = next(iter(train_horses))\n", "sample_zebra = next(iter(train_zebras))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4pOYjMk_KfIB", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "plt.subplot(121)\n", "plt.title('Horse')\n", "plt.imshow(sample_horse[0] * 0.5 + 0.5)\n", "\n", "plt.subplot(122)\n", "plt.title('Horse with random jitter')\n", "plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0KJyB9ENLb2y", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "plt.subplot(121)\n", "plt.title('Zebra')\n", "plt.imshow(sample_zebra[0] * 0.5 + 0.5)\n", "\n", "plt.subplot(122)\n", "plt.title('Zebra with random jitter')\n", "plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)" ] }, { "cell_type": "markdown", "metadata": { "id": "hvX8sKsfMaio" }, "source": [ "## 导入并重用 Pix2Pix 模型" ] }, { "cell_type": "markdown", "metadata": { "id": "cGrL73uCd-_M" }, "source": [ "通过安装的 [tensorflow_examples](https://github.com/tensorflow/examples) 包导入 [Pix2Pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) 中的生成器和判别器。\n", "\n", "本教程中使用模型体系结构与 [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) 中所使用的非常相似。一些区别在于:\n", "\n", "- Cyclegan 使用 [instance normalization(实例归一化)](https://arxiv.org/abs/1607.08022)而不是 [batch normalization (批归一化)](https://arxiv.org/abs/1502.03167)。\n", "- [CycleGAN 论文](https://arxiv.org/abs/1703.10593)使用一种基于 `resnet` 的改进生成器。简单起见,本教程使用的是改进的 `unet` 生成器。\n", "\n", "这里训练了两个生成器(G 和 F)以及两个判别器(X 和 Y)。\n", "\n", "- 生成器 `G` 学习将图片 `X` 转换为 `Y`。 $(G: X -> Y)$\n", "- 生成器 `F` 学习将图片 `Y` 转换为 `X`。 $(F: Y -> X)$\n", "- 判别器 `D_X` 学习区分图片 `X` 与生成的图片 `X` (`F(Y)`)。\n", "- 判别器 `D_Y` 学习区分图片 `Y` 与生成的图片 `Y` (`G(X)`)。\n", "\n", "![Cyclegan 模型](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/cyclegan_model.png?raw=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8ju9Wyw87MRW", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "OUTPUT_CHANNELS = 3\n", "\n", "generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')\n", "generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')\n", "\n", "discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)\n", "discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wDaGZ3WpZUyw", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "to_zebra = generator_g(sample_horse)\n", "to_horse = generator_f(sample_zebra)\n", "plt.figure(figsize=(8, 8))\n", "contrast = 8\n", "\n", "imgs = [sample_horse, to_zebra, sample_zebra, to_horse]\n", "title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']\n", "\n", "for i in range(len(imgs)):\n", " plt.subplot(2, 2, i+1)\n", " plt.title(title[i])\n", " if i % 2 == 0:\n", " plt.imshow(imgs[i][0] * 0.5 + 0.5)\n", " else:\n", " plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O5MhJmxyZiy9", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "plt.figure(figsize=(8, 8))\n", "\n", "plt.subplot(121)\n", "plt.title('Is a real zebra?')\n", "plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')\n", "\n", "plt.subplot(122)\n", "plt.title('Is a real horse?')\n", "plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "0FMYgY_mPfTi" }, "source": [ "## 损失函数" ] }, { "cell_type": "markdown", "metadata": { "id": "JRqt02lupRn8" }, "source": [ "在 CycleGAN 中,没有可训练的成对数据,因此无法保证输入 `x` 和 目标 `y` 数据对在训练期间是有意义的。所以为了强制网络学习正确的映射,作者提出了循环一致损失。\n", "\n", "判别器损失和生成器损失和 [pix2pix](https://google.tensorflow.cn/tutorials/generative/pix2pix#define_the_loss_functions_and_the_optimizer) 中所使用的类似。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cyhxTuvJyIHV", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "LAMBDA = 10" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Q1Xbz5OaLj5C", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wkMNfBWlT-PV", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def discriminator_loss(real, generated):\n", " real_loss = loss_obj(tf.ones_like(real), real)\n", "\n", " generated_loss = loss_obj(tf.zeros_like(generated), generated)\n", "\n", " total_disc_loss = real_loss + generated_loss\n", "\n", " return total_disc_loss * 0.5" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "90BIcCKcDMxz", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def generator_loss(generated):\n", " return loss_obj(tf.ones_like(generated), generated)" ] }, { "cell_type": "markdown", "metadata": { "id": "5iIWQzVF7f9e" }, "source": [ "循环一致意味着结果应接近原始输出。例如,将一句英文译为法文,随后再从法文翻译回英文,最终的结果句应与原始句输入相同。\n", "\n", "在循环一致损失中,\n", "\n", "- 图片 $X$ 通过生成器 $G$ 传递,该生成器生成图片 $\\hat{Y}$。\n", "- 生成的图片 $\\hat{Y}$ 通过生成器 $F$ 传递,循环生成图片 $\\hat{X}$。\n", "- 在 $X$ 和 $\\hat{X}$ 之间计算平均绝对误差。\n", "\n", "```\n", "$$forward\\ cycle\\ consistency\\ loss: X -> G(X) -> F(G(X)) \\sim \\hat{X}$$\n", "```\n", "\n", "```\n", "$$backward\\ cycle\\ consistency\\ loss: Y -> F(Y) -> G(F(Y)) \\sim \\hat{Y}$$\n", "```\n", "\n", "![循环损失](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/cycle_loss.png?raw=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NMpVGj_sW6Vo", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def calc_cycle_loss(real_image, cycled_image):\n", " loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))\n", " \n", " return LAMBDA * loss1" ] }, { "cell_type": "markdown", "metadata": { "id": "U-tJL-fX0Mq7" }, "source": [ "如上所示,生成器 $G$ 负责将图片 $X$ 转换为 $Y$。一致性损失表明,如果您将图片 $Y$ 馈送给生成器 $G$,它应当生成真实图片 $Y$ 或接近于 $Y$ 的图片。\n", "\n", "如果您在马上运行斑马到马的模型或在斑马上运行马到斑马的模型,那么它不会对图像进行太多修改,因为图像已包含目标类。\n", "\n", "```\n", "$$Identity\\ loss = |G(Y) - Y| + |F(X) - X|$$\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "05ywEH680Aud", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def identity_loss(real_image, same_image):\n", " loss = tf.reduce_mean(tf.abs(real_image - same_image))\n", " return LAMBDA * 0.5 * loss" ] }, { "cell_type": "markdown", "metadata": { "id": "G-vjRM7IffTT" }, "source": [ "为所有生成器和判别器初始化优化器。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iWCn_PVdEJZ7", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n", "generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n", "\n", "discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n", "discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)" ] }, { "cell_type": "markdown", "metadata": { "id": "aKUZnDiqQrAh" }, "source": [ "## Checkpoints" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WJnftd5sQsv6", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "checkpoint_path = \"./checkpoints/train\"\n", "\n", "ckpt = tf.train.Checkpoint(generator_g=generator_g,\n", " generator_f=generator_f,\n", " discriminator_x=discriminator_x,\n", " discriminator_y=discriminator_y,\n", " generator_g_optimizer=generator_g_optimizer,\n", " generator_f_optimizer=generator_f_optimizer,\n", " discriminator_x_optimizer=discriminator_x_optimizer,\n", " discriminator_y_optimizer=discriminator_y_optimizer)\n", "\n", "ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)\n", "\n", "# if a checkpoint exists, restore the latest checkpoint.\n", "if ckpt_manager.latest_checkpoint:\n", " ckpt.restore(ckpt_manager.latest_checkpoint)\n", " print ('Latest checkpoint restored!!')" ] }, { "cell_type": "markdown", "metadata": { "id": "Rw1fkAczTQYh" }, "source": [ "## 训练\n", "\n", "注:此示例模型的训练周期 (10) 少于论文 (200),以保持本教程的训练时间合理。生成的图像质量会低得多。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NS2GWywBbAWo", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "EPOCHS = 10" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RmdVsmvhPxyy", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def generate_images(model, test_input):\n", " prediction = model(test_input)\n", " \n", " plt.figure(figsize=(12, 12))\n", "\n", " display_list = [test_input[0], prediction[0]]\n", " title = ['Input Image', 'Predicted Image']\n", "\n", " for i in range(2):\n", " plt.subplot(1, 2, i+1)\n", " plt.title(title[i])\n", " # getting the pixel values between [0, 1] to plot it.\n", " plt.imshow(display_list[i] * 0.5 + 0.5)\n", " plt.axis('off')\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "kE47ERn5fyLC" }, "source": [ "尽管训练循环看起来很复杂,其实包含四个基本步骤:\n", "\n", "- 获取预测。\n", "- 计算损失值。\n", "- 使用反向传播计算损失值。\n", "- 将梯度应用于优化器。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KBKUV2sKXDbY", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "@tf.function\n", "def train_step(real_x, real_y):\n", " # persistent is set to True because the tape is used more than\n", " # once to calculate the gradients.\n", " with tf.GradientTape(persistent=True) as tape:\n", " # Generator G translates X -> Y\n", " # Generator F translates Y -> X.\n", " \n", " fake_y = generator_g(real_x, training=True)\n", " cycled_x = generator_f(fake_y, training=True)\n", "\n", " fake_x = generator_f(real_y, training=True)\n", " cycled_y = generator_g(fake_x, training=True)\n", "\n", " # same_x and same_y are used for identity loss.\n", " same_x = generator_f(real_x, training=True)\n", " same_y = generator_g(real_y, training=True)\n", "\n", " disc_real_x = discriminator_x(real_x, training=True)\n", " disc_real_y = discriminator_y(real_y, training=True)\n", "\n", " disc_fake_x = discriminator_x(fake_x, training=True)\n", " disc_fake_y = discriminator_y(fake_y, training=True)\n", "\n", " # calculate the loss\n", " gen_g_loss = generator_loss(disc_fake_y)\n", " gen_f_loss = generator_loss(disc_fake_x)\n", " \n", " total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)\n", " \n", " # Total generator loss = adversarial loss + cycle loss\n", " total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)\n", " total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)\n", "\n", " disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)\n", " disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)\n", " \n", " # Calculate the gradients for generator and discriminator\n", " generator_g_gradients = tape.gradient(total_gen_g_loss, \n", " generator_g.trainable_variables)\n", " generator_f_gradients = tape.gradient(total_gen_f_loss, \n", " generator_f.trainable_variables)\n", " \n", " discriminator_x_gradients = tape.gradient(disc_x_loss, \n", " discriminator_x.trainable_variables)\n", " discriminator_y_gradients = tape.gradient(disc_y_loss, \n", " discriminator_y.trainable_variables)\n", " \n", " # Apply the gradients to the optimizer\n", " generator_g_optimizer.apply_gradients(zip(generator_g_gradients, \n", " generator_g.trainable_variables))\n", "\n", " generator_f_optimizer.apply_gradients(zip(generator_f_gradients, \n", " generator_f.trainable_variables))\n", " \n", " discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,\n", " discriminator_x.trainable_variables))\n", " \n", " discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,\n", " discriminator_y.trainable_variables))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2M7LmLtGEMQJ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for epoch in range(EPOCHS):\n", " start = time.time()\n", "\n", " n = 0\n", " for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):\n", " train_step(image_x, image_y)\n", " if n % 10 == 0:\n", " print ('.', end='')\n", " n += 1\n", "\n", " clear_output(wait=True)\n", " # Using a consistent image (sample_horse) so that the progress of the model\n", " # is clearly visible.\n", " generate_images(generator_g, sample_horse)\n", "\n", " if (epoch + 1) % 5 == 0:\n", " ckpt_save_path = ckpt_manager.save()\n", " print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,\n", " ckpt_save_path))\n", "\n", " print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n", " time.time()-start))" ] }, { "cell_type": "markdown", "metadata": { "id": "1RGysMU_BZhx" }, "source": [ "## 使用测试数据集进行生成" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KUgSnmy2nqSP", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Run the trained model on the test dataset\n", "for inp in test_horses.take(5):\n", " generate_images(generator_g, inp)" ] }, { "cell_type": "markdown", "metadata": { "id": "ABGiHY6fE02b" }, "source": [ "## 下一步\n", "\n", "本教程展示了如何从 [Pix2Pix](https://tensorflow.google.cn/tutorials/generative/pix2pix) 教程实现的生成器和判别器开始实现 CycleGAN。 下一步,您可以尝试使用一个来源于 [TensorFlow 数据集](https://tensorflow.google.cn/datasets/datasets#cycle_gan)的不同的数据集。\n", "\n", "您也可以训练更多的 epoch 以改进结果,或者可以实现[论文](https://arxiv.org/abs/1703.10593)中所使用的改进 ResNet 生成器来代替这里使用的 U-Net 生成器。" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "cyclegan.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }