{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "DweYe9FcbMK_" }, "outputs": [], "source": [ "##### Copyright 2019 The TensorFlow Authors.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "AVV2e0XKbJeX" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "sUtoed20cRJJ" }, "source": [ "# 加载 CSV 数据" ] }, { "cell_type": "markdown", "metadata": { "id": "1ap_W4aQcgNT" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行在 Github 上查看源代码下载此 notebook
" ] }, { "cell_type": "markdown", "metadata": { "id": "C-3Xbt0FfGfs" }, "source": [ "本教程提供了如何在 TensorFlow 中使用 CSV 数据的示例。\n", "\n", "其中包括两个主要部分:\n", "\n", "1. **从磁盘加载数据**\n", "2. **将数据预处理为适合训练的形式。**\n", "\n", "本教程侧重于加载,并提供了一些关于预处理的快速示例。要了解有关预处理方面的更多信息,请查看[使用预处理层](https://tensorflow.google.cn/guide/keras/preprocessing_layers)指南和[使用 Keras 预处理层对结构化数据进行分类](../structured_data/preprocessing_layers.ipynb)教程。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "fgZ9gjmPfSnK" }, "source": [ "## 安装" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "baYFZMW_bJHh" }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "\n", "# Make numpy values easier to read.\n", "np.set_printoptions(precision=3, suppress=True)\n", "\n", "import tensorflow as tf\n", "from tensorflow.keras import layers" ] }, { "cell_type": "markdown", "metadata": { "id": "1ZhJYbJxHNGJ" }, "source": [ "## 内存数据" ] }, { "cell_type": "markdown", "metadata": { "id": "ny5TEgcmHjVx" }, "source": [ "对于任何较小的 CSV 数据集,在其上训练 TensorFlow 模型的最简单方式是将其作为 Pandas Dataframe 或 NumPy 数组加载到内存中。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "LgpBOuU8PGFf" }, "source": [ "一个相对简单的示例是 [Abalone Dataset](https://archive.ics.uci.edu/ml/datasets/abalone)。\n", "\n", "- 数据集很小。\n", "- 所有输入特征都是有限范围的浮点值。\n", "\n", "以下是将数据下载到 [Pandas `DataFrame`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) 的方式:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IZVExo9DKoNz" }, "outputs": [], "source": [ "abalone_train = pd.read_csv(\n", " \"https://storage.googleapis.com/download.tensorflow.org/data/abalone_train.csv\",\n", " names=[\"Length\", \"Diameter\", \"Height\", \"Whole weight\", \"Shucked weight\",\n", " \"Viscera weight\", \"Shell weight\", \"Age\"])\n", "\n", "abalone_train.head()" ] }, { "cell_type": "markdown", "metadata": { "id": "hP22mdyPQ1_t" }, "source": [ "该数据集包含一组[鲍鱼](https://en.wikipedia.org/wiki/Abalone)(一种海螺)的测量值。\n", "\n", "![an abalone shell](https://tensorflow.org/images/abalone_shell.jpg)\n", "\n", "[“鲍鱼壳”](https://www.flickr.com/photos/thenickster/16641048623/)(作者:[Nicki Dugan Pogue](https://www.flickr.com/photos/thenickster/),CC BY-SA 2.0)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vlfGrk_9N-wf" }, "source": [ "此数据集的名义任务是根据其他测量值预测年龄,因此要把特征和标签分开以进行训练:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "udOnDJOxNi7p" }, "outputs": [], "source": [ "abalone_features = abalone_train.copy()\n", "abalone_labels = abalone_features.pop('Age')" ] }, { "cell_type": "markdown", "metadata": { "id": "seK9n71-UBfT" }, "source": [ "对于此数据集,将以相同的方式处理所有特征。将这些特征打包成单个 NumPy 数组:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Dp3N5McbUMwb" }, "outputs": [], "source": [ "abalone_features = np.array(abalone_features)\n", "abalone_features" ] }, { "cell_type": "markdown", "metadata": { "id": "1C1yFOxLOdxh" }, "source": [ "接下来,制作一个回归模型来预测年龄。由于只有一个输入张量,这里使用 `tf.keras.Sequential` 模型就足够了。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d8zzNrZqOmfB" }, "outputs": [], "source": [ "abalone_model = tf.keras.Sequential([\n", " layers.Dense(64, activation='relu'),\n", " layers.Dense(1)\n", "])\n", "\n", "abalone_model.compile(loss = tf.keras.losses.MeanSquaredError(),\n", " optimizer = tf.keras.optimizers.Adam())" ] }, { "cell_type": "markdown", "metadata": { "id": "j6IWeP78O2wE" }, "source": [ "要训练该模型,请将特征和标签传递给 `Model.fit`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uZdpCD92SN3Z" }, "outputs": [], "source": [ "abalone_model.fit(abalone_features, abalone_labels, epochs=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "GapLOj1OOTQH" }, "source": [ "您刚刚看到了使用 CSV 数据训练模型的最基本方式。接下来,您将学习如何应用预处理来归一化数值列。" ] }, { "cell_type": "markdown", "metadata": { "id": "B87Rd1SOUv02" }, "source": [ "## 基本预处理" ] }, { "cell_type": "markdown", "metadata": { "id": "yCrB2Jd-U0Vt" }, "source": [ "对模型的输入进行归一化是一种很好的做法。Keras 预处理层提供了一种便捷方式来将此归一化构建到您的模型。\n", "\n", "`tf.keras.layers.Normalization` 层会预先计算每列的均值和方差,并使用这些值对数据进行归一化。\n", "\n", "首先,创建层:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "H2WQpDU5VRk7" }, "outputs": [], "source": [ "normalize = layers.Normalization()" ] }, { "cell_type": "markdown", "metadata": { "id": "hGgEZE-7Vpt6" }, "source": [ "然后,使用 `Normalization.adapt()` 方法使归一化层适应您的数据。\n", "\n", "注:仅将您的训练数据用于 `PreprocessingLayer.adapt` 方法。不要使用您的验证数据或测试数据。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2WgOPIiOVpLg" }, "outputs": [], "source": [ "normalize.adapt(abalone_features)" ] }, { "cell_type": "markdown", "metadata": { "id": "rE6vh0byV7cE" }, "source": [ "然后,将归一化层用于您的模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "quPcZ9dTWA9A" }, "outputs": [], "source": [ "norm_abalone_model = tf.keras.Sequential([\n", " normalize,\n", " layers.Dense(64, activation='relu'),\n", " layers.Dense(1)\n", "])\n", "\n", "norm_abalone_model.compile(loss = tf.keras.losses.MeanSquaredError(),\n", " optimizer = tf.keras.optimizers.Adam())\n", "\n", "norm_abalone_model.fit(abalone_features, abalone_labels, epochs=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "Wuqj601Qw0Ml" }, "source": [ "## 混合数据类型\n", "\n", "\"Titanic\" 数据集包含有关泰坦尼克号乘客的信息。该数据集的名义任务是预测幸存者。\n", "\n", "![交通堵塞。](images/csv/traffic.jpg)\n", "\n", "图片[来自 Wikimedia](https://commons.wikimedia.org/wiki/File:Trafficjam.jpg)\n", "\n", "原始数据可以轻松加载为 Pandas `DataFrame`,但不能立即用作 TensorFlow 模型的输入。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GS-dBMpuYMnz" }, "outputs": [], "source": [ "titanic = pd.read_csv(\"https://storage.googleapis.com/tf-datasets/titanic/train.csv\")\n", "titanic.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "D8rCGIK1ZzKx" }, "outputs": [], "source": [ "titanic_features = titanic.copy()\n", "titanic_labels = titanic_features.pop('survived')" ] }, { "cell_type": "markdown", "metadata": { "id": "urHOwpCDYtcI" }, "source": [ "由于数据类型和范围不同,您不能简单地将特征堆叠到 NumPy 数组中并将其传递给 `tf.keras.Sequential` 模型。每列都需要单独处理。\n", "\n", "作为一种选择,您可以(使用您喜欢的任何工具)离线预处理数据,将分类列转换为数值列,然后将处理后的输出传递给 TensorFlow 模型。这种方式的缺点是,如果保存并导出模型,预处理不会随之保存。Keras 预处理层能够避免这个问题,因为它们是模型的一部分。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Bta4Sx0Zau5v" }, "source": [ "在此示例中,您将构建一个使用 [Keras 函数式 API](https://tensorflow.google.cn/guide/keras/functional) 实现预处理逻辑的模型。您也可以通过[子类化](https://tensorflow.google.cn/guide/keras/custom_layers_and_models)来实现。\n", "\n", "函数式 API 会对“符号”张量进行运算。正常的 \"eager\" 张量有一个值。相比之下,这些“符号”张量则没有值。相反,它们会跟踪在它们上面运行的运算,并构建可以稍后运行的计算的表示。以下是一个简单示例:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "730F16_97D-3" }, "outputs": [], "source": [ "# Create a symbolic input\n", "input = tf.keras.Input(shape=(), dtype=tf.float32)\n", "\n", "# Perform a calculation using the input\n", "result = 2*input + 1\n", "\n", "# the result doesn't have a value\n", "result" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RtcNXWB18kMJ" }, "outputs": [], "source": [ "calc = tf.keras.Model(inputs=input, outputs=result)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fUGQOUqZ8sa-" }, "outputs": [], "source": [ "print(calc(1).numpy())\n", "print(calc(2).numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "rNS9lT7f6_U2" }, "source": [ "要构建预处理模型,首先要构建一组符号 `tf.keras.Input` 对象,匹配 CSV 列的名称和数据类型。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5WODe_1da3yw" }, "outputs": [], "source": [ "inputs = {}\n", "\n", "for name, column in titanic_features.items():\n", " dtype = column.dtype\n", " if dtype == object:\n", " dtype = tf.string\n", " else:\n", " dtype = tf.float32\n", "\n", " inputs[name] = tf.keras.Input(shape=(1,), name=name, dtype=dtype)\n", "\n", "inputs" ] }, { "cell_type": "markdown", "metadata": { "id": "aaheJFmymq8l" }, "source": [ "预处理逻辑的第一步是将数值输入串联在一起,并通过归一化层运行它们:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wPRC_E6rkp8D" }, "outputs": [], "source": [ "numeric_inputs = {name:input for name,input in inputs.items()\n", " if input.dtype==tf.float32}\n", "\n", "x = layers.Concatenate()(list(numeric_inputs.values()))\n", "norm = layers.Normalization()\n", "norm.adapt(np.array(titanic[numeric_inputs.keys()]))\n", "all_numeric_inputs = norm(x)\n", "\n", "all_numeric_inputs" ] }, { "cell_type": "markdown", "metadata": { "id": "-JoR45Uj712l" }, "source": [ "收集所有符号预处理结果,稍后将它们串联起来:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M7jIJw5XntdN" }, "outputs": [], "source": [ "preprocessed_inputs = [all_numeric_inputs]" ] }, { "cell_type": "markdown", "metadata": { "id": "r0Hryylyosfm" }, "source": [ "对于字符串输入,请使用 `tf.keras.layers.StringLookup` 函数将字符串映射到词汇表中的整数索引。接下来,使用 `tf.keras.layers.CategoryEncoding` 将索引转换为适合模型的 `float32` 数据。\n", "\n", "`tf.keras.layers.CategoryEncoding` 层的默认设置会为每个输入创建一个独热向量。也可以使用 `tf.keras.layers.Embedding`。请参阅[使用预处理层](https://tensorflow.google.cn/guide/keras/preprocessing_layers)指南和[使用 Keras 预处理层对结构化数据进行分类](../structured_data/preprocessing_layers.ipynb)教程,了解有关此主题的更多信息。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "79fi1Cgan2YV" }, "outputs": [], "source": [ "for name, input in inputs.items():\n", " if input.dtype == tf.float32:\n", " continue\n", " \n", " lookup = layers.StringLookup(vocabulary=np.unique(titanic_features[name]))\n", " one_hot = layers.CategoryEncoding(num_tokens=lookup.vocabulary_size())\n", "\n", " x = lookup(input)\n", " x = one_hot(x)\n", " preprocessed_inputs.append(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "Wnhv0T7itnc7" }, "source": [ "您可以使用 `inputs` 和 `processed_inputs` 的集合将所有预处理的输入串联在一起,并构建处理预处理的模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XJRzUTe8ukXc" }, "outputs": [], "source": [ "preprocessed_inputs_cat = layers.Concatenate()(preprocessed_inputs)\n", "\n", "titanic_preprocessing = tf.keras.Model(inputs, preprocessed_inputs_cat)\n", "\n", "tf.keras.utils.plot_model(model = titanic_preprocessing , rankdir=\"LR\", dpi=72, show_shapes=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "PNHxrNW8vdda" }, "source": [ "此 `model` 仅包含输入预处理。您可以运行它以查看其对您的数据进行了哪些操作。Keras 模型不会自动转换 Pandas DataFrames,因为不清楚是应该将其转换为一个张量还是张量字典。因此,将其转换为张量字典:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5YjdYyMEacwQ" }, "outputs": [], "source": [ "titanic_features_dict = {name: np.array(value) \n", " for name, value in titanic_features.items()}" ] }, { "cell_type": "markdown", "metadata": { "id": "0nKJYoPByada" }, "source": [ "切出第一个训练样本并将其传递给此预处理模型,您会看到数字特征和字符串独热全部串联在一起:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SjnmU8PSv8T3" }, "outputs": [], "source": [ "features_dict = {name:values[:1] for name, values in titanic_features_dict.items()}\n", "titanic_preprocessing(features_dict)" ] }, { "cell_type": "markdown", "metadata": { "id": "qkBf4LvmzMDp" }, "source": [ "接下来,在此基础上构建模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "coIPtGaCzUV7" }, "outputs": [], "source": [ "def titanic_model(preprocessing_head, inputs):\n", " body = tf.keras.Sequential([\n", " layers.Dense(64, activation='relu'),\n", " layers.Dense(1)\n", " ])\n", "\n", " preprocessed_inputs = preprocessing_head(inputs)\n", " result = body(preprocessed_inputs)\n", " model = tf.keras.Model(inputs, result)\n", "\n", " model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", " optimizer=tf.keras.optimizers.Adam())\n", " return model\n", "\n", "titanic_model = titanic_model(titanic_preprocessing, inputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "LK5uBQQF2KbZ" }, "source": [ "训练模型时,将特征字典作为 `x` 传递,将标签作为 `y` 传递。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "D1gVfwJ61ejz" }, "outputs": [], "source": [ "titanic_model.fit(x=titanic_features_dict, y=titanic_labels, epochs=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "LxgJarZk3bfH" }, "source": [ "由于预处理是模型的一部分,您可以保存模型并将其重新加载到其他地方并获得相同的结果:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ay-8ymNA2ZCh" }, "outputs": [], "source": [ "titanic_model.save('test.keras')\n", "reloaded = tf.keras.models.load_model('test.keras')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Qm6jMTpD20lK" }, "outputs": [], "source": [ "features_dict = {name:values[:1] for name, values in titanic_features_dict.items()}\n", "\n", "before = titanic_model(features_dict)\n", "after = reloaded(features_dict)\n", "assert (before-after)<1e-3\n", "print(before)\n", "print(after)" ] }, { "cell_type": "markdown", "metadata": { "id": "7VsPlxIRZpXf" }, "source": [ "## 使用 tf.data\n" ] }, { "cell_type": "markdown", "metadata": { "id": "NyVDCwGzR5HW" }, "source": [ "在前一部分中,您在训练模型时依赖了模型的内置数据乱序和批处理。\n", "\n", "如果您需要对输入数据流水线进行更多控制或需要使用不易放入内存的数据:请使用 `tf.data`。\n", "\n", "有关更多示例,请参阅 [`tf.data`:构建 TensorFlow 输入流水线](../../guide/data.ipynb)指南。" ] }, { "cell_type": "markdown", "metadata": { "id": "gP5Y1jM2Sor0" }, "source": [ "### 有关内存数据\n", "\n", "作为将 `tf.data` 应用于 CSV 数据的第一个样本,请考虑使用以下代码手动切分上一个部分中的特征字典。对于每个索引,它会为每个特征获取该索引:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "i8wE-MVuVu7_" }, "outputs": [], "source": [ "import itertools\n", "\n", "def slices(features):\n", " for i in itertools.count():\n", " # For each feature take index `i`\n", " example = {name:values[i] for name, values in features.items()}\n", " yield example" ] }, { "cell_type": "markdown", "metadata": { "id": "cQ3RTbS9YEal" }, "source": [ "运行此代码并打印第一个样本:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Wwq8XK88WwFk" }, "outputs": [], "source": [ "for example in slices(titanic_features_dict):\n", " for name, value in example.items():\n", " print(f\"{name:19s}: {value}\")\n", " break" ] }, { "cell_type": "markdown", "metadata": { "id": "vvp8Dct6YOIE" }, "source": [ "内存数据加载程序中最基本的 `tf.data.Dataset` 是 `Dataset.from_tensor_slices` 构造函数。这会返回一个 `tf.data.Dataset`,它将在 TensorFlow 中实现上述 `slices` 函数的泛化版本。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2gEJthslYxeV" }, "outputs": [], "source": [ "features_ds = tf.data.Dataset.from_tensor_slices(titanic_features_dict)" ] }, { "cell_type": "markdown", "metadata": { "id": "-ZC0rTpMZMZK" }, "source": [ "您可以像任何其他 Python 可迭代对象一样迭代 `tf.data.Dataset`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gOHbiefaY4ag" }, "outputs": [], "source": [ "for example in features_ds:\n", " for name, value in example.items():\n", " print(f\"{name:19s}: {value}\")\n", " break" ] }, { "cell_type": "markdown", "metadata": { "id": "uwcFoVJWZY5F" }, "source": [ "`from_tensor_slices` 函数可以处理嵌套字典或元组的任何结构。以下代码创建了一个 `(features_dict, labels)` 对的数据集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xIHGBy76Zcrx" }, "outputs": [], "source": [ "titanic_ds = tf.data.Dataset.from_tensor_slices((titanic_features_dict, titanic_labels))" ] }, { "cell_type": "markdown", "metadata": { "id": "gQwxitt8c2GK" }, "source": [ "要使用此 `Dataset` 训练模型,您至少需要对数据进行 `shuffle` 和 `batch`。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SbJcbldhddeC" }, "outputs": [], "source": [ "titanic_batches = titanic_ds.shuffle(len(titanic_labels)).batch(32)" ] }, { "cell_type": "markdown", "metadata": { "id": "-4FRqhRFuoJx" }, "source": [ "不是将 `features` 和 `labels` 传递给 `Model.fit`,而是传递数据集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8yXkNPumdBtB" }, "outputs": [], "source": [ "titanic_model.fit(titanic_batches, epochs=5)" ] }, { "cell_type": "markdown", "metadata": { "id": "qXuibiv9exT7" }, "source": [ "### 从单个文件\n", "\n", "到目前为止,本教程已经使用了内存数据。`tf.data` 是用于构建数据流水线的高度可扩展的工具包,并提供了一些用于处理加载 CSV 文件的函数。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ncf5t6tgL5ZI" }, "outputs": [], "source": [ "titanic_file_path = tf.keras.utils.get_file(\"train.csv\", \"https://storage.googleapis.com/tf-datasets/titanic/train.csv\")" ] }, { "cell_type": "markdown", "metadata": { "id": "t4N-plO4tDXd" }, "source": [ "现在,从文件中读取 CSV 数据并创建一个 `tf.data.Dataset`。\n", "\n", "(有关完整文档,请参阅 `tf.data.experimental.make_csv_dataset`)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yIbUscB9sqha" }, "outputs": [], "source": [ "titanic_csv_ds = tf.data.experimental.make_csv_dataset(\n", " titanic_file_path,\n", " batch_size=5, # Artificially small to make examples easier to show.\n", " label_name='survived',\n", " num_epochs=1,\n", " ignore_errors=True,)" ] }, { "cell_type": "markdown", "metadata": { "id": "Sf3v3BKgy4AG" }, "source": [ "此函数包括许多方便的功能,因此很容易处理数据。这包括:\n", "\n", "- 使用列标题作为字典键。\n", "- 自动确定每列的类型。\n", "\n", "小心:请确保在 `tf.data.experimental.make_csv_dataset` 中设置 `num_epochs` 参数,否则 `tf.data.Dataset` 的默认行为是无限循环。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v4oMO9MIxgTG" }, "outputs": [], "source": [ "for batch, label in titanic_csv_ds.take(1):\n", " for key, value in batch.items():\n", " print(f\"{key:20s}: {value}\")\n", " print()\n", " print(f\"{'label':20s}: {label}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "k-TgA6o2Ja6U" }, "source": [ "注:如果您运行两次上述代码单元,它将产生不同的结果。`tf.data.experimental.make_csv_dataset` 的默认设置包括 `shuffle_buffer_size=1000`,这对于这个小型数据集来说已经绰绰有余,但可能不适用于实际的数据集。" ] }, { "cell_type": "markdown", "metadata": { "id": "d6uviU_KCCWD" }, "source": [ "它还可以对数据进行即时解压。下面是一个用 gzip 压缩的 CSV 文件,其中包含 [Metro Interstate Traffic Dataset](https://archive.ics.uci.edu/ml/datasets/Metro+Interstate+Traffic+Volume)。\n", "\n", "![字体](images/csv/fonts.jpg)\n", "\n", "图片[来自 Wikimedia](https://commons.wikimedia.org/wiki/File:Trafficjam.jpg)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kT7oZI2E46Q8" }, "outputs": [], "source": [ "traffic_volume_csv_gz = tf.keras.utils.get_file(\n", " 'Metro_Interstate_Traffic_Volume.csv.gz', \n", " \"https://archive.ics.uci.edu/ml/machine-learning-databases/00492/Metro_Interstate_Traffic_Volume.csv.gz\",\n", " cache_dir='.', cache_subdir='traffic')" ] }, { "cell_type": "markdown", "metadata": { "id": "F-IOsFHbCw0i" }, "source": [ "将 `compression_type` 参数设置为直接从压缩文件中读取:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ar0MPEVJ5NeA" }, "outputs": [], "source": [ "traffic_volume_csv_gz_ds = tf.data.experimental.make_csv_dataset(\n", " traffic_volume_csv_gz,\n", " batch_size=256,\n", " label_name='traffic_volume',\n", " num_epochs=1,\n", " compression_type=\"GZIP\")\n", "\n", "for batch, label in traffic_volume_csv_gz_ds.take(1):\n", " for key, value in batch.items():\n", " print(f\"{key:20s}: {value[:5]}\")\n", " print()\n", " print(f\"{'label':20s}: {label[:5]}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "p12Y6tGq8D6M" }, "source": [ "注:如果需要在 `tf.data` 流水线中解析这些日期时间字符串,您可以使用 `tfa.text.parse_time`。" ] }, { "cell_type": "markdown", "metadata": { "id": "EtrAXzYGP3l0" }, "source": [ "### 缓存" ] }, { "cell_type": "markdown", "metadata": { "id": "fN2dL_LRP83r" }, "source": [ "解析 CSV 数据有一些开销。对于小型模型,这可能是训练的瓶颈。\n", "\n", "根据您的用例,使用 `Dataset.cache` 或 `tf.data.Dataset.snapshot` 可能是个好主意,这样 CSV 数据仅会在第一个周期进行解析。\n", "\n", "`cache` 和 `snapshot` 方法的主要区别在于 `cache` 文件只能由创建它们的 TensorFlow 进程使用,而 `snapshot` 文件可以被其他进程读取。\n", "\n", "例如,在没有缓存的情况下迭代 `traffic_volume_csv_gz_ds` 20 次可能需要大约 15 秒,而使用缓存大约需要 2 秒。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Qk38Sw4MO4eh" }, "outputs": [], "source": [ "%%time\n", "for i, (batch, label) in enumerate(traffic_volume_csv_gz_ds.repeat(20)):\n", " if i % 40 == 0:\n", " print('.', end='')\n", "print()" ] }, { "cell_type": "markdown", "metadata": { "id": "pN3HtDONh5TX" }, "source": [ "注:`Dataset.cache` 会存储第一个周期的数据并按顺序回放。因此,使用 `cache` 方法会停用流水线中较早的任何重排。下面,在 `Dataset.cache` 之后重新添加了 `Dataset.shuffle`。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "r5Jj72MrPbnh" }, "outputs": [], "source": [ "%%time\n", "caching = traffic_volume_csv_gz_ds.cache().shuffle(1000)\n", "\n", "for i, (batch, label) in enumerate(caching.shuffle(1000).repeat(20)):\n", " if i % 40 == 0:\n", " print('.', end='')\n", "print()" ] }, { "cell_type": "markdown", "metadata": { "id": "wN7uUBjmgNZ9" }, "source": [ "注:`tf.data.Dataset.snapshot` 文件用于在使用时*临时*存储数据集。这*不是*长期存储的格式。文件格式被视为内部详细信息,无法在 TensorFlow 各版本之间保证。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PHGD1E8ktUvW" }, "outputs": [], "source": [ "%%time\n", "snapshotting = traffic_volume_csv_gz_ds.snapshot('titanic.tfsnap').shuffle(1000)\n", "\n", "for i, (batch, label) in enumerate(snapshotting.shuffle(1000).repeat(20)):\n", " if i % 40 == 0:\n", " print('.', end='')\n", "print()" ] }, { "cell_type": "markdown", "metadata": { "id": "fUSSegnMCGRz" }, "source": [ "如果加载 CSV 文件减慢了数据加载速度,并且 `Dataset.cache` 和 `tf.data.Dataset.snapshot` 不足以满足您的用例,请考虑将数据重新编码为更简化的格式。" ] }, { "cell_type": "markdown", "metadata": { "id": "M0iGXv9pC5kr" }, "source": [ "### 多个文件" ] }, { "cell_type": "markdown", "metadata": { "id": "9FFzHQrCDH4w" }, "source": [ "到目前为止,本部分中的所有示例都可以在没有 `tf.data` 的情况下轻松完成。处理文件集合时,`tf.data` 可以真正简化事情。\n", "\n", "例如,将 [Character Font Images](https://archive.ics.uci.edu/ml/datasets/Character+Font+Images) 数据集作为 CSV 文件的集合分发,每种字体一个集合。\n", "\n", "![Fonts](https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/load_data/images/csv/fonts.jpg?raw=true)\n", "\n", "图像作者:Willi Heidelbach,来源:Pixabay\n", "\n", "下载数据集,并检查里面的文件:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RmVknMdJh5ks" }, "outputs": [], "source": [ "fonts_zip = tf.keras.utils.get_file(\n", " 'fonts.zip', \"https://archive.ics.uci.edu/ml/machine-learning-databases/00417/fonts.zip\",\n", " cache_dir='.', cache_subdir='fonts',\n", " extract=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xsDlMCnyi55e" }, "outputs": [], "source": [ "import pathlib\n", "font_csvs = sorted(str(p) for p in pathlib.Path('fonts').glob(\"*.csv\"))\n", "\n", "font_csvs[:10]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lRAEJx9ROAGl" }, "outputs": [], "source": [ "len(font_csvs)" ] }, { "cell_type": "markdown", "metadata": { "id": "19Udrw9iG-FS" }, "source": [ "在处理一堆文件时,可以将 glob 样式的 `file_pattern` 传递给 `tf.data.experimental.make_csv_dataset` 函数。每次迭代都会重排文件的顺序。\n", "\n", "使用 `num_parallel_reads` 参数设置并行读取并交错在一起的文件数量。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6TSUNdT6iG58" }, "outputs": [], "source": [ "fonts_ds = tf.data.experimental.make_csv_dataset(\n", " file_pattern = \"fonts/*.csv\",\n", " batch_size=10, num_epochs=1,\n", " num_parallel_reads=20,\n", " shuffle_buffer_size=10000)" ] }, { "cell_type": "markdown", "metadata": { "id": "XMoexinLHYFa" }, "source": [ "这些 CSV 文件会将图像展平成一行。列名的格式为 `r{row}c{column}`。下面是第一个批次:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RmFvBWxxi3pq" }, "outputs": [], "source": [ "for features in fonts_ds.take(1):\n", " for i, (name, value) in enumerate(features.items()):\n", " if i>15:\n", " break\n", " print(f\"{name:20s}: {value}\")\n", "print('...')\n", "print(f\"[total: {len(features)} features]\")" ] }, { "cell_type": "markdown", "metadata": { "id": "xrC3sKdeOhb5" }, "source": [ "#### 可选:打包字段\n", "\n", "您可能不想像这样在单独的列中处理每个像素。在尝试使用此数据集之前,请务必将像素打包到图像张量中。\n", "\n", "下面是解析列名,从而为每个示例构建图像的代码:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hct5EMEWNyfH" }, "outputs": [], "source": [ "import re\n", "\n", "def make_images(features):\n", " image = [None]*400\n", " new_feats = {}\n", "\n", " for name, value in features.items():\n", " match = re.match('r(\\d+)c(\\d+)', name)\n", " if match:\n", " image[int(match.group(1))*20+int(match.group(2))] = value\n", " else:\n", " new_feats[name] = value\n", "\n", " image = tf.stack(image, axis=0)\n", " image = tf.reshape(image, [20, 20, -1])\n", " new_feats['image'] = image\n", "\n", " return new_feats" ] }, { "cell_type": "markdown", "metadata": { "id": "61qy8utAwARP" }, "source": [ "将该函数应用于数据集中的每个批次:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DJnnfIW9baE4" }, "outputs": [], "source": [ "fonts_image_ds = fonts_ds.map(make_images)\n", "\n", "for features in fonts_image_ds.take(1):\n", " break" ] }, { "cell_type": "markdown", "metadata": { "id": "_ThqrthGwHSm" }, "source": [ "绘制生成的图像:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "I5dcey31T_tk" }, "outputs": [], "source": [ "from matplotlib import pyplot as plt\n", "\n", "plt.figure(figsize=(6,6), dpi=120)\n", "\n", "for n in range(9):\n", " plt.subplot(3,3,n+1)\n", " plt.imshow(features['image'][..., n])\n", " plt.title(chr(features['m_label'][n]))\n", " plt.axis('off')" ] }, { "cell_type": "markdown", "metadata": { "id": "7-nNR0Nncdd1" }, "source": [ "## 低级函数" ] }, { "cell_type": "markdown", "metadata": { "id": "3jiGZeUijJNd" }, "source": [ "到目前为止,本教程重点介绍了用于读取 CSV 数据的最高级别效用函数。如果您的用例不符合基本模式,还有其他两个 API 可能对高级用户有所帮助。\n", "\n", "- `tf.io.decode_csv`:用于将文本行解析为 CSV 列张量列表的函数。\n", "- `tf.data.experimental.CsvDataset`:较低级别的 CSV 数据集构造函数。\n", "\n", "本部分会重新创建 `tf.data.experimental.make_csv_dataset` 提供的功能,以演示如何使用此较低级别的功能。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "LL_ixywomOHW" }, "source": [ "### `tf.io.decode_csv`\n", "\n", "此函数会将字符串或字符串列表解码为列列表。\n", "\n", "与 `tf.data.experimental.make_csv_dataset` 不同,此函数不会尝试猜测列数据类型。您可以通过为每列提供包含正确类型值的记录 `record_defaults` 值列表来指定列类型。\n", "\n", "要使用 tf.io.decode_csv 将 Titanic 数据作为字符串读取,您可以使用以下代码:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m1D2C-qdlqeW" }, "outputs": [], "source": [ "text = pathlib.Path(titanic_file_path).read_text()\n", "lines = text.split('\\n')[1:-1]\n", "\n", "all_strings = [str()]*10\n", "all_strings" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9W4UeJYyHPx5" }, "outputs": [], "source": [ "features = tf.io.decode_csv(lines, record_defaults=all_strings) \n", "\n", "for f in features:\n", " print(f\"type: {f.dtype.name}, shape: {f.shape}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "j8TaHSQFoQL4" }, "source": [ "要使用实际类型解析它们,请创建相应类型的 `record_defaults` 列表: " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rzUjR59yoUe1" }, "outputs": [], "source": [ "print(lines[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7sPTunxwoeWU" }, "outputs": [], "source": [ "titanic_types = [int(), str(), float(), int(), int(), float(), str(), str(), str(), str()]\n", "titanic_types" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "n3NlViCzoB7F" }, "outputs": [], "source": [ "features = tf.io.decode_csv(lines, record_defaults=titanic_types) \n", "\n", "for f in features:\n", " print(f\"type: {f.dtype.name}, shape: {f.shape}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "m-LkTUTnpn2P" }, "source": [ "注:在大批量行上调用 `tf.io.decode_csv` 比在单个 CSV 文本行上调用更有效。" ] }, { "cell_type": "markdown", "metadata": { "id": "Yp1UItJmqGqw" }, "source": [ "### `tf.data.experimental.CsvDataset`\n", "\n", "`tf.data.experimental.CsvDataset` 类提供了一个最小的 CSV `Dataset` 接口,没有 `tf.data.experimental.make_csv_dataset` 函数的便利功能:列标题解析、列类型推断、自动重排、文件交错。\n", "\n", "此构造函数使用 `record_defaults` 的方式与 `tf.io.decode_csv` 相同:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9OzZLp3krP-t" }, "outputs": [], "source": [ "simple_titanic = tf.data.experimental.CsvDataset(titanic_file_path, record_defaults=titanic_types, header=True)\n", "\n", "for example in simple_titanic.take(1):\n", " print([e.numpy() for e in example])" ] }, { "cell_type": "markdown", "metadata": { "id": "_HBmfI-Ks7dw" }, "source": [ "上面的代码基本等价于:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "E5O5d69Yq7gG" }, "outputs": [], "source": [ "def decode_titanic_line(line):\n", " return tf.io.decode_csv(line, titanic_types)\n", "\n", "manual_titanic = (\n", " # Load the lines of text\n", " tf.data.TextLineDataset(titanic_file_path)\n", " # Skip the header row.\n", " .skip(1)\n", " # Decode the line.\n", " .map(decode_titanic_line)\n", ")\n", "\n", "for example in manual_titanic.take(1):\n", " print([e.numpy() for e in example])" ] }, { "cell_type": "markdown", "metadata": { "id": "5R3ralsnt2AC" }, "source": [ "#### 多个文件\n", "\n", "要使用 `tf.data.experimental.CsvDataset` 解析字体数据集,您首先需要确定 `record_defaults` 的列类型。首先检查一个文件的第一行:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3tlFOTjCvAI5" }, "outputs": [], "source": [ "font_line = pathlib.Path(font_csvs[0]).read_text().splitlines()[1]\n", "print(font_line)" ] }, { "cell_type": "markdown", "metadata": { "id": "etyGu8K_ySRz" }, "source": [ "只有前两个字段是字符串,其余的都是整数或浮点数,通过计算逗号的个数可以得到特征总数:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "crgZZn0BzkSB" }, "outputs": [], "source": [ "num_font_features = font_line.count(',')+1\n", "font_column_types = [str(), str()] + [float()]*(num_font_features-2)" ] }, { "cell_type": "markdown", "metadata": { "id": "YeK2Pw540RNj" }, "source": [ "`tf.data.experimental.CsvDataset` 构造函数可以获取输入文件列表,但会按顺序读取它们。CSV 列表中的第一个文件是 `AGENCY.csv`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_SvL5Uvl0r0N" }, "outputs": [], "source": [ "font_csvs[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "EfAX3G8Xywy6" }, "source": [ "因此,当您将文件列表传递给 `CsvDataset` 时,会首先读取 `AGENCY.csv` 中的记录:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Gtr1E66VmBqj" }, "outputs": [], "source": [ "simple_font_ds = tf.data.experimental.CsvDataset(\n", " font_csvs, \n", " record_defaults=font_column_types, \n", " header=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "k750Mgq4yt_o" }, "outputs": [], "source": [ "for row in simple_font_ds.take(10):\n", " print(row[0].numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "NiqWKQV21FrE" }, "source": [ "要交错多个文件,请使用 `Dataset.interleave`。\n", "\n", "这是一个包含 CSV 文件名的初始数据集: " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t9dS3SNb23W8" }, "outputs": [], "source": [ "font_files = tf.data.Dataset.list_files(\"fonts/*.csv\")" ] }, { "cell_type": "markdown", "metadata": { "id": "TNiLHMXpzHy5" }, "source": [ "这会在每个周期重排文件名:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zNd-TYyNzIgg" }, "outputs": [], "source": [ "print('Epoch 1:')\n", "for f in list(font_files)[:5]:\n", " print(\" \", f.numpy())\n", "print(' ...')\n", "print()\n", "\n", "print('Epoch 2:')\n", "for f in list(font_files)[:5]:\n", " print(\" \", f.numpy())\n", "print(' ...')" ] }, { "cell_type": "markdown", "metadata": { "id": "B0QB1PtU3WAN" }, "source": [ "`interleave` 方法采用 `map_func`,它会为父 `Dataset`的每个元素创建一个子 `Dataset`。\n", "\n", "在这里,您要从文件数据集的每个元素创建一个 `tf.data.experimental.CsvDataset`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QWp4rH0Q4uPh" }, "outputs": [], "source": [ "def make_font_csv_ds(path):\n", " return tf.data.experimental.CsvDataset(\n", " path, \n", " record_defaults=font_column_types, \n", " header=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "VxRGdLMB5nRF" }, "source": [ "交错返回的 `Dataset` 通过循环遍历多个子 `Dataset` 来返回元素。请注意,下面的数据集如何在 `cycle_length=3` 三个字体文件中循环:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OePMNF_x1_Cc" }, "outputs": [], "source": [ "font_rows = font_files.interleave(make_font_csv_ds,\n", " cycle_length=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UORIGWLy54-E" }, "outputs": [], "source": [ "fonts_dict = {'font_name':[], 'character':[]}\n", "\n", "for row in font_rows.take(10):\n", " fonts_dict['font_name'].append(row[0].numpy().decode())\n", " fonts_dict['character'].append(chr(row[2].numpy()))\n", "\n", "pd.DataFrame(fonts_dict)" ] }, { "cell_type": "markdown", "metadata": { "id": "mkKZa_HX8zAm" }, "source": [ "#### 性能\n" ] }, { "cell_type": "markdown", "metadata": { "id": "8BtGHraUApdJ" }, "source": [ "早些时候,有人注意到 `tf.io.decode_csv` 在一个批次字符串上运行时效率更高。\n", "\n", "当使用大批次时,可以利用这一事实来提高 CSV 加载性能(但请先尝试使用[缓存](#caching))。" ] }, { "cell_type": "markdown", "metadata": { "id": "d35zWMH7MDL1" }, "source": [ "使用内置加载器 20,2048 个样本批次大约需要 17 秒。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ieUVAPryjpJS" }, "outputs": [], "source": [ "BATCH_SIZE=2048\n", "fonts_ds = tf.data.experimental.make_csv_dataset(\n", " file_pattern = \"fonts/*.csv\",\n", " batch_size=BATCH_SIZE, num_epochs=1,\n", " num_parallel_reads=100)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MUC2KW4LkQIz" }, "outputs": [], "source": [ "%%time\n", "for i,batch in enumerate(fonts_ds.take(20)):\n", " print('.',end='')\n", "\n", "print()" ] }, { "cell_type": "markdown", "metadata": { "id": "5lhnh6rZEDS2" }, "source": [ "将**批量文本行**传递给 `decode_csv` 运行速度更快,大约需要 5 秒:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4XbPZV1okVF9" }, "outputs": [], "source": [ "fonts_files = tf.data.Dataset.list_files(\"fonts/*.csv\")\n", "fonts_lines = fonts_files.interleave(\n", " lambda fname:tf.data.TextLineDataset(fname).skip(1), \n", " cycle_length=100).batch(BATCH_SIZE)\n", "\n", "fonts_fast = fonts_lines.map(lambda x: tf.io.decode_csv(x, record_defaults=font_column_types))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "te9C2km-qO8W" }, "outputs": [], "source": [ "%%time\n", "for i,batch in enumerate(fonts_fast.take(20)):\n", " print('.',end='')\n", "\n", "print()" ] }, { "cell_type": "markdown", "metadata": { "id": "aebC1plsMeOi" }, "source": [ "有关通过使用大批次提高 CSV 性能的另一个示例,请参阅[过拟合和欠拟合教程](../keras/overfit_and_underfit.ipynb)。\n", "\n", "这种方式可能有效,但请考虑其他选项,例如 `Dataset.cache` 和 `tf.data.Dataset.snapshot`,或者将您的数据重新编码为更简化的格式。" ] } ], "metadata": { "colab": { "name": "csv.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "xxx", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 0 }