{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "pL--_KGdYoBz" }, "outputs": [], "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "uBDvXpYzYnGj", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "HQzaEQuJiW_d" }, "source": [ "# TFRecord 和 tf.Example\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行在 GitHub 上查看源代码下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "3pkUd_9IZCFO" }, "source": [ "TFRecord 格式是一种用于存储二进制记录序列的简单格式。\n", "\n", "[协议缓冲区](https://developers.google.com/protocol-buffers/)是一个跨平台、跨语言的库,用于高效地序列化结构化数据。\n", "\n", "协议消息由 `.proto` 文件定义,这通常是了解消息类型最简单的方法。\n", "\n", "`tf.train.Example` 消息(或 protobuf)是一种灵活的消息类型,表示 `{\"string\": value}` 映射。它可以与 TensorFlow 结合使用,并在 [TFX](https://tensorflow.google.cn/tfx/) 等更高级 API 中使用。" ] }, { "cell_type": "markdown", "metadata": { "id": "Ac83J0QxjhFt" }, "source": [ "此笔记本将演示如何创建、解析和使用 `tf.Example` 消息,以及如何在 `.tfrecord` 文件之间对 `tf.Example` 消息进行序列化、写入和读取。\n", "\n", "注:这些结构虽然有用,但并不是强制的。您无需转换现有代码即可使用 TFRecord,除非您正在[使用 tf.data](https://tensorflow.google.cn/guide/data) 且读取数据仍是训练的瓶颈。有关数据集性能的提示,请参阅[使用 tf.data API 提升性能](https://tensorflow.google.cn/guide/data_performance)。\n", "\n", "注:通常,您应当将数据分片到多个文件,以便可以并行化 I/O(在单个主机内或跨多个主机)。根据经验法则,文件数量至少应达到读取数据的主机数量的 10 倍。同时,每个文件都应当足够大(至少 10+MB,理想情况下为 100MB+),以便您从 I/O 预提取中受益。例如,假设您有 `X` GB 数据,并且您计划在最多 `N` 个主机上进行训练。理想情况下,您应当将数据分片到 ~`10*N` 个文件,只要 ~`X/(10*N)` 为 10+ MB(理想情况下为 100+ MB)。如果小于该值,则可能需要创建更少的分片来权衡并行性优势和 I/O 预提取优势。" ] }, { "cell_type": "markdown", "metadata": { "id": "WkRreBf1eDVc" }, "source": [ "## 设置" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ja7sezsmnXph", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "import numpy as np\n", "import IPython.display as display" ] }, { "cell_type": "markdown", "metadata": { "id": "e5Kq88ccUWQV" }, "source": [ "## `tf.train.Example`" ] }, { "cell_type": "markdown", "metadata": { "id": "VrdQHgvNijTi" }, "source": [ "### `tf.Example` 的数据类型" ] }, { "cell_type": "markdown", "metadata": { "id": "lZw57Qrn4CTE" }, "source": [ "从根本上讲,`tf.Example` 是 `{\"string\": tf.train.Feature}` 映射。\n", "\n", "`tf.train.Feature` 消息类型可以接受以下三种类型(请参阅 [`.proto` 文件](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/feature.proto))。大多数其他通用类型也可以强制转换成下面的其中一种:\n", "\n", "1. `tf.train.BytesList`(可强制转换自以下类型)\n", "\n", "- `string`\n", "- `byte`\n", "\n", "1. `tf.train.FloatList`(可强制转换自以下类型)\n", "\n", "- `float` (`float32`)\n", "- `double` (`float64`)\n", "\n", "1. `tf.train.Int64List`(可强制转换自以下类型)\n", "\n", "- `bool`\n", "- `enum`\n", "- `int32`\n", "- `uint32`\n", "- `int64`\n", "- `uint64`" ] }, { "cell_type": "markdown", "metadata": { "id": "_e3g9ExathXP" }, "source": [ "为了将标准 TensorFlow 类型转换为兼容 `tf.Example` 的 `tf.train.Feature`,可以使用下面的快捷函数。请注意,每个函数会接受标量输入值并返回包含上述三种 `list` 类型之一的 `tf.train.Feature`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mbsPOUpVtYxA", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# The following functions can be used to convert a value to a type compatible\n", "# with tf.train.Example.\n", "\n", "def _bytes_feature(value):\n", " \"\"\"Returns a bytes_list from a string / byte.\"\"\"\n", " if isinstance(value, type(tf.constant(0))):\n", " value = value.numpy() # BytesList won't unpack a string from an EagerTensor.\n", " return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n", "\n", "def _float_feature(value):\n", " \"\"\"Returns a float_list from a float / double.\"\"\"\n", " return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))\n", "\n", "def _int64_feature(value):\n", " \"\"\"Returns an int64_list from a bool / enum / int / uint.\"\"\"\n", " return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))" ] }, { "cell_type": "markdown", "metadata": { "id": "Wst0v9O8hgzy" }, "source": [ "注:为了简单起见,本示例仅使用标量输入。要处理非标量特征,最简单的方法是使用 `tf.io.serialize_tensor` 将张量转换为二进制字符串。在 TensorFlow 中,字符串是标量。使用 `tf.io.parse_tensor` 可将二进制字符串转换回张量。" ] }, { "cell_type": "markdown", "metadata": { "id": "vsMbkkC8xxtB" }, "source": [ "下面是有关这些函数如何工作的一些示例。请注意不同的输入类型和标准化的输出类型。如果函数的输入类型与上述可强制转换的类型均不匹配,则该函数将引发异常(例如,`_int64_feature(1.0)` 将出错,因为 `1.0` 是浮点数,应该用于 `_float_feature` 函数):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hZzyLGr0u73y", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "print(_bytes_feature(b'test_string'))\n", "print(_bytes_feature(u'test_bytes'.encode('utf-8')))\n", "\n", "print(_float_feature(np.exp(1)))\n", "\n", "print(_int64_feature(True))\n", "print(_int64_feature(1))" ] }, { "cell_type": "markdown", "metadata": { "id": "nj1qpfQU5qmi" }, "source": [ "可以使用 `.SerializeToString` 方法将所有协议消息序列化为二进制字符串:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5afZkORT5pjm", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "feature = _float_feature(np.exp(1))\n", "\n", "feature.SerializeToString()" ] }, { "cell_type": "markdown", "metadata": { "id": "laKnw9F3hL-W" }, "source": [ "### 创建 `tf.Example` 消息" ] }, { "cell_type": "markdown", "metadata": { "id": "b_MEnhxchQPC" }, "source": [ "假设您要根据现有数据创建 `tf.Example` 消息。在实践中,数据集可能来自任何地方,但是从单个观测值创建 `tf.Example` 消息的过程相同:\n", "\n", "1. 在每个观测结果中,需要使用上述其中一种函数,将每个值转换为包含三种兼容类型之一的 `tf.train.Feature`。\n", "\n", "2. 创建一个从特征名称字符串到第 1 步中生成的编码特征值的映射(字典)。\n", "\n", "3. 将第 2 步中生成的映射转换为 [`Features` 消息](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/feature.proto#L85)。" ] }, { "cell_type": "markdown", "metadata": { "id": "4EgFQ2uHtchc" }, "source": [ "在此笔记本中,您将使用 NumPy 创建一个数据集。\n", "\n", "此数据集将具有 4 个特征:\n", "\n", "- 具有相等 `False` 或 `True` 概率的布尔特征\n", "- 从 `[0, 5]` 均匀随机选择的整数特征\n", "- 通过将整数特征作为索引从字符串表生成的字符串特征\n", "- 来自标准正态分布的浮点特征\n", "\n", "请思考一个样本,其中包含来自上述每个分布的 10,000 个独立且分布相同的观测值:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CnrguFAy3YQv", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# The number of observations in the dataset.\n", "n_observations = int(1e4)\n", "\n", "# Boolean feature, encoded as False or True.\n", "feature0 = np.random.choice([False, True], n_observations)\n", "\n", "# Integer feature, random from 0 to 4.\n", "feature1 = np.random.randint(0, 5, n_observations)\n", "\n", "# String feature.\n", "strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])\n", "feature2 = strings[feature1]\n", "\n", "# Float feature, from a standard normal distribution.\n", "feature3 = np.random.randn(n_observations)" ] }, { "cell_type": "markdown", "metadata": { "id": "aGrscehJr7Jd" }, "source": [ "您可以使用 `_bytes_feature`、`_float_feature` 或 `_int64_feature` 将下面的每个特征强制转换为兼容 `tf.Example` 的类型。然后,可以通过下面的已编码特征创建 `tf.Example` 消息:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RTCS49Ij_kUw", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def serialize_example(feature0, feature1, feature2, feature3):\n", " \"\"\"\n", " Creates a tf.train.Example message ready to be written to a file.\n", " \"\"\"\n", " # Create a dictionary mapping the feature name to the tf.train.Example-compatible\n", " # data type.\n", " feature = {\n", " 'feature0': _int64_feature(feature0),\n", " 'feature1': _int64_feature(feature1),\n", " 'feature2': _bytes_feature(feature2),\n", " 'feature3': _float_feature(feature3),\n", " }\n", "\n", " # Create a Features message using tf.train.Example.\n", "\n", " example_proto = tf.train.Example(features=tf.train.Features(feature=feature))\n", " return example_proto.SerializeToString()" ] }, { "cell_type": "markdown", "metadata": { "id": "XftzX9CN_uGT" }, "source": [ "例如,假设您从数据集中获得了一个观测值 `[False, 4, bytes('goat'), 0.9876]`。您可以使用 `create_message()` 创建和打印此观测值的 `tf.Example` 消息。如上所述,每个观测值将被写为一条 `Features` 消息。请注意,`tf.Example` [消息](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto#L88)只是 `Features` 消息外围的包装器:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N8BtSx2RjYcb", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# This is an example observation from the dataset.\n", "\n", "example_observation = []\n", "\n", "serialized_example = serialize_example(False, 4, b'goat', 0.9876)\n", "serialized_example" ] }, { "cell_type": "markdown", "metadata": { "id": "_pbGATlG6u-4" }, "source": [ "要解码消息,请使用 `tf.train.Example.FromString` 方法。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dGim-mEm6vit", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "example_proto = tf.train.Example.FromString(serialized_example)\n", "example_proto" ] }, { "cell_type": "markdown", "metadata": { "id": "o6qxofy89obI" }, "source": [ "## TFRecords 格式详细信息\n", "\n", "TFRecord 文件包含一系列记录。该文件只能按顺序读取。\n", "\n", "每条记录包含一个字节字符串(用于数据有效负载),外加数据长度,以及用于完整性检查的 CRC32C(使用 Castagnoli 多项式的 32 位 CRC)哈希值。\n", "\n", "每条记录会存储为以下格式:\n", "\n", "```\n", "uint64 length\n", "uint32 masked_crc32_of_length\n", "byte data[length]\n", "uint32 masked_crc32_of_data\n", "```\n", "\n", "将记录连接起来以生成文件。[此处](https://en.wikipedia.org/wiki/Cyclic_redundancy_check)对 CRC 进行了说明,且 CRC 的掩码为:\n", "\n", "```\n", "masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul\n", "```\n" ] }, { "cell_type": "markdown", "metadata": { "id": "-0iHagLQCJv6" }, "source": [ "注:没有要求在 TFRecord 文件中使用 `tf.train.Example`。`tf.train.Example` 只是一种将字典序列化为字节串的方法。任何可以在 TensorFlow 中解码的字节串都可以存储在 TFRecord 文件中。示例包括:文本行、JSON(使用 `tf.io.decode_json_example`)、编码图像数据或序列化 `tf.Tensors`(使用 `tf.io.serialize_tensor`/`tf.io.parse_tensor`)。请参阅 `tf.io` 模块了解更多选项。" ] }, { "cell_type": "markdown", "metadata": { "id": "y-Hjmee-fbLH" }, "source": [ "## 使用 `tf.data` 的 TFRecord 文件" ] }, { "cell_type": "markdown", "metadata": { "id": "GmehkCCT81Ez" }, "source": [ "`tf.data` 模块还提供用于在 TensorFlow 中读取和写入数据的工具。" ] }, { "cell_type": "markdown", "metadata": { "id": "1FISEuz8ubu3" }, "source": [ "### 写入 TFRecord 文件\n", "\n", "要将数据放入数据集中,最简单的方式是使用 `from_tensor_slices` 方法。\n", "\n", "若应用于数组,将返回标量数据集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mXeaukvwu5_-", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "tf.data.Dataset.from_tensor_slices(feature1)" ] }, { "cell_type": "markdown", "metadata": { "id": "f-q0VKyZvcad" }, "source": [ "若应用于数组的元组,将返回元组的数据集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "H5sWyu1kxnvg", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))\n", "features_dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m1C-t71Nywze", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Use `take(1)` to only pull one example from the dataset.\n", "for f0,f1,f2,f3 in features_dataset.take(1):\n", " print(f0)\n", " print(f1)\n", " print(f2)\n", " print(f3)" ] }, { "cell_type": "markdown", "metadata": { "id": "mhIe63awyZYd" }, "source": [ "使用 `tf.data.Dataset.map` 方法可将函数应用于 `Dataset` 的每个元素。\n", "\n", "映射函数必须在 TensorFlow 计算图模式下进行运算(它必须在 `tf.Tensors` 上运算并返回)。可以使用 `tf.py_function` 包装非张量函数(如 `serialize_example`)以使其兼容。\n", "\n", "使用 `tf.py_function` 需要指定形状和类型信息,否则它将不可用:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "apB5KYrJzjPI", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def tf_serialize_example(f0,f1,f2,f3):\n", " tf_string = tf.py_function(\n", " serialize_example,\n", " (f0, f1, f2, f3), # Pass these args to the above function.\n", " tf.string) # The return type is `tf.string`.\n", " return tf.reshape(tf_string, ()) # The result is a scalar." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lHFjW4u4Npz9", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "tf_serialize_example(f0, f1, f2, f3)" ] }, { "cell_type": "markdown", "metadata": { "id": "CrFZ9avE3HUF" }, "source": [ "将此函数应用于数据集中的每个元素:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VDeqYVbW3ww9", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "serialized_features_dataset = features_dataset.map(tf_serialize_example)\n", "serialized_features_dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DlDfuh46bRf6", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def generator():\n", " for features in features_dataset:\n", " yield serialize_example(*features)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iv9oXKrcbhvX", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "serialized_features_dataset = tf.data.Dataset.from_generator(\n", " generator, output_types=tf.string, output_shapes=())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Dqz8C4D5cIj9", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "serialized_features_dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "p6lw5VYpjZZC" }, "source": [ "并将它们写入 TFRecord 文件:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vP1VgTO44UIE", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "filename = 'test.tfrecord'\n", "writer = tf.data.experimental.TFRecordWriter(filename)\n", "writer.write(serialized_features_dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "6aV0GQhV8tmp" }, "source": [ "### 读取 TFRecord 文件" ] }, { "cell_type": "markdown", "metadata": { "id": "o3J5D4gcSy8N" }, "source": [ "您还可以使用 `tf.data.TFRecordDataset` 类来读取 TFRecord 文件。\n", "\n", "有关通过 `tf.data` 使用 TFRecord 文件的详细信息,请参见[此处](https://tensorflow.google.cn/guide/datasets#consuming_tfrecord_data)。\n", "\n", "使用 `TFRecordDataset` 对于标准化输入数据和优化性能十分有用。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6OjX6UZl-bHC", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "filenames = [filename]\n", "raw_dataset = tf.data.TFRecordDataset(filenames)\n", "raw_dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "6_EQ9i2E_-Fz" }, "source": [ "此时,数据集包含序列化的 `tf.train.Example` 消息。迭代时,它会将其作为标量字符串张量返回。\n", "\n", "使用 `.take` 方法仅显示前 10 条记录。\n", "\n", "注:在 `tf.data.Dataset` 上进行迭代仅在启用了 Eager Execution 时有效。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hxVXpLz_AJlm", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for raw_record in raw_dataset.take(10):\n", " print(repr(raw_record))" ] }, { "cell_type": "markdown", "metadata": { "id": "W-6oNzM4luFQ" }, "source": [ "可以使用以下函数对这些张量进行解析。请注意,这里的 `feature_description` 是必需的,因为数据集使用计算图执行,并且需要以下描述来构建它们的形状和类型签名:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zQjbIR1nleiy", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Create a description of the features.\n", "feature_description = {\n", " 'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),\n", " 'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),\n", " 'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),\n", " 'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),\n", "}\n", "\n", "def _parse_function(example_proto):\n", " # Parse the input `tf.train.Example` proto using the dictionary above.\n", " return tf.io.parse_single_example(example_proto, feature_description)" ] }, { "cell_type": "markdown", "metadata": { "id": "gWETjUqhEQZf" }, "source": [ "或者,使用 `tf.parse example` 一次解析整个批次。使用 `tf.data.Dataset.map` 方法将此函数应用于数据集中的每一项:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6Ob7D-zmBm1w", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "parsed_dataset = raw_dataset.map(_parse_function)\n", "parsed_dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "sNV-XclGnOvn" }, "source": [ "使用 Eager Execution 在数据集中显示观测值。此数据集中有 10,000 个观测值,但只会显示前 10 个。数据会作为特征字典进行显示。每一项都是一个 `tf.Tensor`,此张量的 `numpy` 元素会显示特征的值:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x2LT2JCqhoD_", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for parsed_record in parsed_dataset.take(10):\n", " print(repr(parsed_record))" ] }, { "cell_type": "markdown", "metadata": { "id": "Cig9EodTlDmg" }, "source": [ "在这里,`tf.parse_example` 函数会将 `tf.Example` 字段解压缩为标准张量。" ] }, { "cell_type": "markdown", "metadata": { "id": "jyg1g3gU7DNn" }, "source": [ "## Python 中的 TFRecord 文件" ] }, { "cell_type": "markdown", "metadata": { "id": "3FXG3miA7Kf1" }, "source": [ "`tf.io` 模块还包含用于读取和写入 TFRecord 文件的纯 Python 函数。" ] }, { "cell_type": "markdown", "metadata": { "id": "CKn5uql2lAaN" }, "source": [ "### 写入 TFRecord 文件" ] }, { "cell_type": "markdown", "metadata": { "id": "LNW_FA-GQWXs" }, "source": [ "接下来,将 10,000 个观测值写入文件 `test.tfrecord`。每个观测值都将转换为一条 `tf.Example` 消息,然后被写入文件。随后,您可以验证是否已创建 `test.tfrecord` 文件:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MKPHzoGv7q44", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Write the `tf.train.Example` observations to the file.\n", "with tf.io.TFRecordWriter(filename) as writer:\n", " for i in range(n_observations):\n", " example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])\n", " writer.write(example)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EjdFHHJMpUUo", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "!du -sh {filename}" ] }, { "cell_type": "markdown", "metadata": { "id": "2osVRnYNni-E" }, "source": [ "### 读取 TFRecord 文件\n", "\n", "您可以使用 `tf.train.Example.ParseFromString` 轻松解析以下序列化张量:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U3tnd3LerOtV", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "filenames = [filename]\n", "raw_dataset = tf.data.TFRecordDataset(filenames)\n", "raw_dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nsEAACHcnm3f", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for raw_record in raw_dataset.take(1):\n", " example = tf.train.Example()\n", " example.ParseFromString(raw_record.numpy())\n", " print(example)" ] }, { "cell_type": "markdown", "metadata": { "id": "yhnZZmhm1miG" }, "source": [ "这将返回 `tf.train.Example` proto,它难以按照原样使用,但基本上可以表示:\n", "\n", "```\n", "Dict[str,\n", " Union[List[float],\n", " List[int],\n", " List[str]]]\n", "```\n", "\n", "以下代码可以手动将 `Example` 转换成 NumPy 数组的字典,无需使用 TensorFlow Ops。有关详情,请参阅 [PROTO 文件](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/feature.proto)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ziv9tiNE1l6J", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "result = {}\n", "# example.features.feature is the dictionary\n", "for key, feature in example.features.feature.items():\n", " # The values are the Feature objects which contain a `kind` which contains:\n", " # one of three fields: bytes_list, float_list, int64_list\n", "\n", " kind = feature.WhichOneof('kind')\n", " result[key] = np.array(getattr(feature, kind).value)\n", "\n", "result" ] }, { "cell_type": "markdown", "metadata": { "id": "S0tFDrwdoj3q" }, "source": [ "## 演练:读取和写入图像数据" ] }, { "cell_type": "markdown", "metadata": { "id": "rjN2LFxFpcR9" }, "source": [ "下面是关于如何使用 TFRecord 读取和写入图像数据的端到端示例。您将使用图像作为输入数据,将数据写入 TFRecord 文件,然后将文件读取回来并显示图像。\n", "\n", "如果您想在同一个输入数据集上使用多个模型,这种做法会很有用。您可以不以原始格式存储图像,而是将图像预处理为 TFRecord 格式,然后将其用于所有后续的处理和建模中。\n", "\n", "首先,让我们下载雪中的猫的[图像](https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg),以及施工中的纽约威廉斯堡大桥的[照片](https://upload.wikimedia.org/wikipedia/commons/f/fe/New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg)。" ] }, { "cell_type": "markdown", "metadata": { "id": "5Lk2qrKvN0yu" }, "source": [ "### 提取图像" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3a0fmwg8lHdF", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "cat_in_snow = tf.keras.utils.get_file(\n", " '320px-Felis_catus-cat_on_snow.jpg',\n", " 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')\n", "\n", "williamsburg_bridge = tf.keras.utils.get_file(\n", " '194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg',\n", " 'https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7aJJh7vENeE4", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "display.display(display.Image(filename=cat_in_snow))\n", "display.display(display.HTML('Image cc-by: Von.grzanka'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KkW0uuhcXZqA", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "display.display(display.Image(filename=williamsburg_bridge))\n", "display.display(display.HTML('From Wikimedia'))" ] }, { "cell_type": "markdown", "metadata": { "id": "VSOgJSwoN5TQ" }, "source": [ "### 写入 TFRecord 文件" ] }, { "cell_type": "markdown", "metadata": { "id": "Azx83ryQEU6T" }, "source": [ "和以前一样,将特征编码为与 `tf.Example` 兼容的类型。这将存储原始图像字符串特征,以及高度、宽度、深度和任意 `label` 特征。后者会在您写入文件以区分猫和桥的图像时使用。将 `0` 用于猫的图像,将 `1` 用于桥的图像:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kC4TS1ZEONHr", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "image_labels = {\n", " cat_in_snow : 0,\n", " williamsburg_bridge : 1,\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "c5njMSYNEhNZ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# This is an example, just using the cat image.\n", "image_string = open(cat_in_snow, 'rb').read()\n", "\n", "label = image_labels[cat_in_snow]\n", "\n", "# Create a dictionary with features that may be relevant.\n", "def image_example(image_string, label):\n", " image_shape = tf.io.decode_jpeg(image_string).shape\n", "\n", " feature = {\n", " 'height': _int64_feature(image_shape[0]),\n", " 'width': _int64_feature(image_shape[1]),\n", " 'depth': _int64_feature(image_shape[2]),\n", " 'label': _int64_feature(label),\n", " 'image_raw': _bytes_feature(image_string),\n", " }\n", "\n", " return tf.train.Example(features=tf.train.Features(feature=feature))\n", "\n", "for line in str(image_example(image_string, label)).split('\\n')[:15]:\n", " print(line)\n", "print('...')" ] }, { "cell_type": "markdown", "metadata": { "id": "2G_o3O9MN0Qx" }, "source": [ "请注意,所有特征现在都存储在 `tf.Example` 消息中。接下来,函数化上面的代码,并将示例消息写入名为 `images.tfrecords` 的文件:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qcw06lQCOCZU", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Write the raw image files to `images.tfrecords`.\n", "# First, process the two images into `tf.train.Example` messages.\n", "# Then, write to a `.tfrecords` file.\n", "record_file = 'images.tfrecords'\n", "with tf.io.TFRecordWriter(record_file) as writer:\n", " for filename, label in image_labels.items():\n", " image_string = open(filename, 'rb').read()\n", " tf_example = image_example(image_string, label)\n", " writer.write(tf_example.SerializeToString())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yJrTe6tHPCfs", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "!du -sh {record_file}" ] }, { "cell_type": "markdown", "metadata": { "id": "jJSsCkZLPH6K" }, "source": [ "### 读取 TFRecord 文件\n", "\n", "现在,您有文件 `images.tfrecords`,并可以迭代其中的记录以将您写入的内容读取回来。因为在此示例中您只需重新生成图像,所以您只需要原始图像字符串这一个特征。使用上面描述的 getter 方法(即 `example.features.feature['image_raw'].bytes_list.value[0]`)提取该特征。您还可以使用标签来确定哪个记录是猫,哪个记录是桥:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M6Cnfd3cTKHN", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')\n", "\n", "# Create a dictionary describing the features.\n", "image_feature_description = {\n", " 'height': tf.io.FixedLenFeature([], tf.int64),\n", " 'width': tf.io.FixedLenFeature([], tf.int64),\n", " 'depth': tf.io.FixedLenFeature([], tf.int64),\n", " 'label': tf.io.FixedLenFeature([], tf.int64),\n", " 'image_raw': tf.io.FixedLenFeature([], tf.string),\n", "}\n", "\n", "def _parse_image_function(example_proto):\n", " # Parse the input tf.train.Example proto using the dictionary above.\n", " return tf.io.parse_single_example(example_proto, image_feature_description)\n", "\n", "parsed_image_dataset = raw_image_dataset.map(_parse_image_function)\n", "parsed_image_dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "0PEEFPk4NEg1" }, "source": [ "从 TFRecord 文件中恢复图像:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yZf8jOyEIjSF", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for image_features in parsed_image_dataset:\n", " image_raw = image_features['image_raw'].numpy()\n", " display.display(display.Image(data=image_raw))" ] } ], "metadata": { "colab": { "collapsed_sections": [ "pL--_KGdYoBz" ], "name": "tfrecord.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }