{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "4EFY9e5wRn7v" }, "outputs": [], "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "pkTRazeVRwDe", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "VyOckJu6Rs-i" }, "source": [ "# 数据增强" ] }, { "cell_type": "markdown", "metadata": { "id": "0HEsULqDR7AH" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看在 Google Colab 中运行在 GitHub 上查看下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "PxIOE5RnSQtj" }, "source": [ "## 概述\n", "\n", "本教程演示了数据增强:一种通过应用随机(但真实)的变换(例如图像旋转)来增加训练集多样性的技术。\n", "\n", "您将学习如何通过两种方式应用数据增强:\n", "\n", "- 使用 Keras 预处理层,例如 `tf.keras.layers.Resizing`、`tf.keras.layers.Rescaling`、`tf.keras.layers.RandomFlip` 和 `tf.keras.layers.RandomRotation`。\n", "- 使用 `tf.image` 方法,例如 `tf.image.flip_left_right`、`tf.image.rgb_to_grayscale`、`tf.image.adjust_brightness`、`tf.image.central_crop` 和 `tf.image.stateless_random*`。" ] }, { "cell_type": "markdown", "metadata": { "id": "-UxHAqXmSXN5" }, "source": [ "## 设置" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C2Q5rPenTAJP", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds\n", "\n", "from tensorflow.keras import layers" ] }, { "cell_type": "markdown", "metadata": { "id": "Ydx3SSoF4wpG" }, "source": [ "## 下载数据集\n", "\n", "本教程使用 [tf_flowers](https://tensorflow.google.cn/datasets/catalog/tf_flowers) 数据集。为了方便起见,请使用 [TensorFlow Datasets](https://tensorflow.google.cn/datasets) 下载数据集。如果您想了解导入数据的其他方式,请参阅[加载图像](https://tensorflow.google.cn/tutorials/load_data/images)教程。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ytHhsYmO52zy", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "(train_ds, val_ds, test_ds), metadata = tfds.load(\n", " 'tf_flowers',\n", " split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],\n", " with_info=True,\n", " as_supervised=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "MjxEJtCwsnmm" }, "source": [ "花卉数据集有五个类。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wKwx7vQuspxz", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "num_classes = metadata.features['label'].num_classes\n", "print(num_classes)" ] }, { "cell_type": "markdown", "metadata": { "id": "zZAQW44949uw" }, "source": [ "我们从数据集中检索一个图像,然后使用它来演示数据增强。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kXlx1lCr5Bip", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "get_label_name = metadata.features['label'].int2str\n", "\n", "image, label = next(iter(train_ds))\n", "_ = plt.imshow(image)\n", "_ = plt.title(get_label_name(label))" ] }, { "cell_type": "markdown", "metadata": { "id": "vdJ6XA4q2nqK" }, "source": [ "## 使用 Keras 预处理层" ] }, { "cell_type": "markdown", "metadata": { "id": "GRMPnfzBB2hw" }, "source": [ "### 调整大小和重新缩放\n" ] }, { "cell_type": "markdown", "metadata": { "id": "jhG7gSWmUMJx" }, "source": [ "您可以使用 Keras 预处理层将图像大小调整为一致的形状(使用 `tf.keras.layers.Resizing`),并重新调整像素值(使用 `tf.keras.layers.Rescaling`)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jMM3b85e3yhd", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "IMG_SIZE = 180\n", "\n", "resize_and_rescale = tf.keras.Sequential([\n", " layers.Resizing(IMG_SIZE, IMG_SIZE),\n", " layers.Rescaling(1./255)\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "4z8AV1WgnYNW" }, "source": [ "注:上面的重新缩放层将像素值标准化到 `[0,1]` 范围。如果想要 `[-1,1]`,可以编写 `tf.keras.layers.Rescaling(1./127.5, offset=-1)`。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "MQiTwsHJDHAD" }, "source": [ "您可以看到将这些层应用于图像的结果。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X9OLuR1bC1Pd", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "result = resize_and_rescale(image)\n", "_ = plt.imshow(result)" ] }, { "cell_type": "markdown", "metadata": { "id": "yxAMg8Zql5lw" }, "source": [ "验证像素是否在 `[0, 1]` 范围内:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DPTB8IQmSeKM", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "print(\"Min and max pixel values:\", result.numpy().min(), result.numpy().max())" ] }, { "cell_type": "markdown", "metadata": { "id": "fL6M7fuivAw4" }, "source": [ "### 数据增强" ] }, { "cell_type": "markdown", "metadata": { "id": "SL4Suj46ScfU" }, "source": [ "您也可以使用 Keras 预处理层进行数据增强,例如 `tf.keras.layers.RandomFlip` 和 `tf.keras.layers.RandomRotation`。" ] }, { "cell_type": "markdown", "metadata": { "id": "V-4PugTE-4sl" }, "source": [ "我们来创建一些预处理层,然后将它们重复应用于同一图像。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Svu_5yfa_Jb7", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "data_augmentation = tf.keras.Sequential([\n", " layers.RandomFlip(\"horizontal_and_vertical\"),\n", " layers.RandomRotation(0.2),\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kfzEuaNg69iU", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Add the image to a batch.\n", "image = tf.cast(tf.expand_dims(image, 0), tf.float32)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eR4wwi5Q_UZK", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "plt.figure(figsize=(10, 10))\n", "for i in range(9):\n", " augmented_image = data_augmentation(image)\n", " ax = plt.subplot(3, 3, i + 1)\n", " plt.imshow(augmented_image[0])\n", " plt.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": { "id": "jA17pEeS_2_-" }, "source": [ "有多种预处理层可用于数据增强,包括 `tf.keras.layers.RandomContrast`、`tf.keras.layers.RandomCrop`、`tf.keras.layers.RandomZoom` 等。" ] }, { "cell_type": "markdown", "metadata": { "id": "GG5RhIJtE0ng" }, "source": [ "### 使用 Keras 预处理层的两个选项\n", "\n", "您可以通过两种方式使用这些预处理层,但需进行重要的权衡。" ] }, { "cell_type": "markdown", "metadata": { "id": "MxGvUT727Po6" }, "source": [ "#### 选项 1:使预处理层成为模型的一部分" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ULGJQjP6hHvu", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model = tf.keras.Sequential([\n", " # Add the preprocessing layers you created earlier.\n", " resize_and_rescale,\n", " data_augmentation,\n", " layers.Conv2D(16, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " # Rest of your model.\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "pc6ELneyhJN9" }, "source": [ "在这种情况下,需要注意两个要点:\n", "\n", "- 数据增强将与其他层在设备端同步运行,并受益于 GPU 加速。\n", "\n", "- 当您使用 `model.save` 导出模型时,预处理层将与模型的其他部分一起保存。如果您稍后部署此模型,它将自动标准化图像(根据您的层配置)。这可以省去在服务器端重新实现该逻辑的工作。" ] }, { "cell_type": "markdown", "metadata": { "id": "syZwDSpiRXZP" }, "source": [ "注:数据增强在测试时处于停用状态,因此只有在调用 `Model.fit`(而非 `Model.evaluate` 或 `Model.predict`)期间才会对输入图像进行增强。" ] }, { "cell_type": "markdown", "metadata": { "id": "B2X3JTeY_vfv" }, "source": [ "#### 选项 2:将预处理层应用于数据集" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "r1Bt7w5VhVDY", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "aug_ds = train_ds.map(\n", " lambda x, y: (resize_and_rescale(x, training=True), y))" ] }, { "cell_type": "markdown", "metadata": { "id": "HKqeahG2hVdV" }, "source": [ "通过这种方式,您可以使用 `Dataset.map` 创建产生增强图像批次的数据集。在本例中:\n", "\n", "- 数据增强将在 CPU 上异步进行,且为非阻塞性。您可以使用 `Dataset.prefetch` 将 GPU 上的模型训练与数据数据预处理重叠,如下所示。\n", "- 在本例中,当您调用 `Model.save` 时,预处理层将不会随模型一起导出。在保存模型或在服务器端重新实现它们之前,您需要将它们附加到模型上。训练后,您可以在导出之前附加预处理层。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "cgj51k9J7jfc" }, "source": [ "您可以在[图像分类](classification.ipynb)教程中找到第一个选项的示例。我们在这里演示一下第二个选项。" ] }, { "cell_type": "markdown", "metadata": { "id": "31YwMQdrXKBP" }, "source": [ "### 将预处理层应用于数据集" ] }, { "cell_type": "markdown", "metadata": { "id": "WUgW-2LOGiOT" }, "source": [ "使用上面创建的 Keras 预处理层配置训练数据集、验证数据集和测试数据集。您还将配置数据集以提高性能,具体方式是使用并行读取和缓冲预提取从磁盘产生批次,这样不会阻塞 I/O。(您可以通过[使用 tf.data API 提高性能](https://tensorflow.google.cn/guide/data_performance)指南详细了解数据集性能)。" ] }, { "cell_type": "markdown", "metadata": { "id": "eI7VdyqK767y" }, "source": [ "注:应仅对训练集应用数据增强。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R5fGVMqlFxF7", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "batch_size = 32\n", "AUTOTUNE = tf.data.AUTOTUNE\n", "\n", "def prepare(ds, shuffle=False, augment=False):\n", " # Resize and rescale all datasets.\n", " ds = ds.map(lambda x, y: (resize_and_rescale(x), y), \n", " num_parallel_calls=AUTOTUNE)\n", "\n", " if shuffle:\n", " ds = ds.shuffle(1000)\n", "\n", " # Batch all datasets.\n", " ds = ds.batch(batch_size)\n", "\n", " # Use data augmentation only on the training set.\n", " if augment:\n", " ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), \n", " num_parallel_calls=AUTOTUNE)\n", "\n", " # Use buffered prefetching on all datasets.\n", " return ds.prefetch(buffer_size=AUTOTUNE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N86SFGMBHcx-", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_ds = prepare(train_ds, shuffle=True, augment=True)\n", "val_ds = prepare(val_ds)\n", "test_ds = prepare(test_ds)" ] }, { "cell_type": "markdown", "metadata": { "id": "9gplDz4ZV6kk" }, "source": [ "### 训练模型\n", "\n", "为了完整起见,您现在将使用刚刚准备的数据集训练模型。\n", "\n", "[序贯](https://tensorflow.google.cn/guide/keras/sequential_model)模型由三个卷积块 (`tf.keras.layers.Conv2D`) 组成,每个卷积块都有一个最大池化层 (`tf.keras.layers.MaxPooling2D`)。有一个全连接层 (`tf.keras.layers.Dense`),上面有 128 个单元,由 ReLU 激活函数 (`'relu'`) 激活。此模型尚未针对准确率进行调整(目标是展示机制)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IODSymGhq9N6", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model = tf.keras.Sequential([\n", " layers.Conv2D(16, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Conv2D(32, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Conv2D(64, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Flatten(),\n", " layers.Dense(128, activation='relu'),\n", " layers.Dense(num_classes)\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "86454855f7d9" }, "source": [ "选择 `tf.keras.optimizers.Adam` 优化器和 `tf.keras.losses.SparseCategoricalCrossentropy` 损失函数。要查看每个训练周期的训练和验证准确率,请将 `metrics` 参数传递给 `Model.compile`。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZnRJr95WY68k", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])" ] }, { "cell_type": "markdown", "metadata": { "id": "976f718cabc8" }, "source": [ "训练几个周期:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "i_sDl9uZY9Mh", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "epochs=5\n", "history = model.fit(\n", " train_ds,\n", " validation_data=val_ds,\n", " epochs=epochs\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V9PSf4qgiQJG", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "loss, acc = model.evaluate(test_ds)\n", "print(\"Accuracy\", acc)" ] }, { "cell_type": "markdown", "metadata": { "id": "0BkRvvsXb6SI" }, "source": [ "### 自定义数据增强\n", "\n", "您还可以创建自定义数据增强层。\n", "\n", "教程的这一部分展示了两种操作方式:\n", "\n", "- 首先,您将创建一个 `tf.keras.layers.Lambda` 层。这是编写简洁代码的好方式。\n", "- 接下来,您将通过[子类化](https://tensorflow.google.cn/guide/keras/custom_layers_and_models)编写一个新层,这会给您更多的控制。\n", "\n", "两个层都会根据某种概率随机反转图像中的颜色。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nMxEhIVXmAH0", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def random_invert_img(x, p=0.5):\n", " if tf.random.uniform([]) < p:\n", " x = (255-x)\n", " else:\n", " x\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C0huNpxdmDKu", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def random_invert(factor=0.5):\n", " return layers.Lambda(lambda x: random_invert_img(x, factor))\n", "\n", "random_invert = random_invert()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wAcOluP0TNG6", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "plt.figure(figsize=(10, 10))\n", "for i in range(9):\n", " augmented_image = random_invert(image)\n", " ax = plt.subplot(3, 3, i + 1)\n", " plt.imshow(augmented_image[0].numpy().astype(\"uint8\"))\n", " plt.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Xd9XG2PLM5ZJ" }, "source": [ "接下来,通过[子类化](https://tensorflow.google.cn/guide/keras/custom_layers_and_models)实现自定义层:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d11eExc-Ke-7", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class RandomInvert(layers.Layer):\n", " def __init__(self, factor=0.5, **kwargs):\n", " super().__init__(**kwargs)\n", " self.factor = factor\n", "\n", " def call(self, x):\n", " return random_invert_img(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qX-VQgkRL6fc", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "_ = plt.imshow(RandomInvert()(image)[0])" ] }, { "cell_type": "markdown", "metadata": { "id": "B0nmllnXZO6T" }, "source": [ "可以按照上述选项 1 和 2 中的描述使用这两个层。" ] }, { "cell_type": "markdown", "metadata": { "id": "j7-k__2dAfX6" }, "source": [ "## 使用 tf.image" ] }, { "cell_type": "markdown", "metadata": { "id": "NJco2x35EAMs" }, "source": [ "上述 Keras 预训练实用工具十分方便。但为了更精细的控制,您可以使用 `tf.data` 和 tf.image 编写自己的数据增强流水线或数据增强层。您还可以查看 [TensorFlow Addons 图像:运算](https://tensorflow.google.cn/io/tutorials/colorspace)和 TensorFlow I/O:色彩空间转换。" ] }, { "cell_type": "markdown", "metadata": { "id": "xR1RvjYkdd_i" }, "source": [ "由于花卉数据集之前已经配置了数据增强,因此我们将其重新导入以重新开始。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JB-lAS0z9ZJY", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "(train_ds, val_ds, test_ds), metadata = tfds.load(\n", " 'tf_flowers',\n", " split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],\n", " with_info=True,\n", " as_supervised=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "rQ3pqBTS9hNj" }, "source": [ "检索一个图像以供使用:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dDsPaAi8de_j", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "image, label = next(iter(train_ds))\n", "_ = plt.imshow(image)\n", "_ = plt.title(get_label_name(label))" ] }, { "cell_type": "markdown", "metadata": { "id": "chelxcPtFiTF" }, "source": [ "我们来使用以下函数呈现原始图像和增强图像,然后并排比较。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sN1ykjJCHikc", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def visualize(original, augmented):\n", " fig = plt.figure()\n", " plt.subplot(1,2,1)\n", " plt.title('Original image')\n", " plt.imshow(original)\n", "\n", " plt.subplot(1,2,2)\n", " plt.title('Augmented image')\n", " plt.imshow(augmented)" ] }, { "cell_type": "markdown", "metadata": { "id": "C5X4ijQYHmlt" }, "source": [ "### 数据增强" ] }, { "cell_type": "markdown", "metadata": { "id": "RRD9oujLHo6c" }, "source": [ "#### 翻转图像\n", "\n", "使用 `tf.image.flip_left_right` 垂直或水平翻转图像:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1ZjVI24nIH0S", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "flipped = tf.image.flip_left_right(image)\n", "visualize(image, flipped)" ] }, { "cell_type": "markdown", "metadata": { "id": "6iD_lLibIL9q" }, "source": [ "#### 对图像进行灰度处理\n", "\n", "您可以使用 `tf.image.rgb_to_grayscale` 对图像进行灰度处理:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ikaMj0guIRtL", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "grayscaled = tf.image.rgb_to_grayscale(image)\n", "visualize(image, tf.squeeze(grayscaled))\n", "_ = plt.colorbar()" ] }, { "cell_type": "markdown", "metadata": { "id": "f-5yjIs4IZ7v" }, "source": [ "#### 调整图像饱和度\n", "\n", "使用 `tf.image.adjust_saturation`,通过提供饱和度系数来调整图像饱和度:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PHz-NosiInmz", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "saturated = tf.image.adjust_saturation(image, 3)\n", "visualize(image, saturated)" ] }, { "cell_type": "markdown", "metadata": { "id": "FWXiy8qfIqdC" }, "source": [ "#### 更改图像亮度\n", "\n", "使用 `tf.image.adjust_brightness`,通过提供亮度系数来更改图像的亮度:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1hdG-j46I0nJ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "bright = tf.image.adjust_brightness(image, 0.4)\n", "visualize(image, bright)" ] }, { "cell_type": "markdown", "metadata": { "id": "vjEOFEITJOr2" }, "source": [ "#### 对图像进行中心裁剪\n", "\n", "使用 `tf.image.central_crop` 将图像从中心裁剪到所需部分:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RWkK5GFHJUKT", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "cropped = tf.image.central_crop(image, central_fraction=0.5)\n", "visualize(image, cropped)" ] }, { "cell_type": "markdown", "metadata": { "id": "unt76GebI3Gc" }, "source": [ "#### 旋转图像\n", "\n", "使用 `tf.image.rot90` 将图像旋转 90 度:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "b19KuAhkJKR-", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "rotated = tf.image.rot90(image)\n", "visualize(image, rotated)" ] }, { "cell_type": "markdown", "metadata": { "id": "5CPP0vEKB56X" }, "source": [ "### 随机变换\n", "\n", "警告:有两组随机图像运算:`tf.image.random*` 和 `tf.image.stateless_random*`。强烈不建议使用 `tf.image.random*` 运算,因为它们使用的是 TF 1.x 中的旧 RNG。请改用本教程中介绍的随机图像运算。有关详情,请参阅[随机数生成](../../guide/random_numbers.ipynb)。\n", "\n", "对图像应用随机变换可以进一步帮助泛化和扩展数据集。当前的 `tf.image` API 提供了 8 个这样的随机图像运算 (op):\n", "\n", "- [`tf.image.stateless_random_brightness`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_brightness)\n", "- [`tf.image.stateless_random_contrast`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_contrast)\n", "- [`tf.image.stateless_random_crop`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_crop)\n", "- [`tf.image.stateless_random_flip_left_right`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_flip_left_right)\n", "- [`tf.image.stateless_random_flip_up_down`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_flip_up_down)\n", "- [`tf.image.stateless_random_hue`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_hue)\n", "- [`tf.image.stateless_random_jpeg_quality`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_jpeg_quality)\n", "- [`tf.image.stateless_random_saturation`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_saturation)\n", "\n", "这些随机图像运算纯粹是功能性的:输出仅取决于输入。这使得它们易于在高性能、确定性的输入流水线中使用。它们要求每一步都输入一个 `seed` 值。给定相同的 `seed`,无论被调用多少次,它们都会返回相同的结果。\n", "\n", "注:`seed` 是形状为 `(2,)` 的 `Tensor`,其值为任意整数。\n", "\n", "在以下部分中,您将:\n", "\n", "1. 回顾使用随机图像运算来变换图像的示例。\n", "2. 演示如何将随机变换应用于训练数据集。" ] }, { "cell_type": "markdown", "metadata": { "id": "251Wy-MqE4La" }, "source": [ "#### 随机更改图像亮度\n", "\n", "通过提供亮度系数和 `seed`,使用 `tf.image.stateless_random_brightness` 随机更改 `image` 的亮度。亮度系数在 `[-max_delta, max_delta)` 范围内随机选择,并与给定的 `seed` 相关联。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-fFd1kh7Fr-_", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for i in range(3):\n", " seed = (i, 0) # tuple of size (2,)\n", " stateless_random_brightness = tf.image.stateless_random_brightness(\n", " image, max_delta=0.95, seed=seed)\n", " visualize(image, stateless_random_brightness)" ] }, { "cell_type": "markdown", "metadata": { "id": "uLaDEmooUfYJ" }, "source": [ "#### 随机更改图像对比度\n", "\n", "通过提供对比度范围和 `seed`,使用 `tf.image.stateless_random_contrast` 随机更改 `image` 的对比度。对比度范围在区间 `[lower, upper]` 中随机选择,并与给定的 `seed` 相关联。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GmcYoQHaUoke", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for i in range(3):\n", " seed = (i, 0) # tuple of size (2,)\n", " stateless_random_contrast = tf.image.stateless_random_contrast(\n", " image, lower=0.1, upper=0.9, seed=seed)\n", " visualize(image, stateless_random_contrast)" ] }, { "cell_type": "markdown", "metadata": { "id": "wxb-MP-KVPNz" }, "source": [ "#### 随机裁剪图像\n", "\n", "通过提供目标 `size` 和 `seed`,使用 `tf.image.stateless_random_crop` 随机裁剪 `image`。从 `image` 中裁剪出来的部分位于随机选择的偏移处,并与给定的 `seed` 相关联。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vtZQbUw0VOm5", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for i in range(3):\n", " seed = (i, 0) # tuple of size (2,)\n", " stateless_random_crop = tf.image.stateless_random_crop(\n", " image, size=[210, 300, 3], seed=seed)\n", " visualize(image, stateless_random_crop)" ] }, { "cell_type": "markdown", "metadata": { "id": "isrM-MZtpxTq" }, "source": [ "### 对数据集应用增强\n", "\n", "我们首先再次下载图像数据集,以防它们在之前的部分中被修改。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xC80NQP809Uo", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "(train_datasets, val_ds, test_ds), metadata = tfds.load(\n", " 'tf_flowers',\n", " split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],\n", " with_info=True,\n", " as_supervised=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "SMo9HTDV0Gaz" }, "source": [ "接下来,定义一个用于调整图像大小和重新缩放图像的效用函数。此函数将用于统一数据集中图像的大小和比例:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1JKmx06lfcFr", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def resize_and_rescale(image, label):\n", " image = tf.cast(image, tf.float32)\n", " image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])\n", " image = (image / 255.0)\n", " return image, label" ] }, { "cell_type": "markdown", "metadata": { "id": "M7OpE_-jWq-I" }, "source": [ "我们同时定义 `augment` 函数,该函数可以将随机变换应用于图像。此函数将在下一步中用于数据集。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KitLdvlpVxPa", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def augment(image_label, seed):\n", " image, label = image_label\n", " image, label = resize_and_rescale(image, label)\n", " image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 6, IMG_SIZE + 6)\n", " # Make a new seed.\n", " new_seed = tf.random.split(seed, num=1)[0, :]\n", " # Random crop back to the original size.\n", " image = tf.image.stateless_random_crop(\n", " image, size=[IMG_SIZE, IMG_SIZE, 3], seed=seed)\n", " # Random brightness.\n", " image = tf.image.stateless_random_brightness(\n", " image, max_delta=0.5, seed=new_seed)\n", " image = tf.clip_by_value(image, 0, 1)\n", " return image, label" ] }, { "cell_type": "markdown", "metadata": { "id": "SlXRsVp70hg8" }, "source": [ "#### 选项 1:使用 tf.data.experimental.Counter\n", "\n", "创建一个 `tf.data.experimental.Counter()` 对象(我们称之为 `counter`),并使用 `(counter, counter)` `Dataset.zip` 数据集。这将确保数据集中的每个图像都与一个基于 `counter` 的唯一值(形状为 `(2,)`)相关联,稍后可以将其传递到 `augment` 函数,作为随机变换的 `seed` 值。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SZ6Qq0IWznfi", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Create a `Counter` object and `Dataset.zip` it together with the training set.\n", "counter = tf.data.experimental.Counter()\n", "train_ds = tf.data.Dataset.zip((train_datasets, (counter, counter)))" ] }, { "cell_type": "markdown", "metadata": { "id": "eF9ybVQ94X9f" }, "source": [ "将 `augment` 函数映射到训练数据集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wQK9BDKk1_3N", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_ds = (\n", " train_ds\n", " .shuffle(1000)\n", " .map(augment, num_parallel_calls=AUTOTUNE)\n", " .batch(batch_size)\n", " .prefetch(AUTOTUNE)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3AQoyA-k3ELk", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "val_ds = (\n", " val_ds\n", " .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)\n", " .batch(batch_size)\n", " .prefetch(AUTOTUNE)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "p2IQN3NN3G_M", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "test_ds = (\n", " test_ds\n", " .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)\n", " .batch(batch_size)\n", " .prefetch(AUTOTUNE)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "pvTVY8BY2LpD" }, "source": [ "#### 选项 2:使用 `tf.random.Generator`\n", "\n", "- 创建一个具有初始 `seed` 值的 `tf.random.Generator` 对象。在同一个生成器对象上调用 `make_seeds` 函数会始终返回一个新的、唯一的 `seed` 值。\n", "- 定义一个封装容器函数:1) 调用 `make_seeds` 函数;2) 将新生成的 `seed` 值传递给 `augment` 函数进行随机变换。\n", "\n", "注:`tf.random.Generator` 对象会将 RNG 状态存储在 `tf.Variable` 中,这意味着它可以保存为[检查点](../../guide/checkpoint.ipynb)或以 [SavedModel](../../guide/saved_model.ipynb) 格式保存。有关详情,请参阅[随机数生成](../../guide/random_numbers.ipynb)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BQDvedZ33eAy", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Create a generator.\n", "rng = tf.random.Generator.from_seed(123, alg='philox')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eDEkO1nt2ta0", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Create a wrapper function for updating seeds.\n", "def f(x, y):\n", " seed = rng.make_seeds(2)[0]\n", " image, label = augment((x, y), seed)\n", " return image, label" ] }, { "cell_type": "markdown", "metadata": { "id": "PyPC4vUM4MT0" }, "source": [ "将封装容器函数 `f` 映射到训练数据集,并将 `resize_and_rescale` 函数映射到验证集和测试集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Pu2uB7k12xKw", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_ds = (\n", " train_datasets\n", " .shuffle(1000)\n", " .map(f, num_parallel_calls=AUTOTUNE)\n", " .batch(batch_size)\n", " .prefetch(AUTOTUNE)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e6caldPi2HAP", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "val_ds = (\n", " val_ds\n", " .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)\n", " .batch(batch_size)\n", " .prefetch(AUTOTUNE)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ceaCdJnh2I-r", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "test_ds = (\n", " test_ds\n", " .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)\n", " .batch(batch_size)\n", " .prefetch(AUTOTUNE)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "hKwCA6AOjTrc" }, "source": [ "这些数据集现在可以用于训练模型了,如前文所述。" ] }, { "cell_type": "markdown", "metadata": { "id": "YypDihDlj0no" }, "source": [ "## 后续步骤\n", "\n", "本教程演示了使用 Keras 预处理层和 `tf.image` 进行数据增强。\n", "\n", "- 要了解如何在模型中包含预处理层,请参阅[图像分类](classification.ipynb)教程。\n", "- 您可能也有兴趣了解预处理层如何帮助您对文本进行分类,请参阅[基本文本分类](../keras/text_classification.ipynb)教程。\n", "- 您可以在此指南中了解有关 tf.data 的更多信息,并且可以在[这里](../../guide/data_performance.ipynb)了解如何配置输入流水线以提高性能。" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "data_augmentation.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }