{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "cZCM65CBt1CJ" }, "outputs": [], "source": [ "##### Copyright 2019 The TensorFlow Authors.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "JOgMcEajtkmg", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "rCSP-dbMw88x" }, "source": [ "# 图像分割" ] }, { "cell_type": "markdown", "metadata": { "id": "NEWs8JXRuGex" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "sMP7mglMuGT2" }, "source": [ "这篇教程将重点讨论图像分割任务,使用的是改进版的 [U-Net](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/)。\n", "\n", "## 什么是图像分割?\n", "\n", "在图像分类任务中,网络会为每个输入图像分配一个标签(或类)。但是,如何了解该对象的形状、哪个像素属于哪个对象等信息呢?在这种情况下,您需要为图像的每个像素分配一个类。此任务称为分割。分割模型会返回有关图像的更详细信息。图像分割在医学成像、自动驾驶汽车和卫星成像等方面有很多应用。\n", "\n", "本教程使用 [Oxford-IIIT Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/) ([Parkhi et al, 2012](https://www.robots.ox.ac.uk/~vgg/publications/2012/parkhi12a/parkhi12a.pdf))。该数据集由 37 个宠物品种的图像组成,每个品种有 200 个图像(训练拆分和测试拆分各有 100 个)。每个图像都包含相应的标签和像素级掩码。掩码是每个像素的类标签。每个像素都会被划入以下三个类别之一:\n", "\n", "- 第 1 类:属于宠物的像素。\n", "- 第 2 类:宠物边缘的像素。\n", "- 第 3 类:以上都不是/周围的像素。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MQmKthrSBCld", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "!pip install git+https://github.com/tensorflow/examples.git" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YQX7R4bhZy5h", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g87--n2AtyO_", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "from tensorflow_examples.models.pix2pix import pix2pix\n", "\n", "from IPython.display import clear_output\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": { "id": "oWe0_rQM4JbC" }, "source": [ "## 下载 Oxford-IIIT Pets 数据集\n", "\n", "该数据集可从 [TensorFlow Datasets](https://tensorflow.google.cn/datasets/catalog/oxford_iiit_pet) 获得。分割掩码包含在版本 3 以上的版本中。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "40ITeStwDwZb", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "rJcVdj_U4vzf" }, "source": [ "此外,图像颜色值被归一化到 `[0,1]` 范围。最后,如上所述,分割掩码中的像素被标记为 {1, 2, 3}。为方便起见,从分割掩码中减去 1,得到的标签为:{0, 1, 2}。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FD60EbcAQqov", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def normalize(input_image, input_mask):\n", " input_image = tf.cast(input_image, tf.float32) / 255.0\n", " input_mask -= 1\n", " return input_image, input_mask" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Zf0S67hJRp3D", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def load_image(datapoint):\n", " input_image = tf.image.resize(datapoint['image'], (128, 128))\n", " input_mask = tf.image.resize(\n", " datapoint['segmentation_mask'],\n", " (128, 128),\n", " method = tf.image.ResizeMethod.NEAREST_NEIGHBOR,\n", " )\n", "\n", " input_image, input_mask = normalize(input_image, input_mask)\n", "\n", " return input_image, input_mask" ] }, { "cell_type": "markdown", "metadata": { "id": "65-qHTjX5VZh" }, "source": [ "数据集已包含所需的训练拆分和测试拆分,因此请继续使用相同的拆分。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yHwj2-8SaQli", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "TRAIN_LENGTH = info.splits['train'].num_examples\n", "BATCH_SIZE = 64\n", "BUFFER_SIZE = 1000\n", "STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "39fYScNz9lmo", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_images = dataset['train'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)\n", "test_images = dataset['test'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)" ] }, { "cell_type": "markdown", "metadata": { "id": "T9hGHyg8L3Y1" }, "source": [ "下面的类通过随机翻转图像来执行简单的增强。请转到[图像增强](data_augmentation.ipynb)教程以了解更多信息。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fUWdDJRTL0PP", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class Augment(tf.keras.layers.Layer):\n", " def __init__(self, seed=42):\n", " super().__init__()\n", " # both use the same seed, so they'll make the same random changes.\n", " self.augment_inputs = tf.keras.layers.RandomFlip(mode=\"horizontal\", seed=seed)\n", " self.augment_labels = tf.keras.layers.RandomFlip(mode=\"horizontal\", seed=seed)\n", " \n", " def call(self, inputs, labels):\n", " inputs = self.augment_inputs(inputs)\n", " labels = self.augment_labels(labels)\n", " return inputs, labels" ] }, { "cell_type": "markdown", "metadata": { "id": "xTIbNIBdcgL3" }, "source": [ "构建输入流水线,在对输入进行批处理后应用增强:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VPscskQcNCx4", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_batches = (\n", " train_images\n", " .cache()\n", " .shuffle(BUFFER_SIZE)\n", " .batch(BATCH_SIZE)\n", " .repeat()\n", " .map(Augment())\n", " .prefetch(buffer_size=tf.data.AUTOTUNE))\n", "\n", "test_batches = test_images.batch(BATCH_SIZE)" ] }, { "cell_type": "markdown", "metadata": { "id": "Xa3gMAE_9qNa" }, "source": [ "呈现数据集中的图像样本及其对应的掩码:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3N2RPAAW9q4W", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def display(display_list):\n", " plt.figure(figsize=(15, 15))\n", "\n", " title = ['Input Image', 'True Mask', 'Predicted Mask']\n", "\n", " for i in range(len(display_list)):\n", " plt.subplot(1, len(display_list), i+1)\n", " plt.title(title[i])\n", " plt.imshow(tf.keras.utils.array_to_img(display_list[i]))\n", " plt.axis('off')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "a6u_Rblkteqb", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for images, masks in train_batches.take(2):\n", " sample_image, sample_mask = images[0], masks[0]\n", " display([sample_image, sample_mask])" ] }, { "cell_type": "markdown", "metadata": { "id": "FAOe93FRMk3w" }, "source": [ "## 定义模型\n", "\n", "这里使用的模型是修改后的 [U-Net](https://arxiv.org/abs/1505.04597)。U-Net 由编码器(下采样器)和解码器(上采样器)组成。为了学习稳健的特征并减少可训练参数的数量,请使用预训练模型 [MobileNetV2](https://arxiv.org/abs/1801.04381) 作为编码器。对于解码器,您将使用上采样块,该块已在 TensorFlow Examples 仓库的 [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) 示例中实现。(请查看笔记本中的 [pix2pix:使用条件 GAN 进行图像到图像转换](../generative/pix2pix.ipynb)教程。)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "W4mQle3lthit" }, "source": [ "如前所述,编码器是一个预训练的 MobileNetV2 模型。您将使用来自 `tf.keras.applications` 的模型。编码器由模型中中间层的特定输出组成。请注意,在训练过程中不会训练编码器。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "liCeLH0ctjq7", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)\n", "\n", "# Use the activations of these layers\n", "layer_names = [\n", " 'block_1_expand_relu', # 64x64\n", " 'block_3_expand_relu', # 32x32\n", " 'block_6_expand_relu', # 16x16\n", " 'block_13_expand_relu', # 8x8\n", " 'block_16_project', # 4x4\n", "]\n", "base_model_outputs = [base_model.get_layer(name).output for name in layer_names]\n", "\n", "# Create the feature extraction model\n", "down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)\n", "\n", "down_stack.trainable = False" ] }, { "cell_type": "markdown", "metadata": { "id": "KPw8Lzra5_T9" }, "source": [ "解码器/上采样器只是在 TensorFlow 示例中实现的一系列上采样块:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "p0ZbfywEbZpJ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "up_stack = [\n", " pix2pix.upsample(512, 3), # 4x4 -> 8x8\n", " pix2pix.upsample(256, 3), # 8x8 -> 16x16\n", " pix2pix.upsample(128, 3), # 16x16 -> 32x32\n", " pix2pix.upsample(64, 3), # 32x32 -> 64x64\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "45HByxpVtrPF", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def unet_model(output_channels:int):\n", " inputs = tf.keras.layers.Input(shape=[128, 128, 3])\n", "\n", " # Downsampling through the model\n", " skips = down_stack(inputs)\n", " x = skips[-1]\n", " skips = reversed(skips[:-1])\n", "\n", " # Upsampling and establishing the skip connections\n", " for up, skip in zip(up_stack, skips):\n", " x = up(x)\n", " concat = tf.keras.layers.Concatenate()\n", " x = concat([x, skip])\n", "\n", " # This is the last layer of the model\n", " last = tf.keras.layers.Conv2DTranspose(\n", " filters=output_channels, kernel_size=3, strides=2,\n", " padding='same') #64x64 -> 128x128\n", "\n", " x = last(x)\n", "\n", " return tf.keras.Model(inputs=inputs, outputs=x)" ] }, { "cell_type": "markdown", "metadata": { "id": "LRsjdZuEnZfA" }, "source": [ "请注意,最后一层的筛选器数量设置为 `output_channels` 的数量。每个类将有一个输出通道。" ] }, { "cell_type": "markdown", "metadata": { "id": "j0DGH_4T0VYn" }, "source": [ "## 训练模型\n", "\n", "现在,剩下要做的是编译和训练模型。\n", "\n", "由于这是一个多类分类问题,请使用 `tf.keras.losses.SparseCategoricalCrossentropy` 损失函数,并将 `from_logits` 参数设置为 `True`,因为标签是标量整数,而不是每个类的每个像素的分数向量。\n", "\n", "运行推断时,分配给像素的标签是具有最高值的通道。这就是 `create_mask` 函数的作用。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6he36HK5uKAc", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "OUTPUT_CLASSES = 3\n", "\n", "model = unet_model(output_channels=OUTPUT_CLASSES)\n", "model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])" ] }, { "cell_type": "markdown", "metadata": { "id": "xVMzbIZLcyEF" }, "source": [ "绘制最终的模型架构:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sw82qF1Gcovr", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "tf.keras.utils.plot_model(model, show_shapes=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "Tc3MiEO2twLS" }, "source": [ "在训练前试用一下该模型,以检查其预测结果:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UwvIKLZPtxV_", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def create_mask(pred_mask):\n", " pred_mask = tf.math.argmax(pred_mask, axis=-1)\n", " pred_mask = pred_mask[..., tf.newaxis]\n", " return pred_mask[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YLNsrynNtx4d", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def show_predictions(dataset=None, num=1):\n", " if dataset:\n", " for image, mask in dataset.take(num):\n", " pred_mask = model.predict(image)\n", " display([image[0], mask[0], create_mask(pred_mask)])\n", " else:\n", " display([sample_image, sample_mask,\n", " create_mask(model.predict(sample_image[tf.newaxis, ...]))])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X_1CC0T4dho3", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "show_predictions()" ] }, { "cell_type": "markdown", "metadata": { "id": "22AyVYWQdkgk" }, "source": [ "下面定义的回调用于观察模型在训练过程中的改进情况:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wHrHsqijdmL6", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class DisplayCallback(tf.keras.callbacks.Callback):\n", " def on_epoch_end(self, epoch, logs=None):\n", " clear_output(wait=True)\n", " show_predictions()\n", " print ('\\nSample Prediction after epoch {}\\n'.format(epoch+1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "StKDH_B9t4SD", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "EPOCHS = 20\n", "VAL_SUBSPLITS = 5\n", "VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS\n", "\n", "model_history = model.fit(train_batches, epochs=EPOCHS,\n", " steps_per_epoch=STEPS_PER_EPOCH,\n", " validation_steps=VALIDATION_STEPS,\n", " validation_data=test_batches,\n", " callbacks=[DisplayCallback()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P_mu0SAbt40Q", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "loss = model_history.history['loss']\n", "val_loss = model_history.history['val_loss']\n", "\n", "plt.figure()\n", "plt.plot(model_history.epoch, loss, 'r', label='Training loss')\n", "plt.plot(model_history.epoch, val_loss, 'bo', label='Validation loss')\n", "plt.title('Training and Validation Loss')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Loss Value')\n", "plt.ylim([0, 1])\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "unP3cnxo_N72" }, "source": [ "## 做出预测" ] }, { "cell_type": "markdown", "metadata": { "id": "7BVXldSo-0mW" }, "source": [ "接下来,进行一些预测。为了节省时间,保持较小周期数,但您也可以将其设置得更高以获得更准确的结果。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ikrzoG24qwf5", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "show_predictions(test_batches, 3)" ] }, { "cell_type": "markdown", "metadata": { "id": "QAwvlgSNoK3o" }, "source": [ "## 可选:不平衡的类和类权重" ] }, { "cell_type": "markdown", "metadata": { "id": "eqtFPqqu2kxP" }, "source": [ "语义分割数据集可能会高度不平衡,这意味着特定类别的像素可以比其他类别的像素更多地出现在图像内部。由于分割问题可以被视为逐像素分类问题,因此您可以通过加权损失函数来解决不平衡问题。这是处理此问题的一种简单而优雅的方式。请参阅[不平衡数据分类](../structured_data/imbalanced_data.ipynb)教程了解更多信息。\n", "\n", "为[避免歧义](https://github.com/keras-team/keras/issues/3653#issuecomment-243939748),{code 1}Model.fit{/code 1} 不支持具有 3 个以上维度的目标的 `class_weight` 参数。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aHt90UEQsZDn", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "try:\n", " model_history = model.fit(train_batches, epochs=EPOCHS,\n", " steps_per_epoch=STEPS_PER_EPOCH,\n", " class_weight = {0:2.0, 1:2.0, 2:1.0})\n", " assert False\n", "except Exception as e:\n", " print(f\"Expected {type(e).__name__}: {e}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "brbhYODCsvbe" }, "source": [ "因此,在这种情况下,您需要自己实现加权。您将使用样本权重来执行此操作:除了 `(data, label)` 对之外,`Model.fit` 还接受 `(data, label, sample_weight)` 三元组。\n", "\n", "Keras `Model.fit` 将 `sample_weight` 传播给损失和指标,它们也接受 `sample_weight` 参数。在归约步骤之前,将样本权重乘以样本值。例如:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EmHtImJn5Kk-", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "label = [0,0]\n", "prediction = [[-3., 0], [-3, 0]] \n", "sample_weight = [1, 10] \n", "\n", "loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,\n", " reduction=tf.keras.losses.Reduction.NONE)\n", "loss(label, prediction, sample_weight).numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "Gbwo3DZ-9TxM" }, "source": [ "因此,要为本教程设置样本权重,您需要一个函数,该函数接受 `(data, label)` 对并返回 `(data, label, sample_weight)` 三元组,其中 `sample_weight` 是包含每个像素的类权重的单通道图像。\n", "\n", "最简单的可能实现是将标签用作 `class_weight` 列表的索引:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DlG-n2Ugo8Jc", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def add_sample_weights(image, label):\n", " # The weights for each class, with the constraint that:\n", " # sum(class_weights) == 1.0\n", " class_weights = tf.constant([2.0, 2.0, 1.0])\n", " class_weights = class_weights/tf.reduce_sum(class_weights)\n", "\n", " # Create an image of `sample_weights` by using the label at each pixel as an \n", " # index into the `class weights` .\n", " sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))\n", "\n", " return image, label, sample_weights" ] }, { "cell_type": "markdown", "metadata": { "id": "hLH_NvH2UrXU" }, "source": [ "每个生成的数据集元素包含 3 个图像:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SE_ezRSFRCnE", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_batches.map(add_sample_weights).element_spec" ] }, { "cell_type": "markdown", "metadata": { "id": "Yc-EpIzaRbSL" }, "source": [ "现在,您可以在此加权数据集上训练模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QDWipedAoOQe", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "weighted_model = unet_model(OUTPUT_CLASSES)\n", "weighted_model.compile(\n", " optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "btEFKc1xodGR", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "weighted_model.fit(\n", " train_batches.map(add_sample_weights),\n", " epochs=1,\n", " steps_per_epoch=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "R24tahEqmSCk" }, "source": [ "## 接下来\n", "\n", "现在您已经了解了什么是图像分割及其工作原理,您可以使用不同的中间层输出,甚至不同的预训练模型来尝试本教程。您也可以通过尝试在 Kaggle 上托管的 [Carvana](https://www.kaggle.com/c/carvana-image-masking-challenge/overview) 图像掩码挑战来挑战自己。\n", "\n", "您可能还想查看另一个可以根据自己的数据重新训练的模型的 [Tensorflow Object Detection API](https://github.com/tensorflow/models/blob/master/research/object_detection/README.md)。[TensorFlow Hub](https://tensorflow.google.cn/hub/tutorials/tf2_object_detection#optional) 上提供了预训练模型。" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "segmentation.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }