{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "4EFY9e5wRn7v" }, "outputs": [], "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "pkTRazeVRwDe", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "VyOckJu6Rs-i" }, "source": [ "# 数据增强" ] }, { "cell_type": "markdown", "metadata": { "id": "0HEsULqDR7AH" }, "source": [ "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
tf.image
编写自己的数据增强流水线或数据增强层。您还可以查看 [TensorFlow Addons 图像:运算](https://tensorflow.google.cn/io/tutorials/colorspace)和 TensorFlow I/O:色彩空间转换。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xR1RvjYkdd_i"
},
"source": [
"由于花卉数据集之前已经配置了数据增强,因此我们将其重新导入以重新开始。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JB-lAS0z9ZJY",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"(train_ds, val_ds, test_ds), metadata = tfds.load(\n",
" 'tf_flowers',\n",
" split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],\n",
" with_info=True,\n",
" as_supervised=True,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rQ3pqBTS9hNj"
},
"source": [
"检索一个图像以供使用:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dDsPaAi8de_j",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"image, label = next(iter(train_ds))\n",
"_ = plt.imshow(image)\n",
"_ = plt.title(get_label_name(label))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "chelxcPtFiTF"
},
"source": [
"我们来使用以下函数呈现原始图像和增强图像,然后并排比较。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sN1ykjJCHikc",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def visualize(original, augmented):\n",
" fig = plt.figure()\n",
" plt.subplot(1,2,1)\n",
" plt.title('Original image')\n",
" plt.imshow(original)\n",
"\n",
" plt.subplot(1,2,2)\n",
" plt.title('Augmented image')\n",
" plt.imshow(augmented)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "C5X4ijQYHmlt"
},
"source": [
"### 数据增强"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RRD9oujLHo6c"
},
"source": [
"#### 翻转图像\n",
"\n",
"使用 `tf.image.flip_left_right` 垂直或水平翻转图像:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1ZjVI24nIH0S",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"flipped = tf.image.flip_left_right(image)\n",
"visualize(image, flipped)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6iD_lLibIL9q"
},
"source": [
"#### 对图像进行灰度处理\n",
"\n",
"您可以使用 `tf.image.rgb_to_grayscale` 对图像进行灰度处理:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ikaMj0guIRtL",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"grayscaled = tf.image.rgb_to_grayscale(image)\n",
"visualize(image, tf.squeeze(grayscaled))\n",
"_ = plt.colorbar()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f-5yjIs4IZ7v"
},
"source": [
"#### 调整图像饱和度\n",
"\n",
"使用 `tf.image.adjust_saturation`,通过提供饱和度系数来调整图像饱和度:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PHz-NosiInmz",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"saturated = tf.image.adjust_saturation(image, 3)\n",
"visualize(image, saturated)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FWXiy8qfIqdC"
},
"source": [
"#### 更改图像亮度\n",
"\n",
"使用 `tf.image.adjust_brightness`,通过提供亮度系数来更改图像的亮度:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1hdG-j46I0nJ",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"bright = tf.image.adjust_brightness(image, 0.4)\n",
"visualize(image, bright)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vjEOFEITJOr2"
},
"source": [
"#### 对图像进行中心裁剪\n",
"\n",
"使用 `tf.image.central_crop` 将图像从中心裁剪到所需部分:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RWkK5GFHJUKT",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"cropped = tf.image.central_crop(image, central_fraction=0.5)\n",
"visualize(image, cropped)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "unt76GebI3Gc"
},
"source": [
"#### 旋转图像\n",
"\n",
"使用 `tf.image.rot90` 将图像旋转 90 度:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "b19KuAhkJKR-",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"rotated = tf.image.rot90(image)\n",
"visualize(image, rotated)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5CPP0vEKB56X"
},
"source": [
"### 随机变换\n",
"\n",
"警告:有两组随机图像运算:`tf.image.random*` 和 `tf.image.stateless_random*`。强烈不建议使用 `tf.image.random*` 运算,因为它们使用的是 TF 1.x 中的旧 RNG。请改用本教程中介绍的随机图像运算。有关详情,请参阅[随机数生成](../../guide/random_numbers.ipynb)。\n",
"\n",
"对图像应用随机变换可以进一步帮助泛化和扩展数据集。当前的 `tf.image` API 提供了 8 个这样的随机图像运算 (op):\n",
"\n",
"- [`tf.image.stateless_random_brightness`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_brightness)\n",
"- [`tf.image.stateless_random_contrast`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_contrast)\n",
"- [`tf.image.stateless_random_crop`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_crop)\n",
"- [`tf.image.stateless_random_flip_left_right`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_flip_left_right)\n",
"- [`tf.image.stateless_random_flip_up_down`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_flip_up_down)\n",
"- [`tf.image.stateless_random_hue`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_hue)\n",
"- [`tf.image.stateless_random_jpeg_quality`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_jpeg_quality)\n",
"- [`tf.image.stateless_random_saturation`](https://tensorflow.google.cn/api_docs/python/tf/image/stateless_random_saturation)\n",
"\n",
"这些随机图像运算纯粹是功能性的:输出仅取决于输入。这使得它们易于在高性能、确定性的输入流水线中使用。它们要求每一步都输入一个 `seed` 值。给定相同的 `seed`,无论被调用多少次,它们都会返回相同的结果。\n",
"\n",
"注:`seed` 是形状为 `(2,)` 的 `Tensor`,其值为任意整数。\n",
"\n",
"在以下部分中,您将:\n",
"\n",
"1. 回顾使用随机图像运算来变换图像的示例。\n",
"2. 演示如何将随机变换应用于训练数据集。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "251Wy-MqE4La"
},
"source": [
"#### 随机更改图像亮度\n",
"\n",
"通过提供亮度系数和 `seed`,使用 `tf.image.stateless_random_brightness` 随机更改 `image` 的亮度。亮度系数在 `[-max_delta, max_delta)` 范围内随机选择,并与给定的 `seed` 相关联。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-fFd1kh7Fr-_",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"for i in range(3):\n",
" seed = (i, 0) # tuple of size (2,)\n",
" stateless_random_brightness = tf.image.stateless_random_brightness(\n",
" image, max_delta=0.95, seed=seed)\n",
" visualize(image, stateless_random_brightness)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uLaDEmooUfYJ"
},
"source": [
"#### 随机更改图像对比度\n",
"\n",
"通过提供对比度范围和 `seed`,使用 `tf.image.stateless_random_contrast` 随机更改 `image` 的对比度。对比度范围在区间 `[lower, upper]` 中随机选择,并与给定的 `seed` 相关联。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GmcYoQHaUoke",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"for i in range(3):\n",
" seed = (i, 0) # tuple of size (2,)\n",
" stateless_random_contrast = tf.image.stateless_random_contrast(\n",
" image, lower=0.1, upper=0.9, seed=seed)\n",
" visualize(image, stateless_random_contrast)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wxb-MP-KVPNz"
},
"source": [
"#### 随机裁剪图像\n",
"\n",
"通过提供目标 `size` 和 `seed`,使用 `tf.image.stateless_random_crop` 随机裁剪 `image`。从 `image` 中裁剪出来的部分位于随机选择的偏移处,并与给定的 `seed` 相关联。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vtZQbUw0VOm5",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"for i in range(3):\n",
" seed = (i, 0) # tuple of size (2,)\n",
" stateless_random_crop = tf.image.stateless_random_crop(\n",
" image, size=[210, 300, 3], seed=seed)\n",
" visualize(image, stateless_random_crop)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "isrM-MZtpxTq"
},
"source": [
"### 对数据集应用增强\n",
"\n",
"我们首先再次下载图像数据集,以防它们在之前的部分中被修改。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xC80NQP809Uo",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"(train_datasets, val_ds, test_ds), metadata = tfds.load(\n",
" 'tf_flowers',\n",
" split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],\n",
" with_info=True,\n",
" as_supervised=True,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SMo9HTDV0Gaz"
},
"source": [
"接下来,定义一个用于调整图像大小和重新缩放图像的效用函数。此函数将用于统一数据集中图像的大小和比例:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1JKmx06lfcFr",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def resize_and_rescale(image, label):\n",
" image = tf.cast(image, tf.float32)\n",
" image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])\n",
" image = (image / 255.0)\n",
" return image, label"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M7OpE_-jWq-I"
},
"source": [
"我们同时定义 `augment` 函数,该函数可以将随机变换应用于图像。此函数将在下一步中用于数据集。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KitLdvlpVxPa",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def augment(image_label, seed):\n",
" image, label = image_label\n",
" image, label = resize_and_rescale(image, label)\n",
" image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 6, IMG_SIZE + 6)\n",
" # Make a new seed.\n",
" new_seed = tf.random.split(seed, num=1)[0, :]\n",
" # Random crop back to the original size.\n",
" image = tf.image.stateless_random_crop(\n",
" image, size=[IMG_SIZE, IMG_SIZE, 3], seed=seed)\n",
" # Random brightness.\n",
" image = tf.image.stateless_random_brightness(\n",
" image, max_delta=0.5, seed=new_seed)\n",
" image = tf.clip_by_value(image, 0, 1)\n",
" return image, label"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SlXRsVp70hg8"
},
"source": [
"#### 选项 1:使用 tf.data.experimental.Counter\n",
"\n",
"创建一个 `tf.data.experimental.Counter()` 对象(我们称之为 `counter`),并使用 `(counter, counter)` `Dataset.zip` 数据集。这将确保数据集中的每个图像都与一个基于 `counter` 的唯一值(形状为 `(2,)`)相关联,稍后可以将其传递到 `augment` 函数,作为随机变换的 `seed` 值。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SZ6Qq0IWznfi",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"# Create a `Counter` object and `Dataset.zip` it together with the training set.\n",
"counter = tf.data.experimental.Counter()\n",
"train_ds = tf.data.Dataset.zip((train_datasets, (counter, counter)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eF9ybVQ94X9f"
},
"source": [
"将 `augment` 函数映射到训练数据集:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wQK9BDKk1_3N",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"train_ds = (\n",
" train_ds\n",
" .shuffle(1000)\n",
" .map(augment, num_parallel_calls=AUTOTUNE)\n",
" .batch(batch_size)\n",
" .prefetch(AUTOTUNE)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3AQoyA-k3ELk",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"val_ds = (\n",
" val_ds\n",
" .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)\n",
" .batch(batch_size)\n",
" .prefetch(AUTOTUNE)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "p2IQN3NN3G_M",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"test_ds = (\n",
" test_ds\n",
" .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)\n",
" .batch(batch_size)\n",
" .prefetch(AUTOTUNE)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pvTVY8BY2LpD"
},
"source": [
"#### 选项 2:使用 `tf.random.Generator`\n",
"\n",
"- 创建一个具有初始 `seed` 值的 `tf.random.Generator` 对象。在同一个生成器对象上调用 `make_seeds` 函数会始终返回一个新的、唯一的 `seed` 值。\n",
"- 定义一个封装容器函数:1) 调用 `make_seeds` 函数;2) 将新生成的 `seed` 值传递给 `augment` 函数进行随机变换。\n",
"\n",
"注:`tf.random.Generator` 对象会将 RNG 状态存储在 `tf.Variable` 中,这意味着它可以保存为[检查点](../../guide/checkpoint.ipynb)或以 [SavedModel](../../guide/saved_model.ipynb) 格式保存。有关详情,请参阅[随机数生成](../../guide/random_numbers.ipynb)。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BQDvedZ33eAy",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"# Create a generator.\n",
"rng = tf.random.Generator.from_seed(123, alg='philox')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eDEkO1nt2ta0",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"# Create a wrapper function for updating seeds.\n",
"def f(x, y):\n",
" seed = rng.make_seeds(2)[0]\n",
" image, label = augment((x, y), seed)\n",
" return image, label"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PyPC4vUM4MT0"
},
"source": [
"将封装容器函数 `f` 映射到训练数据集,并将 `resize_and_rescale` 函数映射到验证集和测试集:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Pu2uB7k12xKw",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"train_ds = (\n",
" train_datasets\n",
" .shuffle(1000)\n",
" .map(f, num_parallel_calls=AUTOTUNE)\n",
" .batch(batch_size)\n",
" .prefetch(AUTOTUNE)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "e6caldPi2HAP",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"val_ds = (\n",
" val_ds\n",
" .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)\n",
" .batch(batch_size)\n",
" .prefetch(AUTOTUNE)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ceaCdJnh2I-r",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"test_ds = (\n",
" test_ds\n",
" .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)\n",
" .batch(batch_size)\n",
" .prefetch(AUTOTUNE)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hKwCA6AOjTrc"
},
"source": [
"这些数据集现在可以用于训练模型了,如前文所述。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YypDihDlj0no"
},
"source": [
"## 后续步骤\n",
"\n",
"本教程演示了使用 Keras 预处理层和 `tf.image` 进行数据增强。\n",
"\n",
"- 要了解如何在模型中包含预处理层,请参阅[图像分类](classification.ipynb)教程。\n",
"- 您可能也有兴趣了解预处理层如何帮助您对文本进行分类,请参阅[基本文本分类](../keras/text_classification.ipynb)教程。\n",
"- 您可以在此指南中了解有关 tf.data
的更多信息,并且可以在[这里](../../guide/data_performance.ipynb)了解如何配置输入流水线以提高性能。"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "data_augmentation.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}