{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "TBFXQGKYUc4X" }, "outputs": [], "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "1z4xy2gTUc4a", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "FE7KNzPPVrVV" }, "source": [ "# 图像分类" ] }, { "cell_type": "markdown", "metadata": { "id": "KwQtSOz0VrVX" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看在 Google Colab 中运行 在 GitHub 上查看源代码下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "gN7G9GFmVrVY" }, "source": [ "本教程展示了如何使用 `tf.keras.Sequential` 模型对花卉图像进行分类,并使用 `tf.keras.utils.image_dataset_from_directory` 加载数据。其中演示了以下概念:\n", "\n", "- 从磁盘高效加载数据集。\n", "- 识别过拟合,并应用数据增强和随机失活等技术缓解过拟合。\n", "\n", "本教程遵循基本的机器学习工作流:\n", "\n", "1. 检查并理解数据\n", "2. 构建输入流水线\n", "3. 构建模型\n", "4. 训练模型\n", "5. 测试模型\n", "6. 改进模型并重复整个过程\n", "\n", "此外,该笔记本还演示了如何将[保存的模型](../../../guide/saved_model.ipynb)转换为 [TensorFlow Lite](https://tensorflow.google.cn/lite/) 模型,以便在移动设备、嵌入式设备和 IoT 设备上进行设备端机器学习。" ] }, { "cell_type": "markdown", "metadata": { "id": "zF9uvbXNVrVY" }, "source": [ "## 设置\n", "\n", "导入 TensorFlow 和其他必要的库:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L1WtoaOHVrVh", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import PIL\n", "import tensorflow as tf\n", "\n", "from tensorflow import keras\n", "from tensorflow.keras import layers\n", "from tensorflow.keras.models import Sequential" ] }, { "cell_type": "markdown", "metadata": { "id": "UZZI6lNkVrVm" }, "source": [ "## 下载并探索数据集" ] }, { "cell_type": "markdown", "metadata": { "id": "DPHx8-t-VrVo" }, "source": [ "本教程使用一个包含约 3,700 张花卉照片的数据集。该数据集包含 5 个子目录,每个子目录对应一个类:\n", "\n", "```\n", "flower_photo/\n", " daisy/\n", " dandelion/\n", " roses/\n", " sunflowers/\n", " tulips/\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "57CcilYSG0zv", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import pathlib\n", "\n", "dataset_url = \"https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\"\n", "data_dir = tf.keras.utils.get_file('flower_photos.tar', origin=dataset_url, extract=True)\n", "data_dir = pathlib.Path(data_dir).with_suffix('')" ] }, { "cell_type": "markdown", "metadata": { "id": "VpmywIlsVrVx" }, "source": [ "下载后,您现在应该拥有一个数据集的副本。总共有 3,670 个图像:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SbtTDYhOHZb6", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "image_count = len(list(data_dir.glob('*/*.jpg')))\n", "print(image_count)" ] }, { "cell_type": "markdown", "metadata": { "id": "PVmwkOSdHZ5A" }, "source": [ "下面是一些玫瑰:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N1loMlbYHeiJ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "roses = list(data_dir.glob('roses/*'))\n", "PIL.Image.open(str(roses[0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RQbZBOTLHiUP", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "PIL.Image.open(str(roses[1]))" ] }, { "cell_type": "markdown", "metadata": { "id": "DGEqiBbRHnyI" }, "source": [ "和一些郁金香:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HyQkfPGdHilw", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "tulips = list(data_dir.glob('tulips/*'))\n", "PIL.Image.open(str(tulips[0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wtlhWJPAHivf", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "PIL.Image.open(str(tulips[1]))" ] }, { "cell_type": "markdown", "metadata": { "id": "gIjgz7_JIo_m" }, "source": [ "## 使用 Keras 效用函数加载数据\n", "\n", "接下来,使用有用的 `tf.keras.utils.image_dataset_from_directory` 实用工具从磁盘上加载这些图像。只需几行代码就能将磁盘上的图像目录转移到 `tf.data.Dataset`。如果愿意,您也可以访问[加载和预处理图像](../load_data/images.ipynb)教程,从头开始编写您自己的数据加载代码。" ] }, { "cell_type": "markdown", "metadata": { "id": "xyDNn9MbIzfT" }, "source": [ "### 创建数据集" ] }, { "cell_type": "markdown", "metadata": { "id": "anqiK_AGI086" }, "source": [ "为加载程序定义一些参数:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "H74l2DoDI2XD", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "batch_size = 32\n", "img_height = 180\n", "img_width = 180" ] }, { "cell_type": "markdown", "metadata": { "id": "pFBhRrrEI49z" }, "source": [ "开发模型时,使用验证拆分是一种很好的做法。将 80% 的图像用于训练,将 20% 的图像用于验证。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fIR0kRZiI_AT", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_ds = tf.keras.utils.image_dataset_from_directory(\n", " data_dir,\n", " validation_split=0.2,\n", " subset=\"training\",\n", " seed=123,\n", " image_size=(img_height, img_width),\n", " batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iscU3UoVJBXj", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "val_ds = tf.keras.utils.image_dataset_from_directory(\n", " data_dir,\n", " validation_split=0.2,\n", " subset=\"validation\",\n", " seed=123,\n", " image_size=(img_height, img_width),\n", " batch_size=batch_size)" ] }, { "cell_type": "markdown", "metadata": { "id": "WLQULyAvJC3X" }, "source": [ "您可以在这些数据集的 `class_names` 特性中找到类名称。这些名称按照字母顺序与目录名称相对应。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZHAxkHX5JD3k", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class_names = train_ds.class_names\n", "print(class_names)" ] }, { "cell_type": "markdown", "metadata": { "id": "_uoVvxSLJW9m" }, "source": [ "## 呈现数据\n", "\n", "下面是训练数据集中的前 9 个图像:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wBmEA9c0JYes", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.figure(figsize=(10, 10))\n", "for images, labels in train_ds.take(1):\n", " for i in range(9):\n", " ax = plt.subplot(3, 3, i + 1)\n", " plt.imshow(images[i].numpy().astype(\"uint8\"))\n", " plt.title(class_names[labels[i]])\n", " plt.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": { "id": "5M6BXtXFJdW0" }, "source": [ "您将把这些数据集传递给 Keras `Model.fit` 方法,以便在本教程的后面部分进行训练。如果愿意,您还可以手动迭代数据集并检索批量图像:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2-MfMoenJi8s", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for image_batch, labels_batch in train_ds:\n", " print(image_batch.shape)\n", " print(labels_batch.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": { "id": "Wj4FrKxxJkoW" }, "source": [ "`image_batch` 是形状为 `(32, 180, 180, 3)` 的张量。这是由 32 个形状为 `180x180x3`(最后一个维度是指颜色通道 RGB)的图像组成的批次。`label_batch` 是形状为 `(32,)` 的张量,这些是 32 个图像的对应标签。\n", "\n", "您可以在 `image_batch` 和 `labels_batch` 张量上调用 `.numpy()`,将其转换为 `numpy.ndarray`。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "4Dr0at41KcAU" }, "source": [ "## 配置数据集以提高性能\n", "\n", "请确保使用缓冲预提取,以便从磁盘产生数据,而不会阻塞 I/O。这是您在加载数据时应该使用的两种重要方法。\n", "\n", "- 在第一个周期期间从磁盘加载图像后,`Dataset.cache()` 会将这些图像保留在内存中。这将确保在训练模型时数据集不会成为瓶颈。如果数据集太大无法装入内存,您也可以使用此方法创建高性能的磁盘缓存。\n", "- `Dataset.prefetch()` 会在训练时将数据预处理和模型执行重叠。\n", "\n", "感兴趣的读者可以在[使用 tf.data API 获得更佳性能](../../guide/data_performance.ipynb)指南的*预提取*部分了解更多有关这两种方法的详细信息,以及如何将数据缓存到磁盘。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nOjJSm7DKoZA", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "AUTOTUNE = tf.data.AUTOTUNE\n", "\n", "train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)\n", "val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)" ] }, { "cell_type": "markdown", "metadata": { "id": "8GUnmPF4JvEf" }, "source": [ "## 标准化数据" ] }, { "cell_type": "markdown", "metadata": { "id": "e56VXHMWJxYT" }, "source": [ "RGB 通道值在 `[0, 255]` 范围内。这对于神经网络来说并不理想;一般而言,您应当设法使您的输入值变小。\n", "\n", "在这里,我们通过使用 `tf.keras.layers.Rescaling` 将值标准化为在 `[0, 1]` 范围内。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PEYxo2CTJvY9", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "normalization_layer = layers.Rescaling(1./255)" ] }, { "cell_type": "markdown", "metadata": { "id": "Bl4RmanbJ4g0" }, "source": [ "可以通过两种方式使用该层。您可以通过调用 `Dataset.map` 将其应用于数据集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X9o9ESaJJ502", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))\n", "image_batch, labels_batch = next(iter(normalized_ds))\n", "first_image = image_batch[0]\n", "# Notice the pixel values are now in `[0,1]`.\n", "print(np.min(first_image), np.max(first_image))" ] }, { "cell_type": "markdown", "metadata": { "id": "XWEOmRSBJ9J8" }, "source": [ "或者,您可以在模型定义中包括该层,从而简化部署。在这里,请使用第二种方式。" ] }, { "cell_type": "markdown", "metadata": { "id": "XsRk1xCwKZR4" }, "source": [ "注:您之前使用 `tf.keras.utils.image_dataset_from_directory` 的 `image_size` 参数调整了图像大小。如果您还希望在模型中包括调整大小的逻辑,可以使用 `tf.keras.layers.Resizing` 层。" ] }, { "cell_type": "markdown", "metadata": { "id": "WcUTyDOPKucd" }, "source": [ "## 基本 Keras 模型\n", "\n", "### 创建模型\n", "\n", "Keras [序贯](https://tensorflow.google.cn/guide/keras/sequential_model)模型由三个卷积块 (`tf.keras.layers.Conv2D`) 组成,每个卷积块都有一个最大池化层 (`tf.keras.layers.MaxPooling2D`)。有一个全连接层 (`tf.keras.layers.Dense`),上方有 128 个单元,由 ReLU 激活函数 (`'relu'`) 激活。此模型尚未针对高准确率进行调整;本教程的目标是展示一种标准方式。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QR6argA1K074", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "num_classes = len(class_names)\n", "\n", "model = Sequential([\n", " layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)),\n", " layers.Conv2D(16, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Conv2D(32, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Conv2D(64, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Flatten(),\n", " layers.Dense(128, activation='relu'),\n", " layers.Dense(num_classes)\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "EaKFzz72Lqpg" }, "source": [ "### 编译模型\n", "\n", "对于本教程,选择 `tf.keras.optimizers.Adam` 优化器和 `tf.keras.losses.SparseCategoricalCrossentropy` 损失函数。要查看每个训练周期的训练和验证准确率,请将 `metrics` 参数传递给 `Model.compile`。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jloGNS1MLx3A", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])" ] }, { "cell_type": "markdown", "metadata": { "id": "aMJ4DnuJL55A" }, "source": [ "### 模型摘要\n", "\n", "使用 Keras `Model.summary` 方法查看网络的所有层:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "llLYH-BXL7Xe", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "NiYHcbvaL9H-" }, "source": [ "### 训练模型" ] }, { "cell_type": "markdown", "metadata": { "id": "j30F69T4sIVN" }, "source": [ "使用 Keras `Model.fit` 方法将模型训练 10 个 周期:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5fWToCqYMErH", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "epochs=10\n", "history = model.fit(\n", " train_ds,\n", " validation_data=val_ds,\n", " epochs=epochs\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "SyFKdQpXMJT4" }, "source": [ "## 呈现训练结果" ] }, { "cell_type": "markdown", "metadata": { "id": "dFvOvmAmMK9w" }, "source": [ "在训练集和验证集上创建损失和准确率的图表:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jWnopEChMMCn", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "acc = history.history['accuracy']\n", "val_acc = history.history['val_accuracy']\n", "\n", "loss = history.history['loss']\n", "val_loss = history.history['val_loss']\n", "\n", "epochs_range = range(epochs)\n", "\n", "plt.figure(figsize=(8, 8))\n", "plt.subplot(1, 2, 1)\n", "plt.plot(epochs_range, acc, label='Training Accuracy')\n", "plt.plot(epochs_range, val_acc, label='Validation Accuracy')\n", "plt.legend(loc='lower right')\n", "plt.title('Training and Validation Accuracy')\n", "\n", "plt.subplot(1, 2, 2)\n", "plt.plot(epochs_range, loss, label='Training Loss')\n", "plt.plot(epochs_range, val_loss, label='Validation Loss')\n", "plt.legend(loc='upper right')\n", "plt.title('Training and Validation Loss')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "hO_jT7HwMrEn" }, "source": [ "图表显示,训练准确率和验证准确率相差很大,并且模型在验证集上仅达到了 60% 左右的准确率。\n", "\n", "以下教程部分展示了如何检查出了什么问题并尝试提高模型的整体性能。" ] }, { "cell_type": "markdown", "metadata": { "id": "hqtyGodAMvNV" }, "source": [ "## 过拟合" ] }, { "cell_type": "markdown", "metadata": { "id": "ixsz9XFfMxcu" }, "source": [ "在上面的图表中,训练准确率随时间呈线性提升,而验证准确率在训练过程中停滞在 60% 左右。同时,训练准确率和验证准确率之间的差异也很明显,这是[过拟合](https://tensorflow.google.cn/tutorials/keras/overfit_and_underfit)的标志。\n", "\n", "当训练样本数量较少时,模型有时会从训练样本中的噪声或不需要的细节中学习,以至于对模型在新样本上的性能产生负面影响。这种现象被称为过拟合。这意味着模型将很难在新数据集上泛化。\n", "\n", "在训练过程中有多种方式解决过拟合问题。在本教程中,您将使用*数据增强*并将*随机失活*添加到模型中。" ] }, { "cell_type": "markdown", "metadata": { "id": "BDMfYqwmM1C-" }, "source": [ "## 数据增强" ] }, { "cell_type": "markdown", "metadata": { "id": "GxYwix81M2YO" }, "source": [ "过拟合通常会在训练样本数量较少的情况下发生。[数据增强](./data_augmentation.ipynb)采用的方法是:通过增强然后使用随机转换,从现有样本中生成其他训练数据,产生看起来可信的图像。这有助于向模型公开数据的更多方面,且有助于更好地进行泛化。\n", "\n", "您将使用以下 Keras 预处理层实现数据增强:`tf.keras.layers.RandomFlip`、 `tf.keras.layers.RandomRotation` 和 `tf.keras.layers.RandomZoom`。这些层可以像其他层一样包含在您的模型中,并在 GPU 上运行。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9J80BAbIMs21", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "data_augmentation = keras.Sequential(\n", " [\n", " layers.RandomFlip(\"horizontal\",\n", " input_shape=(img_height,\n", " img_width,\n", " 3)),\n", " layers.RandomRotation(0.1),\n", " layers.RandomZoom(0.1),\n", " ]\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "PN4k1dK3S6eV" }, "source": [ "通过对同一图像多次应用数据增强来呈现一些增强示例:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7Z90k539S838", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "plt.figure(figsize=(10, 10))\n", "for images, _ in train_ds.take(1):\n", " for i in range(9):\n", " augmented_images = data_augmentation(images)\n", " ax = plt.subplot(3, 3, i + 1)\n", " plt.imshow(augmented_images[0].numpy().astype(\"uint8\"))\n", " plt.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": { "id": "tsjXCBLYYNs5" }, "source": [ "在下一步训练之前,您将在模型中添加数据增强。" ] }, { "cell_type": "markdown", "metadata": { "id": "ZeD3bXepYKXs" }, "source": [ "## 随机失活\n", "\n", "另一种减少过拟合的技术是向网络中引入[随机失活](https://developers.google.com/machine-learning/glossary#dropout_regularization){:.external}正则化。\n", "\n", "将随机失活应用于层时,它会在训练过程中随机从该层丢弃(通过将激活设置为零)一些输出单元。随机失活会接受小数作为输入值,形式如 0.1、0.2、0.4 等。这意味着从应用了随机失活的层中随机丢弃 10%、20% 或 40% 的输出单元。\n", "\n", "在使用增强图像对其进行训练之前,我们来使用 `tf.keras.layers.Dropout` 创建一个新的神经网络:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2Zeg8zsqXCsm", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model = Sequential([\n", " data_augmentation,\n", " layers.Rescaling(1./255),\n", " layers.Conv2D(16, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Conv2D(32, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Conv2D(64, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Dropout(0.2),\n", " layers.Flatten(),\n", " layers.Dense(128, activation='relu'),\n", " layers.Dense(num_classes, name=\"outputs\")\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "L4nEcuqgZLbi" }, "source": [ "## 编译并训练模型" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EvyAINs9ZOmJ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wWLkKoKjZSoC", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LWS-vvNaZDag", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "epochs = 15\n", "history = model.fit(\n", " train_ds,\n", " validation_data=val_ds,\n", " epochs=epochs\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "Lkdl8VsBbZOu" }, "source": [ "## 呈现训练结果\n", "\n", "应用数据增强和 `tf.keras.layers.Dropout` 后,过拟合的情况比以前少了,训练准确率和验证准确率也变得更为接近:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dduoLfKsZVIA", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "acc = history.history['accuracy']\n", "val_acc = history.history['val_accuracy']\n", "\n", "loss = history.history['loss']\n", "val_loss = history.history['val_loss']\n", "\n", "epochs_range = range(epochs)\n", "\n", "plt.figure(figsize=(8, 8))\n", "plt.subplot(1, 2, 1)\n", "plt.plot(epochs_range, acc, label='Training Accuracy')\n", "plt.plot(epochs_range, val_acc, label='Validation Accuracy')\n", "plt.legend(loc='lower right')\n", "plt.title('Training and Validation Accuracy')\n", "\n", "plt.subplot(1, 2, 2)\n", "plt.plot(epochs_range, loss, label='Training Loss')\n", "plt.plot(epochs_range, val_loss, label='Validation Loss')\n", "plt.legend(loc='upper right')\n", "plt.title('Training and Validation Loss')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "dtv5VbaVb-3W" }, "source": [ "## 根据新数据进行预测" ] }, { "cell_type": "markdown", "metadata": { "id": "10buWpJbcCQz" }, "source": [ "使用您的模型对一个未包含在训练集或验证集中的图像进行分类。" ] }, { "cell_type": "markdown", "metadata": { "id": "NKgMZ4bDcHf7" }, "source": [ "注:数据增强层和随机失活层在推断时处于非活动状态。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dC40sRITBSsQ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "sunflower_url = \"https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg\"\n", "sunflower_path = tf.keras.utils.get_file('Red_sunflower', origin=sunflower_url)\n", "\n", "img = tf.keras.utils.load_img(\n", " sunflower_path, target_size=(img_height, img_width)\n", ")\n", "img_array = tf.keras.utils.img_to_array(img)\n", "img_array = tf.expand_dims(img_array, 0) # Create a batch\n", "\n", "predictions = model.predict(img_array)\n", "score = tf.nn.softmax(predictions[0])\n", "\n", "print(\n", " \"This image most likely belongs to {} with a {:.2f} percent confidence.\"\n", " .format(class_names[np.argmax(score)], 100 * np.max(score))\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "aOc3PZ2N2r18" }, "source": [ "## 实用 TensorFlow Lite\n", "\n", "TensorFlow Lite 是一组工具,可帮助开发者在移动设备、嵌入式设备和边缘设备上运行模型,从而实现设备端机器学习。" ] }, { "cell_type": "markdown", "metadata": { "id": "cThu25rh4LPP" }, "source": [ "### 将 Keras 序贯模型转换为 TensorFlow Lite 模型\n", "\n", "要将经过训练的模型与设备端应用程序一起使用,请首先[将其转换](https://tensorflow.google.cn/lite/models/convert)为更小、更高效的模型格式,称为 [TensorFlow Lite](https://tensorflow.google.cn/lite/) 模型。\n", "\n", "在此示例中,采用经过训练的 Keras 序贯模型并使用 `tf.lite.TFLiteConverter.from_keras_model` 生成 [TensorFlow Lite](https://tensorflow.google.cn/lite/) 模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mXo6ftuL2ufx", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Convert the model.\n", "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", "tflite_model = converter.convert()\n", "\n", "# Save the model.\n", "with open('model.tflite', 'wb') as f:\n", " f.write(tflite_model)" ] }, { "cell_type": "markdown", "metadata": { "id": "4R26OU4gGKhh" }, "source": [ "您在上一步中保存的 TensorFlow Lite 模型可以包含多个函数签名。 Keras 模型转换器 API 会自动使用默认签名。详细了解 [TensorFlow Lite 签名](https://tensorflow.google.cn/lite/guide/signatures)。" ] }, { "cell_type": "markdown", "metadata": { "id": "7fjQfXaV2l-5" }, "source": [ "### 运行 TensorFlow Lite 模型\n", "\n", "您可以通过 `tf.lite.Interpreter` 类在 Python 中访问 TensorFlow Lite 保存的模型签名。\n", "\n", "使用 `Interpreter` 加载模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cHYcip_FOaHq", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "TF_MODEL_FILE_PATH = 'model.tflite' # The default path to the saved TensorFlow Lite model\n", "\n", "interpreter = tf.lite.Interpreter(model_path=TF_MODEL_FILE_PATH)" ] }, { "cell_type": "markdown", "metadata": { "id": "nPUXY6BdHDHo" }, "source": [ "打印转换后的模型中的签名以获得输入(和输出)的名称:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZdDl00E2OaHq", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "interpreter.get_signature_list()" ] }, { "cell_type": "markdown", "metadata": { "id": "4eVFqT0je3YG" }, "source": [ "在此示例中,您有一个名为 `serving_default` 的默认签名。此外,`'inputs'` 的名称是 `'sequential_1_input'`,而 `'outputs'` 的名称为 `'outputs'`。如本教程前面所述,您可以在运行 `Model.summary` 时查找这些第一个和最后一个 Keras 层名称。\n", "\n", "现在,您可以使用 `tf.lite.Interpreter.get_signature_runner` 通过传递签名名称对示例图像执行推断来测试加载的 TensorFlow 模型,如下所示:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yFoT_7W_OaHq", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "classify_lite = interpreter.get_signature_runner('serving_default')\n", "classify_lite" ] }, { "cell_type": "markdown", "metadata": { "id": "b1mfRcBOnEx0" }, "source": [ "与您在本教程前面所做的类似,您可以使用 TensorFlow Lite 模型对未包含在训练集或验证集中的图像进行分类。\n", "\n", "您已经对该图像进行了张量化并将其保存为 `img_array`。现在,将其传递给已加载的 TensorFlow Lite 模型 (`predictions_lite`) 的第一个参数(`'inputs'` 的名称),计算 Softmax 激活,然后打印具有最高计算概率的类的预测。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sEqR27YcnFvc", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "predictions_lite = classify_lite(sequential_1_input=img_array)['outputs']\n", "score_lite = tf.nn.softmax(predictions_lite)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZKP_GFeKUWb5", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "print(\n", " \"This image most likely belongs to {} with a {:.2f} percent confidence.\"\n", " .format(class_names[np.argmax(score_lite)], 100 * np.max(score_lite))\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "Poz_iYgeUg_U" }, "source": [ "Lite 模型生成的预测应该与原始模型生成的预测几乎相同:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "InXXDJL8UYC1", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "print(np.max(np.abs(predictions - predictions_lite)))" ] }, { "cell_type": "markdown", "metadata": { "id": "5hJzY8XijM7N" }, "source": [ "在 `'daisy'`、`'dandelion'`、`'roses'`、`'sunflowers'` 和 `'tulips'` 这五个类中,模型应该预测图像属于向日葵,这与 TensorFlow Lite 转换之前的结果相同。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "1RlfCY9v2_ir" }, "source": [ "## 后续步骤\n", "\n", "本教程展示了如何训练用于图像分类的模型,对其进行测试,将其转换为 TensorFlow Lite 格式以用于设备端应用(例如图像分类应用),以及使用 Python API 通过 TensorFlow Lite 模型执行推断。\n", "\n", "您可以通过[教程](https://tensorflow.google.cn/lite/tutorials)和[指南](https://tensorflow.google.cn/lite/guide)了解有关 TensorFlow Lite 的更多信息。" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "classification.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }