{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "v1CUZ0dkOo_F" }, "outputs": [], "source": [ "##### Copyright 2019 The TensorFlow Authors.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "qmkj-80IHxnd", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "_xnMOsbqHz61" }, "source": [ "# pix2pix:使用条件 GAN 进行图像到图像的转换" ] }, { "cell_type": "markdown", "metadata": { "id": "Ds4o1h4WHz9U" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "ITZuApL56Mny" }, "source": [ "本教程演示了如何构建和训练一个名为 pix2pix 的条件生成对抗网络 (cGAN),该网络学习从输入图像到输出图像的映射,如 Isola 等人在 [Image-to-image translation with conditional adversarial networks](https://arxiv.org/abs/1611.07004){:.external} (2017 年)中所述 。pix2pix 非特定于应用,它可以应用于多种任务,包括从标签地图合成照片,从黑白图像生成彩色照片,将 Google Maps 照片转换为航拍图像,甚至将草图转换为照片。\n", "\n", "在此示例中,您的网络将使用[布拉格捷克理工大学](http://cmp.felk.cvut.cz/~tylecr1/facade/){:.external}的[机器感知中心](http://cmp.felk.cvut.cz/){:.external}提供的 [CMP Facade Database](https://www.cvut.cz/) 来生成建筑立面。为了简化示例,您将使用由 pix2pix 作者创建的此数据集的[预处理副本](https://efrosgans.eecs.berkeley.edu/pix2pix/datasets/){:.external}。\n", "\n", "在 pix2pix cGAN 中,您可以对输入图像进行调节并生成相应的输出图像。cGAN 最初在 [Conditional Generative Adversarial Nets](https://arxiv.org/abs/1411.1784) (Mirza and Osindero, 2014) 中提出。\n", "\n", "您的网络架构将包含:\n", "\n", "- 基于 [U-Net](https://arxiv.org/abs/1505.04597){:.external} 架构的生成器。\n", "- 由卷积 PatchGAN 分类器表示的判别器(在 [pix2pix 论文](https://arxiv.org/abs/1611.07004){:.external}中提出)。\n", "\n", "请注意,在单个 V100 GPU 上,每个周期可能需要大约 15 秒。\n", "\n", "以下是 pix2pix cGAN 在 Facade Database(8 万步)上训练 200 个周期后生成的一些输出示例。\n", "\n", "![sample output_1](https://tensorflow.google.cn/images/gan/pix2pix_1.png) ![sample output_2](https://tensorflow.google.cn/images/gan/pix2pix_2.png)" ] }, { "cell_type": "markdown", "metadata": { "id": "e1_Y75QXJS6h" }, "source": [ "## 导入 TensorFlow 和其他库" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YfIk2es3hJEd", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "import os\n", "import pathlib\n", "import time\n", "import datetime\n", "\n", "from matplotlib import pyplot as plt\n", "from IPython import display" ] }, { "cell_type": "markdown", "metadata": { "id": "iYn4MdZnKCey" }, "source": [ "## 加载数据集\n", "\n", "下载 CMP Facade Database 数据 (30MB)。可在[这里](http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/){:.external}以相同格式获得其他数据集。在 Colab 中,您可以从下拉菜单中选择其他数据集。请注意,其他一些数据集要大得多(`edges2handbags` 为 8GB)。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qp6IAZvEShNf", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "dataset_name = \"facades\" #@param [\"cityscapes\", \"edges2handbags\", \"edges2shoes\", \"facades\", \"maps\", \"night2day\"]\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Kn-k8kTXuAlv", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "_URL = f'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{dataset_name}.tar.gz'\n", "\n", "path_to_zip = tf.keras.utils.get_file(\n", " fname=f\"{dataset_name}.tar.gz\",\n", " origin=_URL,\n", " extract=True)\n", "\n", "path_to_zip = pathlib.Path(path_to_zip)\n", "\n", "PATH = path_to_zip.parent/dataset_name" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V67lt3BFb2iN", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "list(PATH.parent.iterdir())" ] }, { "cell_type": "markdown", "metadata": { "id": "1fUzsnerj1P3" }, "source": [ "每个原始图像的大小为 `256 x 512`,包含两个 `256 x 256` 图像:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XGY1kiptguTQ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "sample_image = tf.io.read_file(str(PATH / 'train/1.jpg'))\n", "sample_image = tf.io.decode_jpeg(sample_image)\n", "print(sample_image.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vJ2sO8Izg7QV", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "plt.figure()\n", "plt.imshow(sample_image)" ] }, { "cell_type": "markdown", "metadata": { "id": "2A5SU-qxPAqd" }, "source": [ "您需要将真实的建筑立面图像与建筑标签图像分开,所有这些图像的大小都是 `256 x 256`。\n", "\n", "定义加载图像文件并输出两个图像张量的函数:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aO9ZAGH5K3SY", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def load(image_file):\n", " # Read and decode an image file to a uint8 tensor\n", " image = tf.io.read_file(image_file)\n", " image = tf.io.decode_jpeg(image)\n", "\n", " # Split each image tensor into two tensors:\n", " # - one with a real building facade image\n", " # - one with an architecture label image \n", " w = tf.shape(image)[1]\n", " w = w // 2\n", " input_image = image[:, w:, :]\n", " real_image = image[:, :w, :]\n", "\n", " # Convert both images to float32 tensors\n", " input_image = tf.cast(input_image, tf.float32)\n", " real_image = tf.cast(real_image, tf.float32)\n", "\n", " return input_image, real_image" ] }, { "cell_type": "markdown", "metadata": { "id": "r5ByHTlfE06P" }, "source": [ "绘制输入图像(建筑标签图像)和真实(建筑立面照片)图像的样本:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4OLHMpsQ5aOv", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "inp, re = load(str(PATH / 'train/100.jpg'))\n", "# Casting to int for matplotlib to display the images\n", "plt.figure()\n", "plt.imshow(inp / 255.0)\n", "plt.figure()\n", "plt.imshow(re / 255.0)" ] }, { "cell_type": "markdown", "metadata": { "id": "PVuZQTfI_c-s" }, "source": [ "如 [pix2pix 论文](https://arxiv.org/abs/1611.07004){:.external}中所述,您需要应用随机抖动和镜像来预处理训练集。\n", "\n", "定义几个具有以下功能的函数:\n", "\n", "1. 将每个 `256 x 256` 图像调整为更大的高度和宽度,`286 x 286`。\n", "2. 将其随机裁剪回 `256 x 256`。\n", "3. 随机水平翻转图像,即从左到右(随机镜像)。\n", "4. 将图像归一化到 `[-1, 1]` 范围。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2CbTEt448b4R", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# The facade training set consist of 400 images\n", "BUFFER_SIZE = 400\n", "# The batch size of 1 produced better results for the U-Net in the original pix2pix experiment\n", "BATCH_SIZE = 1\n", "# Each image is 256x256 in size\n", "IMG_WIDTH = 256\n", "IMG_HEIGHT = 256" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rwwYQpu9FzDu", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def resize(input_image, real_image, height, width):\n", " input_image = tf.image.resize(input_image, [height, width],\n", " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", " real_image = tf.image.resize(real_image, [height, width],\n", " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", "\n", " return input_image, real_image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Yn3IwqhiIszt", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def random_crop(input_image, real_image):\n", " stacked_image = tf.stack([input_image, real_image], axis=0)\n", " cropped_image = tf.image.random_crop(\n", " stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])\n", "\n", " return cropped_image[0], cropped_image[1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "muhR2cgbLKWW", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Normalizing the images to [-1, 1]\n", "def normalize(input_image, real_image):\n", " input_image = (input_image / 127.5) - 1\n", " real_image = (real_image / 127.5) - 1\n", "\n", " return input_image, real_image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fVQOjcPVLrUc", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "@tf.function()\n", "def random_jitter(input_image, real_image):\n", " # Resizing to 286x286\n", " input_image, real_image = resize(input_image, real_image, 286, 286)\n", "\n", " # Random cropping back to 256x256\n", " input_image, real_image = random_crop(input_image, real_image)\n", "\n", " if tf.random.uniform(()) > 0.5:\n", " # Random mirroring\n", " input_image = tf.image.flip_left_right(input_image)\n", " real_image = tf.image.flip_left_right(real_image)\n", "\n", " return input_image, real_image" ] }, { "cell_type": "markdown", "metadata": { "id": "wfAQbzy799UV" }, "source": [ "您可以检查部分预处理输出:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "n0OGdi6D92kM", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "plt.figure(figsize=(6, 6))\n", "for i in range(4):\n", " rj_inp, rj_re = random_jitter(inp, re)\n", " plt.subplot(2, 2, i + 1)\n", " plt.imshow(rj_inp / 255.0)\n", " plt.axis('off')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "3E9LGq3WBmsh" }, "source": [ "检查加载和预处理能够正常工作后,我们来定义两个辅助函数来加载和预处理训练集和测试集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tyaP4hLJ8b4W", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def load_image_train(image_file):\n", " input_image, real_image = load(image_file)\n", " input_image, real_image = random_jitter(input_image, real_image)\n", " input_image, real_image = normalize(input_image, real_image)\n", "\n", " return input_image, real_image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VB3Z6D_zKSru", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def load_image_test(image_file):\n", " input_image, real_image = load(image_file)\n", " input_image, real_image = resize(input_image, real_image,\n", " IMG_HEIGHT, IMG_WIDTH)\n", " input_image, real_image = normalize(input_image, real_image)\n", "\n", " return input_image, real_image" ] }, { "cell_type": "markdown", "metadata": { "id": "PIGN6ouoQxt3" }, "source": [ "## 使用 `tf.data` 构建输入流水线" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SQHmYSmk8b4b", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_dataset = tf.data.Dataset.list_files(str(PATH / 'train/*.jpg'))\n", "train_dataset = train_dataset.map(load_image_train,\n", " num_parallel_calls=tf.data.AUTOTUNE)\n", "train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n", "train_dataset = train_dataset.batch(BATCH_SIZE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MS9J0yA58b4g", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "try:\n", " test_dataset = tf.data.Dataset.list_files(str(PATH / 'test/*.jpg'))\n", "except tf.errors.InvalidArgumentError:\n", " test_dataset = tf.data.Dataset.list_files(str(PATH / 'val/*.jpg'))\n", "test_dataset = test_dataset.map(load_image_test)\n", "test_dataset = test_dataset.batch(BATCH_SIZE)" ] }, { "cell_type": "markdown", "metadata": { "id": "THY-sZMiQ4UV" }, "source": [ "## 构建生成器\n", "\n", "您的 pix2pix cGAN 是*经过修改的* [U-Net](https://arxiv.org/abs/1505.04597){:.external}。U-Net 由编码器(下采样器)和解码器(上采样器)。(有关详细信息,请参阅[图像分割](../images/segmentation.ipynb)教程和 [U-Net 项目网站](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/){:.external}。)\n", "\n", "- 编码器中的每个块为:Convolution -> Batch normalization -> Leaky ReLU\n", "- 解码器中的每个块为:Transposed convolution -> Batch normalization -> Dropout(应用于前三个块)-> ReLU\n", "- 编码器和解码器之间存在跳跃连接(如在 U-Net 中)。" ] }, { "cell_type": "markdown", "metadata": { "id": "4MQPuBCgtldI" }, "source": [ "定义下采样器(编码器):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tqqvWxlw8b4l", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "OUTPUT_CHANNELS = 3" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3R09ATE_SH9P", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def downsample(filters, size, apply_batchnorm=True):\n", " initializer = tf.random_normal_initializer(0., 0.02)\n", "\n", " result = tf.keras.Sequential()\n", " result.add(\n", " tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',\n", " kernel_initializer=initializer, use_bias=False))\n", "\n", " if apply_batchnorm:\n", " result.add(tf.keras.layers.BatchNormalization())\n", "\n", " result.add(tf.keras.layers.LeakyReLU())\n", "\n", " return result" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "a6_uCZCppTh7", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "down_model = downsample(3, 4)\n", "down_result = down_model(tf.expand_dims(inp, 0))\n", "print (down_result.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "aFI_Pa52tjLl" }, "source": [ "定义上采样器(解码器):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nhgDsHClSQzP", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def upsample(filters, size, apply_dropout=False):\n", " initializer = tf.random_normal_initializer(0., 0.02)\n", "\n", " result = tf.keras.Sequential()\n", " result.add(\n", " tf.keras.layers.Conv2DTranspose(filters, size, strides=2,\n", " padding='same',\n", " kernel_initializer=initializer,\n", " use_bias=False))\n", "\n", " result.add(tf.keras.layers.BatchNormalization())\n", "\n", " if apply_dropout:\n", " result.add(tf.keras.layers.Dropout(0.5))\n", "\n", " result.add(tf.keras.layers.ReLU())\n", "\n", " return result" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mz-ahSdsq0Oc", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "up_model = upsample(3, 4)\n", "up_result = up_model(down_result)\n", "print (up_result.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "ueEJyRVrtZ-p" }, "source": [ "使用下采样器和上采样器定义生成器:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lFPI4Nu-8b4q", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def Generator():\n", " inputs = tf.keras.layers.Input(shape=[256, 256, 3])\n", "\n", " down_stack = [\n", " downsample(64, 4, apply_batchnorm=False), # (batch_size, 128, 128, 64)\n", " downsample(128, 4), # (batch_size, 64, 64, 128)\n", " downsample(256, 4), # (batch_size, 32, 32, 256)\n", " downsample(512, 4), # (batch_size, 16, 16, 512)\n", " downsample(512, 4), # (batch_size, 8, 8, 512)\n", " downsample(512, 4), # (batch_size, 4, 4, 512)\n", " downsample(512, 4), # (batch_size, 2, 2, 512)\n", " downsample(512, 4), # (batch_size, 1, 1, 512)\n", " ]\n", "\n", " up_stack = [\n", " upsample(512, 4, apply_dropout=True), # (batch_size, 2, 2, 1024)\n", " upsample(512, 4, apply_dropout=True), # (batch_size, 4, 4, 1024)\n", " upsample(512, 4, apply_dropout=True), # (batch_size, 8, 8, 1024)\n", " upsample(512, 4), # (batch_size, 16, 16, 1024)\n", " upsample(256, 4), # (batch_size, 32, 32, 512)\n", " upsample(128, 4), # (batch_size, 64, 64, 256)\n", " upsample(64, 4), # (batch_size, 128, 128, 128)\n", " ]\n", "\n", " initializer = tf.random_normal_initializer(0., 0.02)\n", " last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,\n", " strides=2,\n", " padding='same',\n", " kernel_initializer=initializer,\n", " activation='tanh') # (batch_size, 256, 256, 3)\n", "\n", " x = inputs\n", "\n", " # Downsampling through the model\n", " skips = []\n", " for down in down_stack:\n", " x = down(x)\n", " skips.append(x)\n", "\n", " skips = reversed(skips[:-1])\n", "\n", " # Upsampling and establishing the skip connections\n", " for up, skip in zip(up_stack, skips):\n", " x = up(x)\n", " x = tf.keras.layers.Concatenate()([x, skip])\n", "\n", " x = last(x)\n", "\n", " return tf.keras.Model(inputs=inputs, outputs=x)" ] }, { "cell_type": "markdown", "metadata": { "id": "Z4PKwrcQFYvF" }, "source": [ "可视化生成器模型架构:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dIbRPFzjmV85", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "generator = Generator()\n", "tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)" ] }, { "cell_type": "markdown", "metadata": { "id": "Z8kbgTK8FcPo" }, "source": [ "测试生成器:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U1N1_obwtdQH", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "gen_output = generator(inp[tf.newaxis, ...], training=False)\n", "plt.imshow(gen_output[0, ...])" ] }, { "cell_type": "markdown", "metadata": { "id": "dpDPEQXIAiQO" }, "source": [ "### 定义生成器损失\n", "\n", "GAN 学习适应数据的损失,而 cGAN 学习结构化损失,该损失会惩罚与网络输出和目标图像不同的可能结构,如 [pix2pix 论文](https://arxiv.org/abs/1611.07004){:.external}中所述。\n", "\n", "- 生成器损失是生成图像和**一数组**的 sigmoid 交叉熵损失。\n", "- 论文还提到了 L1 损失,它是生成图像与目标图像之间的 MAE(平均绝对误差)。\n", "- 这样可使生成的图像在结构上与目标图像相似。\n", "- 计算总生成器损失的公式为:`gan_loss + LAMBDA * l1_loss`,其中 `LAMBDA = 100`。该值由论文作者决定。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cyhxTuvJyIHV", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "LAMBDA = 100" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Q1Xbz5OaLj5C", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "90BIcCKcDMxz", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def generator_loss(disc_generated_output, gen_output, target):\n", " gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)\n", "\n", " # Mean absolute error\n", " l1_loss = tf.reduce_mean(tf.abs(target - gen_output))\n", "\n", " total_gen_loss = gan_loss + (LAMBDA * l1_loss)\n", "\n", " return total_gen_loss, gan_loss, l1_loss" ] }, { "cell_type": "markdown", "metadata": { "id": "fSZbDgESHIV6" }, "source": [ "生成器的训练过程如下:" ] }, { "cell_type": "markdown", "metadata": { "id": "TlB-XMY5Awj9" }, "source": [ "![生成器更新图像](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/gen.png?raw=1)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ZTKZfoaoEF22" }, "source": [ "## 构建判别器\n", "\n", "pix2pix cGAN 中的判别器是一个卷积 PatchGAN 分类器,它会尝试对每个图像*分块*的真实与否进行分类,如 [pix2pix 论文](https://arxiv.org/abs/1611.07004){:.external}中所述。\n", "\n", "- 判别器中的每个块为:Convolution -> Batch normalization -> Leaky ReLU。\n", "- 最后一层之后的输出形状为 `(batch_size, 30, 30, 1)`。\n", "- 输出的每个 `30 x 30` 图像分块会对输入图像的 `70 x 70` 部分进行分类。\n", "- 判别器接收 2 个输入:\n", " - 输入图像和目标图像,应分类为真实图像。\n", " - 输入图像和生成图像(生成器的输出),应分类为伪图像。\n", " - 使用`tf.concat([inp, tar], axis=-1)` 将这 2 个输入连接在一起。" ] }, { "cell_type": "markdown", "metadata": { "id": "XIuTeGL5v45m" }, "source": [ "我们来定义判别器:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ll6aNeQx8b4v", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def Discriminator():\n", " initializer = tf.random_normal_initializer(0., 0.02)\n", "\n", " inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')\n", " tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')\n", "\n", " x = tf.keras.layers.concatenate([inp, tar]) # (batch_size, 256, 256, channels*2)\n", "\n", " down1 = downsample(64, 4, False)(x) # (batch_size, 128, 128, 64)\n", " down2 = downsample(128, 4)(down1) # (batch_size, 64, 64, 128)\n", " down3 = downsample(256, 4)(down2) # (batch_size, 32, 32, 256)\n", "\n", " zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (batch_size, 34, 34, 256)\n", " conv = tf.keras.layers.Conv2D(512, 4, strides=1,\n", " kernel_initializer=initializer,\n", " use_bias=False)(zero_pad1) # (batch_size, 31, 31, 512)\n", "\n", " batchnorm1 = tf.keras.layers.BatchNormalization()(conv)\n", "\n", " leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)\n", "\n", " zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (batch_size, 33, 33, 512)\n", "\n", " last = tf.keras.layers.Conv2D(1, 4, strides=1,\n", " kernel_initializer=initializer)(zero_pad2) # (batch_size, 30, 30, 1)\n", "\n", " return tf.keras.Model(inputs=[inp, tar], outputs=last)" ] }, { "cell_type": "markdown", "metadata": { "id": "HdV9yAbBHNkg" }, "source": [ "可视化判别器模型架构:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YHoUui4om-Ev", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "discriminator = Discriminator()\n", "tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)" ] }, { "cell_type": "markdown", "metadata": { "id": "ps7nIHigHYc7" }, "source": [ "测试判别器:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gDkA05NE6QMs", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "disc_out = discriminator([inp[tf.newaxis, ...], gen_output], training=False)\n", "plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap='RdBu_r')\n", "plt.colorbar()" ] }, { "cell_type": "markdown", "metadata": { "id": "AOqg1dhUAWoD" }, "source": [ "### 定义判别器损失\n", "\n", "- `discriminator_loss` 函数接收 2 个输入:**真实图像**和**生成图像**。\n", "- `real_loss` 是**真实图像**和**一组 1的 sigmoid 的交叉熵损失(因为这些是真实图像)**。\n", "- `generated_loss` 是**生成图像**和**一组 0 的 sigmoid 交叉熵损失(因为这些是伪图像)**。\n", "- `total_loss` 是 `real_loss` 和 `generated_loss` 的总和" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wkMNfBWlT-PV", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def discriminator_loss(disc_real_output, disc_generated_output):\n", " real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)\n", "\n", " generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)\n", "\n", " total_disc_loss = real_loss + generated_loss\n", "\n", " return total_disc_loss" ] }, { "cell_type": "markdown", "metadata": { "id": "-ede4p2YELFa" }, "source": [ "判别器的训练过程如下所示。\n", "\n", "要详细了解架构和超参数,请参阅 [pix2pix 论文](https://arxiv.org/abs/1611.07004){:.external}。" ] }, { "cell_type": "markdown", "metadata": { "id": "IS9sHa-1BoAF" }, "source": [ "![判别器更新图像](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/dis.png?raw=1)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "0FMYgY_mPfTi" }, "source": [ "## 定义优化器和检查点 saver\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lbHFNexF0x6O", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n", "discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WJnftd5sQsv6", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "checkpoint_dir = './training_checkpoints'\n", "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n", " discriminator_optimizer=discriminator_optimizer,\n", " generator=generator,\n", " discriminator=discriminator)" ] }, { "cell_type": "markdown", "metadata": { "id": "Rw1fkAczTQYh" }, "source": [ "## 生成图像\n", "\n", "编写函数以在训练期间绘制一些图像。\n", "\n", "- 将图像从测试集传递到生成器。\n", "- 然后,生成器会将输入图像转换为输出。\n", "- 最后一步是绘制预测,*瞧*!" ] }, { "cell_type": "markdown", "metadata": { "id": "Rb0QQFHF-JfS" }, "source": [ "注:在这里,`training=True` 是有意的,因为在基于测试数据集运行模型时,您需要批次统计信息。如果您使用 training = False,将获得从训练数据集中学习的累积统计信息(您不需要)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RmdVsmvhPxyy", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def generate_images(model, test_input, tar):\n", " prediction = model(test_input, training=True)\n", " plt.figure(figsize=(15, 15))\n", "\n", " display_list = [test_input[0], tar[0], prediction[0]]\n", " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n", "\n", " for i in range(3):\n", " plt.subplot(1, 3, i+1)\n", " plt.title(title[i])\n", " # Getting the pixel values in the [0, 1] range to plot.\n", " plt.imshow(display_list[i] * 0.5 + 0.5)\n", " plt.axis('off')\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "gipsSEoZIG1a" }, "source": [ "测试该函数:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8Fc4NzT-DgEx", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for example_input, example_target in test_dataset.take(1):\n", " generate_images(generator, example_input, example_target)" ] }, { "cell_type": "markdown", "metadata": { "id": "NLKOG55MErD0" }, "source": [ "## 训练\n", "\n", "- 为每个样本输入生成一个输出。\n", "- 判别器接收 input_image 和生成的图像作为第一个输入。第二个输入为 input_image 和 target_image。\n", "- 接下来,计算生成器和判别器损失。\n", "- 随后,计算损失相对于生成器和判别器变量(输入)的梯度,并将其应用于优化器。\n", "- 最后,将损失记录到 TensorBoard。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xNNMDBNH12q-", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "log_dir=\"logs/\"\n", "\n", "summary_writer = tf.summary.create_file_writer(\n", " log_dir + \"fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KBKUV2sKXDbY", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "@tf.function\n", "def train_step(input_image, target, step):\n", " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", " gen_output = generator(input_image, training=True)\n", "\n", " disc_real_output = discriminator([input_image, target], training=True)\n", " disc_generated_output = discriminator([input_image, gen_output], training=True)\n", "\n", " gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)\n", " disc_loss = discriminator_loss(disc_real_output, disc_generated_output)\n", "\n", " generator_gradients = gen_tape.gradient(gen_total_loss,\n", " generator.trainable_variables)\n", " discriminator_gradients = disc_tape.gradient(disc_loss,\n", " discriminator.trainable_variables)\n", "\n", " generator_optimizer.apply_gradients(zip(generator_gradients,\n", " generator.trainable_variables))\n", " discriminator_optimizer.apply_gradients(zip(discriminator_gradients,\n", " discriminator.trainable_variables))\n", "\n", " with summary_writer.as_default():\n", " tf.summary.scalar('gen_total_loss', gen_total_loss, step=step//1000)\n", " tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=step//1000)\n", " tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=step//1000)\n", " tf.summary.scalar('disc_loss', disc_loss, step=step//1000)" ] }, { "cell_type": "markdown", "metadata": { "id": "hx7s-vBHFKdh" }, "source": [ "实际的训练循环。由于本教程可以运行多个数据集,并且数据集的大小差异很大,因此将训练循环设置为按步骤而非按周期工作。\n", "\n", "- 迭代步骤数。\n", "- 每 10 步打印一个点 (`.`)。\n", "- 每 1 千步:清除显示并运行 `generate_images` 以显示进度。\n", "- 每 5 千步:保存一个检查点。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GFyPlBWv1B5j", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def fit(train_ds, test_ds, steps):\n", " example_input, example_target = next(iter(test_ds.take(1)))\n", " start = time.time()\n", "\n", " for step, (input_image, target) in train_ds.repeat().take(steps).enumerate():\n", " if (step) % 1000 == 0:\n", " display.clear_output(wait=True)\n", "\n", " if step != 0:\n", " print(f'Time taken for 1000 steps: {time.time()-start:.2f} sec\\n')\n", "\n", " start = time.time()\n", "\n", " generate_images(generator, example_input, example_target)\n", " print(f\"Step: {step//1000}k\")\n", "\n", " train_step(input_image, target, step)\n", "\n", " # Training step\n", " if (step+1) % 10 == 0:\n", " print('.', end='', flush=True)\n", "\n", "\n", " # Save (checkpoint) the model every 5k steps\n", " if (step + 1) % 5000 == 0:\n", " checkpoint.save(file_prefix=checkpoint_prefix)" ] }, { "cell_type": "markdown", "metadata": { "id": "wozqyTh2wmCu" }, "source": [ "此训练循环会保存日志,您可以在 TensorBoard 中查看这些日志以监控训练进度。\n", "\n", "如果您使用的是本地计算机,则需要启动一个单独的 TensorBoard 进程。在笔记本中工作时,请在开始训练之前启动查看器以使用 TensorBoard 进行监控。\n", "\n", "在笔记本中打开嵌入式 TensorBoard 查看器(抱歉,这不会在 tensorflow.org 上显示):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ot22ujrlLhOd", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "%load_ext tensorboard\n", "%tensorboard --logdir {log_dir}" ] }, { "cell_type": "markdown", "metadata": { "id": "fyjixlMlBybN" }, "source": [ "您可以在 [TensorBoard.dev](https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw) 上查看此笔记本[先前运行的结果](https://tensorboard.dev/)。" ] }, { "cell_type": "markdown", "metadata": { "id": "Pe0-8Bzg22ox" }, "source": [ "最后,运行训练循环:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "a1zZmKmvOH85", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "fit(train_dataset, test_dataset, steps=40000)" ] }, { "cell_type": "markdown", "metadata": { "id": "DMTm4peo3cem" }, "source": [ "与简单的分类或回归模型相比,在训练 GAN(或像 pix2pix 这样的 cGAN)时,对日志的解释更加微妙。要检查的内容包括:\n", "\n", "- 检查生成器模型或判别器模型均未“获胜”。如果 `gen_gan_loss` 或 `disc_loss` 变得很低,则表明此模型正在支配另一个模型,并且您未能成功训练组合模型。\n", "- 值 `log(2) = 0.69` 是这些损失的一个良好参考点,因为它表示困惑度为 2:判别器对这两个选项的平均不确定性是相等的。\n", "- 对于 `disc_loss`,低于 `0.69` 的值意味着判别器在真实图像和生成图像的组合集上的表现要优于随机数。\n", "- 对于 `gen_gan_loss`,如果值小于 `0.69`,则表示生成器在欺骗判别器方面的表现要优于随机数。\n", "- 随着训练的进行,`gen_l1_loss` 应当下降。" ] }, { "cell_type": "markdown", "metadata": { "id": "kz80bY3aQ1VZ" }, "source": [ "## 恢复最新的检查点并测试网络" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HSSm4kfvJiqv", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "!ls {checkpoint_dir}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4t4x69adQ5xb", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Restoring the latest checkpoint in checkpoint_dir\n", "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" ] }, { "cell_type": "markdown", "metadata": { "id": "1RGysMU_BZhx" }, "source": [ "## 使用测试集生成一些图像" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KUgSnmy2nqSP", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Run the trained model on a few examples from the test set\n", "for inp, tar in test_dataset.take(5):\n", " generate_images(generator, inp, tar)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "pix2pix.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }