{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "mt9dL5dIir8X" }, "outputs": [], "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "ufPx7EiCiqgR", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ucMoYase6URl" }, "source": [ "# 加载和预处理图像" ] }, { "cell_type": "markdown", "metadata": { "id": "_Wwu5SXZmEkB" }, "source": [ "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
tf.data
编写自己的输入流水线。本部分展示了如何做到这一点,从我们之前下载的 TGZ 文件中的文件路径开始。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lAkQp5uxoINu",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'), shuffle=False)\n",
"list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "coORvEH-NGwc",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"for f in list_ds.take(5):\n",
" print(f.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6NLQ_VJhWO4z"
},
"source": [
"文件的树结构可用于编译 `class_names` 列表。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uRPHzDGhKACK",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"class_names = np.array(sorted([item.name for item in data_dir.glob('*') if item.name != \"LICENSE.txt\"]))\n",
"print(class_names)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CiptrWmAlmAa"
},
"source": [
"将数据集拆分为训练集和测试集:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GWHNPzXclpVr",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"val_size = int(image_count * 0.2)\n",
"train_ds = list_ds.skip(val_size)\n",
"val_ds = list_ds.take(val_size)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rkB-IR4-pS3U"
},
"source": [
"您可以按照如下方式打印每个数据集的长度:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SiKQrb9ppS-7",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"print(tf.data.experimental.cardinality(train_ds).numpy())\n",
"print(tf.data.experimental.cardinality(val_ds).numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "91CPfUUJ_8SZ"
},
"source": [
"编写一个将文件路径转换为 `(img, label)` 对的短函数:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "arSQzIey-4D4",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def get_label(file_path):\n",
" # Convert the path to a list of path components\n",
" parts = tf.strings.split(file_path, os.path.sep)\n",
" # The second to last is the class-directory\n",
" one_hot = parts[-2] == class_names\n",
" # Integer encode the label\n",
" return tf.argmax(one_hot)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MGlq4IP4Aktb",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def decode_img(img):\n",
" # Convert the compressed string to a 3D uint8 tensor\n",
" img = tf.io.decode_jpeg(img, channels=3)\n",
" # Resize the image to the desired size\n",
" return tf.image.resize(img, [img_height, img_width])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-xhBRgvNqRRe",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def process_path(file_path):\n",
" label = get_label(file_path)\n",
" # Load the raw data from the file as a string\n",
" img = tf.io.read_file(file_path)\n",
" img = decode_img(img)\n",
" return img, label"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S9a5GpsUOBx8"
},
"source": [
"使用 `Dataset.map` 创建 `image, label` 对的数据集:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3SDhbo8lOBQv",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.\n",
"train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)\n",
"val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kxrl0lGdnpRz",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"for image, label in train_ds.take(1):\n",
" print(\"Image shape: \", image.numpy().shape)\n",
" print(\"Label: \", label.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vYGCgJuR_9Qp"
},
"source": [
"### 训练的基本方法"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wwZavzgsIytz"
},
"source": [
"要使用此数据集训练模型,你将会想要数据:\n",
"\n",
"- 被充分打乱。\n",
"- 被分割为 batch。\n",
"- 永远重复。\n",
"\n",
"使用 `tf.data` API 可以轻松添加这些功能。有关详情,请访问[输入流水线性能](../../guide/performance/datasets.ipynb)指南。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uZmZJx8ePw_5",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def configure_for_performance(ds):\n",
" ds = ds.cache()\n",
" ds = ds.shuffle(buffer_size=1000)\n",
" ds = ds.batch(batch_size)\n",
" ds = ds.prefetch(buffer_size=AUTOTUNE)\n",
" return ds\n",
"\n",
"train_ds = configure_for_performance(train_ds)\n",
"val_ds = configure_for_performance(val_ds)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "45P7OvzRWzOB"
},
"source": [
"### 呈现数据\n",
"\n",
"您可以通过与之前创建的数据集类似的方式呈现此数据集:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UN_Dnl72YNIj",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"image_batch, label_batch = next(iter(train_ds))\n",
"\n",
"plt.figure(figsize=(10, 10))\n",
"for i in range(9):\n",
" ax = plt.subplot(3, 3, i + 1)\n",
" plt.imshow(image_batch[i].numpy().astype(\"uint8\"))\n",
" label = label_batch[i]\n",
" plt.title(class_names[label])\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fMT8kh_uXPRU"
},
"source": [
"### 继续训练模型\n",
"\n",
"您现在已经手动构建了一个与由上面的 `keras.preprocessing` 创建的数据集类似的 `tf.data.Dataset`。您可以继续用它来训练模型。和之前一样,您将只训练几个周期以确保较短的运行时间。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vm_bi7NKXOzW",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"model.fit(\n",
" train_ds,\n",
" validation_data=val_ds,\n",
" epochs=3\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EDJXAexrwsx8"
},
"source": [
"## 使用 TensorFlow Datasets\n",
"\n",
"到目前为止,本教程的重点是从磁盘加载数据。此外,您还可以通过在 [TensorFlow Datasets](https://tensorflow.google.cn/datasets/catalog/overview) 上探索易于下载的大型数据集[目录](https://tensorflow.google.cn/datasets)来查找要使用的数据集。\n",
"\n",
"由于您之前已经从磁盘加载了花卉数据集,接下来看看如何使用 TensorFlow Datasets 导入它。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qyu9wWDf1gfH"
},
"source": [
"使用 TensorFlow Datasets 下载花卉[数据集](https://tensorflow.google.cn/datasets/catalog/tf_flowers):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NTQ-53DNwv8o",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"(train_ds, val_ds, test_ds), metadata = tfds.load(\n",
" 'tf_flowers',\n",
" split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],\n",
" with_info=True,\n",
" as_supervised=True,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3hxXSgtj1iLV"
},
"source": [
"花卉数据集有五个类:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kJvt6qzF1i4L",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"num_classes = metadata.features['label'].num_classes\n",
"print(num_classes)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6dbvEz_F1lgE"
},
"source": [
"从数据集中检索图像:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1lF3IUAO1ogi",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"get_label_name = metadata.features['label'].int2str\n",
"\n",
"image, label = next(iter(train_ds))\n",
"_ = plt.imshow(image)\n",
"_ = plt.title(get_label_name(label))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lHOOH_4TwaUb"
},
"source": [
"和以前一样,请记得对训练集、验证集和测试集进行批处理、打乱顺序和配置以提高性能。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AMV6GtZiwfGP",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"train_ds = configure_for_performance(train_ds)\n",
"val_ds = configure_for_performance(val_ds)\n",
"test_ds = configure_for_performance(test_ds)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gmR7kT8l1w20"
},
"source": [
"您可以通过访问[数据增强](../images/data_augmentation.ipynb)教程找到使用花卉数据集和 TensorFlow Datasets 的完整示例。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6cqkPenZIaHl"
},
"source": [
"## 后续步骤\n",
"\n",
"本教程展示了从磁盘加载图像的两种方式。首先,您学习了如何使用 Keras 预处理层和效用函数加载和预处理图像数据集。接下来,您学习了如何使用 `tf.data` 从头开始编写输入流水线。最后,您学习了如何从 TensorFlow Datasets 下载数据集。\n",
"\n",
"后续步骤:\n",
"\n",
"- 您可以学习[如何添加数据增强](https://tensorflow.google.cn/tutorials/images/data_augmentation)。\n",
"- 要详细了解 `tf.data`,您可以访问 [tf.data:构建 TensorFlow 输入流水线](https://tensorflow.google.cn/guide/data)指南。"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "images.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}