{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "mt9dL5dIir8X" }, "outputs": [], "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "ufPx7EiCiqgR", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ucMoYase6URl" }, "source": [ "# 加载和预处理图像" ] }, { "cell_type": "markdown", "metadata": { "id": "_Wwu5SXZmEkB" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "Oxw4WahM7DU9" }, "source": [ "本教程介绍如何以三种方式加载和预处理图像数据集:\n", "\n", "- 首先,您将使用高级 Keras 预处理效用函数(例如 `tf.keras.utils.image_dataset_from_directory`)和层(例如 `tf.keras.layers.Rescaling`)来读取磁盘上的图像目录。\n", "- 然后,您将[使用 tf.data](../../guide/data.ipynb) 从头编写自己的输入流水线。\n", "- 最后,您将从 [TensorFlow Datasets](https://tensorflow.google.cn/datasets) 中的大型[目录](https://tensorflow.google.cn/datasets/catalog/overview)下载数据集。" ] }, { "cell_type": "markdown", "metadata": { "id": "hoQQiZDB6URn" }, "source": [ "## 配置" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3vhAMaIOBIee", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import numpy as np\n", "import os\n", "import PIL\n", "import PIL.Image\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Qnp9Z2sT5dWj", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "print(tf.__version__)" ] }, { "cell_type": "markdown", "metadata": { "id": "wO0InzL66URu" }, "source": [ "### 检索图片\n", "\n", "本教程使用一个包含数千张花卉照片的数据集。该花卉数据集包含 5 个子目录,每个子目录对应一个类:\n", "\n", "```\n", "flowers_photos/\n", " daisy/\n", " dandelion/\n", " roses/\n", " sunflowers/\n", " tulips/\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "Ju2yXtdV5YaT" }, "source": [ "注:所有图像均获得 CC-BY 许可,创作者在 LICENSE.txt 文件中列出。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rN-Pc6Zd6awg", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import pathlib\n", "dataset_url = \"https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\"\n", "archive = tf.keras.utils.get_file(origin=dataset_url, extract=True)\n", "data_dir = pathlib.Path(archive).with_suffix('')" ] }, { "cell_type": "markdown", "metadata": { "id": "rFkFK74oO--g" }, "source": [ "下载 (218MB) 后,您现在应该拥有花卉照片的副本。总共有 3670 个图像:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QhewYCxhXQBX", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "image_count = len(list(data_dir.glob('*/*.jpg')))\n", "print(image_count)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUFusk44d9GW" }, "source": [ "每个目录都包含该类型花卉的图像。下面是一些玫瑰:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "crs7ZjEp60Ot", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "roses = list(data_dir.glob('roses/*'))\n", "PIL.Image.open(str(roses[0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oV9PtjdKKWyI", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "roses = list(data_dir.glob('roses/*'))\n", "PIL.Image.open(str(roses[1]))" ] }, { "cell_type": "markdown", "metadata": { "id": "9_kge08gSCan" }, "source": [ "## 使用 Keras 效用函数加载数据\n", "\n", "让我们使用实用的 `tf.keras.utils.image_dataset_from_directory` 效用函数从磁盘加载这些图像。" ] }, { "cell_type": "markdown", "metadata": { "id": "6jobDTUs8Wxu" }, "source": [ "### 创建数据集" ] }, { "cell_type": "markdown", "metadata": { "id": "lAmtzsnjDNhB" }, "source": [ "为加载器定义一些参数:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qJdpyqK541ty", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "batch_size = 32\n", "img_height = 180\n", "img_width = 180" ] }, { "cell_type": "markdown", "metadata": { "id": "ehhW308g8soJ" }, "source": [ "开发模型时,最好使用验证拆分。您将使用 80% 的图像进行训练,20% 的图像进行验证。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "chqakIP14PDm", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_ds = tf.keras.utils.image_dataset_from_directory(\n", " data_dir,\n", " validation_split=0.2,\n", " subset=\"training\",\n", " seed=123,\n", " image_size=(img_height, img_width),\n", " batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pb2Af2lsUShk", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "val_ds = tf.keras.utils.image_dataset_from_directory(\n", " data_dir,\n", " validation_split=0.2,\n", " subset=\"validation\",\n", " seed=123,\n", " image_size=(img_height, img_width),\n", " batch_size=batch_size)" ] }, { "cell_type": "markdown", "metadata": { "id": "Ug3ITsz0b_cF" }, "source": [ "您可以在这些数据集的 `class_names` 特性中找到类名称。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R7z2yKt7VDPJ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class_names = train_ds.class_names\n", "print(class_names)" ] }, { "cell_type": "markdown", "metadata": { "id": "bK6CQCqIctCd" }, "source": [ "### 呈现数据\n", "\n", "下面是训练数据集中的前 9 个图像。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AAY3LJN28Kuy", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.figure(figsize=(10, 10))\n", "for images, labels in train_ds.take(1):\n", " for i in range(9):\n", " ax = plt.subplot(3, 3, i + 1)\n", " plt.imshow(images[i].numpy().astype(\"uint8\"))\n", " plt.title(class_names[labels[i]])\n", " plt.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": { "id": "jUI0fr7igPtA" }, "source": [ "您可以使用这些数据集来训练模型,方法是将它们传递给 `model.fit`(在本教程后面展示)。如果愿意,您还可以手动迭代数据集并检索批量图像:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BdPHeHXt9sjA", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for image_batch, labels_batch in train_ds:\n", " print(image_batch.shape)\n", " print(labels_batch.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": { "id": "2ZgIZeXaDUsF" }, "source": [ "`image_batch` 是形状为 `(32, 180, 180, 3)` 的张量。这是由 32 个形状为 `180x180x3`(最后一个维度是指颜色通道 RGB)的图像组成的批次。`label_batch` 是形状为 `(32,)` 的张量,这些是 32 个图像的对应标签。\n", "\n", "您可以对这些张量中的任何一个调用 `.numpy()` 以将它们转换为 `numpy.ndarray`。" ] }, { "cell_type": "markdown", "metadata": { "id": "Ybl6a2YCg1rV" }, "source": [ "### 标准化数据\n" ] }, { "cell_type": "markdown", "metadata": { "id": "IdogGjM2K6OU" }, "source": [ "RGB 通道值在 `[0, 255]` 范围内。这对于神经网络来说并不理想;一般而言,您应当设法使您的输入值变小。\n", "\n", "在这里,我们通过使用 `tf.keras.layers.Rescaling` 将值标准化为在 `[0, 1]` 范围内。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "16yNdZXdExyM", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "normalization_layer = tf.keras.layers.Rescaling(1./255)" ] }, { "cell_type": "markdown", "metadata": { "id": "Nd0_enkb8uxZ" }, "source": [ "可以通过两种方式使用该层。您可以通过调用 `Dataset.map` 将其应用于数据集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QgOnza-U_z5Y", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))\n", "image_batch, labels_batch = next(iter(normalized_ds))\n", "first_image = image_batch[0]\n", "# Notice the pixel values are now in `[0,1]`.\n", "print(np.min(first_image), np.max(first_image))" ] }, { "cell_type": "markdown", "metadata": { "id": "z39nXayj9ioS" }, "source": [ "或者,您也可以在模型定义中包含该层以简化部署。在这里,您将使用第二种方式。" ] }, { "cell_type": "markdown", "metadata": { "id": "hXLd3wMpDIkp" }, "source": [ "注:如果您想将像素值缩放到 `[-1,1]`,则可以改为编写 `tf.keras.layers.Rescaling(1./127.5, offset=-1)`" ] }, { "cell_type": "markdown", "metadata": { "id": "LeNWVa8qRBGm" }, "source": [ "注:您之前使用 `tf.keras.utils.image_dataset_from_directory` 的 `image_size` 参数调整了图像大小。如果您还希望在模型中包括调整大小的逻辑,可以使用 `tf.keras.layers.Resizing` 层。" ] }, { "cell_type": "markdown", "metadata": { "id": "Ti8avTlLofoJ" }, "source": [ "### 配置数据集以提高性能\n", "\n", "我们确保使用缓冲预获取,以便您可以从磁盘生成数据,而不会导致 I/O 阻塞。下面是加载数据时应当使用的两个重要方法。\n", "\n", "- 在第一个周期期间从磁盘加载图像后,`Dataset.cache()` 会将这些图像保留在内存中。这将确保在训练模型时数据集不会成为瓶颈。如果数据集太大无法装入内存,您也可以使用此方法创建高性能的磁盘缓存。\n", "- `Dataset.prefetch()` 会在训练时将数据预处理和模型执行重叠。\n", "\n", "感兴趣的读者可以在使用 tf.data API 提升性能指南的预提取部分了解更多有关这两种方法的详细信息,以及如何将数据缓存到磁盘。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ea3kbMe-pGDw", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "AUTOTUNE = tf.data.AUTOTUNE\n", "\n", "train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)\n", "val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)" ] }, { "cell_type": "markdown", "metadata": { "id": "XqHjIr6cplwY" }, "source": [ "### 训练模型\n", "\n", "为了完整起见,您现在将使用刚刚准备的数据集来训练一个简单模型。\n", "\n", "[序贯](https://tensorflow.google.cn/guide/keras/sequential_model)模型由三个卷积块 (`tf.keras.layers.Conv2D`) 组成,每个卷积块都有一个最大池化层 (`tf.keras.layers.MaxPooling2D`)。有一个全连接层 (`tf.keras.layers.Dense`),上面有 128 个单元,由 ReLU 激活函数 (`'relu'`) 激活。此模型尚未进行任何调整(目标是使用您刚刚创建的数据集展示机制)。要详细了解图像分类,请访问[图像分类](../images/classification.ipynb)教程。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LdR0BzCcqxw0", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "num_classes = 5\n", "\n", "model = tf.keras.Sequential([\n", " tf.keras.layers.Rescaling(1./255),\n", " tf.keras.layers.Conv2D(32, 3, activation='relu'),\n", " tf.keras.layers.MaxPooling2D(),\n", " tf.keras.layers.Conv2D(32, 3, activation='relu'),\n", " tf.keras.layers.MaxPooling2D(),\n", " tf.keras.layers.Conv2D(32, 3, activation='relu'),\n", " tf.keras.layers.MaxPooling2D(),\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(128, activation='relu'),\n", " tf.keras.layers.Dense(num_classes)\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "d83f5aa7f3fb" }, "source": [ "选择 `tf.keras.optimizers.Adam` 优化器和 `tf.keras.losses.SparseCategoricalCrossentropy` 损失函数。要查看每个训练周期的训练和验证准确率,请将 `metrics` 参数传递给 `Model.compile`。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t_BlmsnmsEr4", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model.compile(\n", " optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])" ] }, { "cell_type": "markdown", "metadata": { "id": "ffwd44ldNMOE" }, "source": [ "注:您将仅训练几个周期,因此本教程的运行速度很快。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "S08ZKKODsnGW", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model.fit(\n", " train_ds,\n", " validation_data=val_ds,\n", " epochs=3\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "MEtT9YGjSAOK" }, "source": [ "注:您也可以编写自定义训练循环而不是使用 `Model.fit`。要了解详情,请访问[从头编写训练循环](https://tensorflow.google.cn/guide/keras/writing_a_training_loop_from_scratch)教程。" ] }, { "cell_type": "markdown", "metadata": { "id": "BaW4wx5L7hrZ" }, "source": [ "您可能会注意到,与训练准确率相比,验证准确率较低,这表明我们的模型存在过拟合。您可以在此[教程](https://tensorflow.google.cn/tutorials/keras/overfit_and_underfit)中详细了解过拟合以及如何减少过拟合。" ] }, { "cell_type": "markdown", "metadata": { "id": "AxS1cLzM8mEp" }, "source": [ "## 使用 tf.data 进行更精细的控制" ] }, { "cell_type": "markdown", "metadata": { "id": "Ylj9fgkamgWZ" }, "source": [ "利用上面的 Keras 预处理效用函数 `tf.keras.utils.image_dataset_from_directory`,可以方便地从头创建 `tf.data.Dataset`。\n", "\n", "要实现更精细的控制,您可以使用 tf.data 编写自己的输入流水线。本部分展示了如何做到这一点,从我们之前下载的 TGZ 文件中的文件路径开始。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lAkQp5uxoINu", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'), shuffle=False)\n", "list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "coORvEH-NGwc", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for f in list_ds.take(5):\n", " print(f.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "6NLQ_VJhWO4z" }, "source": [ "文件的树结构可用于编译 `class_names` 列表。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uRPHzDGhKACK", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class_names = np.array(sorted([item.name for item in data_dir.glob('*') if item.name != \"LICENSE.txt\"]))\n", "print(class_names)" ] }, { "cell_type": "markdown", "metadata": { "id": "CiptrWmAlmAa" }, "source": [ "将数据集拆分为训练集和测试集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GWHNPzXclpVr", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "val_size = int(image_count * 0.2)\n", "train_ds = list_ds.skip(val_size)\n", "val_ds = list_ds.take(val_size)" ] }, { "cell_type": "markdown", "metadata": { "id": "rkB-IR4-pS3U" }, "source": [ "您可以按照如下方式打印每个数据集的长度:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SiKQrb9ppS-7", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "print(tf.data.experimental.cardinality(train_ds).numpy())\n", "print(tf.data.experimental.cardinality(val_ds).numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "91CPfUUJ_8SZ" }, "source": [ "编写一个将文件路径转换为 `(img, label)` 对的短函数:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "arSQzIey-4D4", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def get_label(file_path):\n", " # Convert the path to a list of path components\n", " parts = tf.strings.split(file_path, os.path.sep)\n", " # The second to last is the class-directory\n", " one_hot = parts[-2] == class_names\n", " # Integer encode the label\n", " return tf.argmax(one_hot)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MGlq4IP4Aktb", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def decode_img(img):\n", " # Convert the compressed string to a 3D uint8 tensor\n", " img = tf.io.decode_jpeg(img, channels=3)\n", " # Resize the image to the desired size\n", " return tf.image.resize(img, [img_height, img_width])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-xhBRgvNqRRe", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def process_path(file_path):\n", " label = get_label(file_path)\n", " # Load the raw data from the file as a string\n", " img = tf.io.read_file(file_path)\n", " img = decode_img(img)\n", " return img, label" ] }, { "cell_type": "markdown", "metadata": { "id": "S9a5GpsUOBx8" }, "source": [ "使用 `Dataset.map` 创建 `image, label` 对的数据集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3SDhbo8lOBQv", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.\n", "train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)\n", "val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kxrl0lGdnpRz", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for image, label in train_ds.take(1):\n", " print(\"Image shape: \", image.numpy().shape)\n", " print(\"Label: \", label.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "vYGCgJuR_9Qp" }, "source": [ "### 训练的基本方法" ] }, { "cell_type": "markdown", "metadata": { "id": "wwZavzgsIytz" }, "source": [ "要使用此数据集训练模型,你将会想要数据:\n", "\n", "- 被充分打乱。\n", "- 被分割为 batch。\n", "- 永远重复。\n", "\n", "使用 `tf.data` API 可以轻松添加这些功能。有关详情,请访问[输入流水线性能](../../guide/performance/datasets.ipynb)指南。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uZmZJx8ePw_5", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def configure_for_performance(ds):\n", " ds = ds.cache()\n", " ds = ds.shuffle(buffer_size=1000)\n", " ds = ds.batch(batch_size)\n", " ds = ds.prefetch(buffer_size=AUTOTUNE)\n", " return ds\n", "\n", "train_ds = configure_for_performance(train_ds)\n", "val_ds = configure_for_performance(val_ds)" ] }, { "cell_type": "markdown", "metadata": { "id": "45P7OvzRWzOB" }, "source": [ "### 呈现数据\n", "\n", "您可以通过与之前创建的数据集类似的方式呈现此数据集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UN_Dnl72YNIj", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "image_batch, label_batch = next(iter(train_ds))\n", "\n", "plt.figure(figsize=(10, 10))\n", "for i in range(9):\n", " ax = plt.subplot(3, 3, i + 1)\n", " plt.imshow(image_batch[i].numpy().astype(\"uint8\"))\n", " label = label_batch[i]\n", " plt.title(class_names[label])\n", " plt.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": { "id": "fMT8kh_uXPRU" }, "source": [ "### 继续训练模型\n", "\n", "您现在已经手动构建了一个与由上面的 `keras.preprocessing` 创建的数据集类似的 `tf.data.Dataset`。您可以继续用它来训练模型。和之前一样,您将只训练几个周期以确保较短的运行时间。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vm_bi7NKXOzW", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model.fit(\n", " train_ds,\n", " validation_data=val_ds,\n", " epochs=3\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "EDJXAexrwsx8" }, "source": [ "## 使用 TensorFlow Datasets\n", "\n", "到目前为止,本教程的重点是从磁盘加载数据。此外,您还可以通过在 [TensorFlow Datasets](https://tensorflow.google.cn/datasets/catalog/overview) 上探索易于下载的大型数据集[目录](https://tensorflow.google.cn/datasets)来查找要使用的数据集。\n", "\n", "由于您之前已经从磁盘加载了花卉数据集,接下来看看如何使用 TensorFlow Datasets 导入它。" ] }, { "cell_type": "markdown", "metadata": { "id": "Qyu9wWDf1gfH" }, "source": [ "使用 TensorFlow Datasets 下载花卉[数据集](https://tensorflow.google.cn/datasets/catalog/tf_flowers):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NTQ-53DNwv8o", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "(train_ds, val_ds, test_ds), metadata = tfds.load(\n", " 'tf_flowers',\n", " split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],\n", " with_info=True,\n", " as_supervised=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "3hxXSgtj1iLV" }, "source": [ "花卉数据集有五个类:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kJvt6qzF1i4L", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "num_classes = metadata.features['label'].num_classes\n", "print(num_classes)" ] }, { "cell_type": "markdown", "metadata": { "id": "6dbvEz_F1lgE" }, "source": [ "从数据集中检索图像:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1lF3IUAO1ogi", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "get_label_name = metadata.features['label'].int2str\n", "\n", "image, label = next(iter(train_ds))\n", "_ = plt.imshow(image)\n", "_ = plt.title(get_label_name(label))" ] }, { "cell_type": "markdown", "metadata": { "id": "lHOOH_4TwaUb" }, "source": [ "和以前一样,请记得对训练集、验证集和测试集进行批处理、打乱顺序和配置以提高性能。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AMV6GtZiwfGP", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_ds = configure_for_performance(train_ds)\n", "val_ds = configure_for_performance(val_ds)\n", "test_ds = configure_for_performance(test_ds)" ] }, { "cell_type": "markdown", "metadata": { "id": "gmR7kT8l1w20" }, "source": [ "您可以通过访问[数据增强](../images/data_augmentation.ipynb)教程找到使用花卉数据集和 TensorFlow Datasets 的完整示例。" ] }, { "cell_type": "markdown", "metadata": { "id": "6cqkPenZIaHl" }, "source": [ "## 后续步骤\n", "\n", "本教程展示了从磁盘加载图像的两种方式。首先,您学习了如何使用 Keras 预处理层和效用函数加载和预处理图像数据集。接下来,您学习了如何使用 `tf.data` 从头开始编写输入流水线。最后,您学习了如何从 TensorFlow Datasets 下载数据集。\n", "\n", "后续步骤:\n", "\n", "- 您可以学习[如何添加数据增强](https://tensorflow.google.cn/tutorials/images/data_augmentation)。\n", "- 要详细了解 `tf.data`,您可以访问 [tf.data:构建 TensorFlow 输入流水线](https://tensorflow.google.cn/guide/data)指南。" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "images.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }