{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "DweYe9FcbMK_" }, "outputs": [], "source": [ "##### Copyright 2019 The TensorFlow Authors.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "AVV2e0XKbJeX" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "sUtoed20cRJJ" }, "source": [ "# 加载 CSV 数据" ] }, { "cell_type": "markdown", "metadata": { "id": "1ap_W4aQcgNT" }, "source": [ "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
DataFrames
,因为不清楚是应该将其转换为一个张量还是张量字典。因此,将其转换为张量字典:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5YjdYyMEacwQ"
},
"outputs": [],
"source": [
"titanic_features_dict = {name: np.array(value) \n",
" for name, value in titanic_features.items()}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0nKJYoPByada"
},
"source": [
"切出第一个训练样本并将其传递给此预处理模型,您会看到数字特征和字符串独热全部串联在一起:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SjnmU8PSv8T3"
},
"outputs": [],
"source": [
"features_dict = {name:values[:1] for name, values in titanic_features_dict.items()}\n",
"titanic_preprocessing(features_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qkBf4LvmzMDp"
},
"source": [
"接下来,在此基础上构建模型:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "coIPtGaCzUV7"
},
"outputs": [],
"source": [
"def titanic_model(preprocessing_head, inputs):\n",
" body = tf.keras.Sequential([\n",
" layers.Dense(64, activation='relu'),\n",
" layers.Dense(1)\n",
" ])\n",
"\n",
" preprocessed_inputs = preprocessing_head(inputs)\n",
" result = body(preprocessed_inputs)\n",
" model = tf.keras.Model(inputs, result)\n",
"\n",
" model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
" optimizer=tf.keras.optimizers.Adam())\n",
" return model\n",
"\n",
"titanic_model = titanic_model(titanic_preprocessing, inputs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LK5uBQQF2KbZ"
},
"source": [
"训练模型时,将特征字典作为 `x` 传递,将标签作为 `y` 传递。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "D1gVfwJ61ejz"
},
"outputs": [],
"source": [
"titanic_model.fit(x=titanic_features_dict, y=titanic_labels, epochs=10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LxgJarZk3bfH"
},
"source": [
"由于预处理是模型的一部分,您可以保存模型并将其重新加载到其他地方并获得相同的结果:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ay-8ymNA2ZCh"
},
"outputs": [],
"source": [
"titanic_model.save('test.keras')\n",
"reloaded = tf.keras.models.load_model('test.keras')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Qm6jMTpD20lK"
},
"outputs": [],
"source": [
"features_dict = {name:values[:1] for name, values in titanic_features_dict.items()}\n",
"\n",
"before = titanic_model(features_dict)\n",
"after = reloaded(features_dict)\n",
"assert (before-after)<1e-3\n",
"print(before)\n",
"print(after)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7VsPlxIRZpXf"
},
"source": [
"## 使用 tf.data\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NyVDCwGzR5HW"
},
"source": [
"在前一部分中,您在训练模型时依赖了模型的内置数据乱序和批处理。\n",
"\n",
"如果您需要对输入数据流水线进行更多控制或需要使用不易放入内存的数据:请使用 `tf.data`。\n",
"\n",
"有关更多示例,请参阅 [`tf.data`:构建 TensorFlow 输入流水线](../../guide/data.ipynb)指南。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gP5Y1jM2Sor0"
},
"source": [
"### 有关内存数据\n",
"\n",
"作为将 `tf.data` 应用于 CSV 数据的第一个样本,请考虑使用以下代码手动切分上一个部分中的特征字典。对于每个索引,它会为每个特征获取该索引:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "i8wE-MVuVu7_"
},
"outputs": [],
"source": [
"import itertools\n",
"\n",
"def slices(features):\n",
" for i in itertools.count():\n",
" # For each feature take index `i`\n",
" example = {name:values[i] for name, values in features.items()}\n",
" yield example"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cQ3RTbS9YEal"
},
"source": [
"运行此代码并打印第一个样本:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Wwq8XK88WwFk"
},
"outputs": [],
"source": [
"for example in slices(titanic_features_dict):\n",
" for name, value in example.items():\n",
" print(f\"{name:19s}: {value}\")\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vvp8Dct6YOIE"
},
"source": [
"内存数据加载程序中最基本的 `tf.data.Dataset` 是 `Dataset.from_tensor_slices` 构造函数。这会返回一个 `tf.data.Dataset`,它将在 TensorFlow 中实现上述 `slices` 函数的泛化版本。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2gEJthslYxeV"
},
"outputs": [],
"source": [
"features_ds = tf.data.Dataset.from_tensor_slices(titanic_features_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-ZC0rTpMZMZK"
},
"source": [
"您可以像任何其他 Python 可迭代对象一样迭代 `tf.data.Dataset`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gOHbiefaY4ag"
},
"outputs": [],
"source": [
"for example in features_ds:\n",
" for name, value in example.items():\n",
" print(f\"{name:19s}: {value}\")\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uwcFoVJWZY5F"
},
"source": [
"`from_tensor_slices` 函数可以处理嵌套字典或元组的任何结构。以下代码创建了一个 `(features_dict, labels)` 对的数据集:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xIHGBy76Zcrx"
},
"outputs": [],
"source": [
"titanic_ds = tf.data.Dataset.from_tensor_slices((titanic_features_dict, titanic_labels))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gQwxitt8c2GK"
},
"source": [
"要使用此 `Dataset` 训练模型,您至少需要对数据进行 `shuffle` 和 `batch`。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SbJcbldhddeC"
},
"outputs": [],
"source": [
"titanic_batches = titanic_ds.shuffle(len(titanic_labels)).batch(32)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-4FRqhRFuoJx"
},
"source": [
"不是将 `features` 和 `labels` 传递给 `Model.fit`,而是传递数据集:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8yXkNPumdBtB"
},
"outputs": [],
"source": [
"titanic_model.fit(titanic_batches, epochs=5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qXuibiv9exT7"
},
"source": [
"### 从单个文件\n",
"\n",
"到目前为止,本教程已经使用了内存数据。`tf.data` 是用于构建数据流水线的高度可扩展的工具包,并提供了一些用于处理加载 CSV 文件的函数。 "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ncf5t6tgL5ZI"
},
"outputs": [],
"source": [
"titanic_file_path = tf.keras.utils.get_file(\"train.csv\", \"https://storage.googleapis.com/tf-datasets/titanic/train.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t4N-plO4tDXd"
},
"source": [
"现在,从文件中读取 CSV 数据并创建一个 `tf.data.Dataset`。\n",
"\n",
"(有关完整文档,请参阅 `tf.data.experimental.make_csv_dataset`)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yIbUscB9sqha"
},
"outputs": [],
"source": [
"titanic_csv_ds = tf.data.experimental.make_csv_dataset(\n",
" titanic_file_path,\n",
" batch_size=5, # Artificially small to make examples easier to show.\n",
" label_name='survived',\n",
" num_epochs=1,\n",
" ignore_errors=True,)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Sf3v3BKgy4AG"
},
"source": [
"此函数包括许多方便的功能,因此很容易处理数据。这包括:\n",
"\n",
"- 使用列标题作为字典键。\n",
"- 自动确定每列的类型。\n",
"\n",
"小心:请确保在 `tf.data.experimental.make_csv_dataset` 中设置 `num_epochs` 参数,否则 `tf.data.Dataset` 的默认行为是无限循环。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v4oMO9MIxgTG"
},
"outputs": [],
"source": [
"for batch, label in titanic_csv_ds.take(1):\n",
" for key, value in batch.items():\n",
" print(f\"{key:20s}: {value}\")\n",
" print()\n",
" print(f\"{'label':20s}: {label}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k-TgA6o2Ja6U"
},
"source": [
"注:如果您运行两次上述代码单元,它将产生不同的结果。`tf.data.experimental.make_csv_dataset` 的默认设置包括 `shuffle_buffer_size=1000`,这对于这个小型数据集来说已经绰绰有余,但可能不适用于实际的数据集。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d6uviU_KCCWD"
},
"source": [
"它还可以对数据进行即时解压。下面是一个用 gzip 压缩的 CSV 文件,其中包含 [Metro Interstate Traffic Dataset](https://archive.ics.uci.edu/ml/datasets/Metro+Interstate+Traffic+Volume)。\n",
"\n",
"\n",
"\n",
"图片[来自 Wikimedia](https://commons.wikimedia.org/wiki/File:Trafficjam.jpg)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kT7oZI2E46Q8"
},
"outputs": [],
"source": [
"traffic_volume_csv_gz = tf.keras.utils.get_file(\n",
" 'Metro_Interstate_Traffic_Volume.csv.gz', \n",
" \"https://archive.ics.uci.edu/ml/machine-learning-databases/00492/Metro_Interstate_Traffic_Volume.csv.gz\",\n",
" cache_dir='.', cache_subdir='traffic')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "F-IOsFHbCw0i"
},
"source": [
"将 `compression_type` 参数设置为直接从压缩文件中读取:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ar0MPEVJ5NeA"
},
"outputs": [],
"source": [
"traffic_volume_csv_gz_ds = tf.data.experimental.make_csv_dataset(\n",
" traffic_volume_csv_gz,\n",
" batch_size=256,\n",
" label_name='traffic_volume',\n",
" num_epochs=1,\n",
" compression_type=\"GZIP\")\n",
"\n",
"for batch, label in traffic_volume_csv_gz_ds.take(1):\n",
" for key, value in batch.items():\n",
" print(f\"{key:20s}: {value[:5]}\")\n",
" print()\n",
" print(f\"{'label':20s}: {label[:5]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "p12Y6tGq8D6M"
},
"source": [
"注:如果需要在 `tf.data` 流水线中解析这些日期时间字符串,您可以使用 `tfa.text.parse_time`。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EtrAXzYGP3l0"
},
"source": [
"### 缓存"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fN2dL_LRP83r"
},
"source": [
"解析 CSV 数据有一些开销。对于小型模型,这可能是训练的瓶颈。\n",
"\n",
"根据您的用例,使用 `Dataset.cache` 或 `tf.data.Dataset.snapshot` 可能是个好主意,这样 CSV 数据仅会在第一个周期进行解析。\n",
"\n",
"`cache` 和 `snapshot` 方法的主要区别在于 `cache` 文件只能由创建它们的 TensorFlow 进程使用,而 `snapshot` 文件可以被其他进程读取。\n",
"\n",
"例如,在没有缓存的情况下迭代 `traffic_volume_csv_gz_ds` 20 次可能需要大约 15 秒,而使用缓存大约需要 2 秒。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Qk38Sw4MO4eh"
},
"outputs": [],
"source": [
"%%time\n",
"for i, (batch, label) in enumerate(traffic_volume_csv_gz_ds.repeat(20)):\n",
" if i % 40 == 0:\n",
" print('.', end='')\n",
"print()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pN3HtDONh5TX"
},
"source": [
"注:`Dataset.cache` 会存储第一个周期的数据并按顺序回放。因此,使用 `cache` 方法会停用流水线中较早的任何重排。下面,在 `Dataset.cache` 之后重新添加了 `Dataset.shuffle`。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "r5Jj72MrPbnh"
},
"outputs": [],
"source": [
"%%time\n",
"caching = traffic_volume_csv_gz_ds.cache().shuffle(1000)\n",
"\n",
"for i, (batch, label) in enumerate(caching.shuffle(1000).repeat(20)):\n",
" if i % 40 == 0:\n",
" print('.', end='')\n",
"print()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wN7uUBjmgNZ9"
},
"source": [
"注:`tf.data.Dataset.snapshot` 文件用于在使用时*临时*存储数据集。这*不是*长期存储的格式。文件格式被视为内部详细信息,无法在 TensorFlow 各版本之间保证。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PHGD1E8ktUvW"
},
"outputs": [],
"source": [
"%%time\n",
"snapshotting = traffic_volume_csv_gz_ds.snapshot('titanic.tfsnap').shuffle(1000)\n",
"\n",
"for i, (batch, label) in enumerate(snapshotting.shuffle(1000).repeat(20)):\n",
" if i % 40 == 0:\n",
" print('.', end='')\n",
"print()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fUSSegnMCGRz"
},
"source": [
"如果加载 CSV 文件减慢了数据加载速度,并且 `Dataset.cache` 和 `tf.data.Dataset.snapshot` 不足以满足您的用例,请考虑将数据重新编码为更简化的格式。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M0iGXv9pC5kr"
},
"source": [
"### 多个文件"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9FFzHQrCDH4w"
},
"source": [
"到目前为止,本部分中的所有示例都可以在没有 `tf.data` 的情况下轻松完成。处理文件集合时,`tf.data` 可以真正简化事情。\n",
"\n",
"例如,将 [Character Font Images](https://archive.ics.uci.edu/ml/datasets/Character+Font+Images) 数据集作为 CSV 文件的集合分发,每种字体一个集合。\n",
"\n",
"\n",
"\n",
"图像作者:Willi Heidelbach,来源:Pixabay\n",
"\n",
"下载数据集,并检查里面的文件:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RmVknMdJh5ks"
},
"outputs": [],
"source": [
"fonts_zip = tf.keras.utils.get_file(\n",
" 'fonts.zip', \"https://archive.ics.uci.edu/ml/machine-learning-databases/00417/fonts.zip\",\n",
" cache_dir='.', cache_subdir='fonts',\n",
" extract=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xsDlMCnyi55e"
},
"outputs": [],
"source": [
"import pathlib\n",
"font_csvs = sorted(str(p) for p in pathlib.Path('fonts').glob(\"*.csv\"))\n",
"\n",
"font_csvs[:10]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lRAEJx9ROAGl"
},
"outputs": [],
"source": [
"len(font_csvs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "19Udrw9iG-FS"
},
"source": [
"在处理一堆文件时,可以将 glob 样式的 `file_pattern` 传递给 `tf.data.experimental.make_csv_dataset` 函数。每次迭代都会重排文件的顺序。\n",
"\n",
"使用 `num_parallel_reads` 参数设置并行读取并交错在一起的文件数量。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6TSUNdT6iG58"
},
"outputs": [],
"source": [
"fonts_ds = tf.data.experimental.make_csv_dataset(\n",
" file_pattern = \"fonts/*.csv\",\n",
" batch_size=10, num_epochs=1,\n",
" num_parallel_reads=20,\n",
" shuffle_buffer_size=10000)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XMoexinLHYFa"
},
"source": [
"这些 CSV 文件会将图像展平成一行。列名的格式为 `r{row}c{column}`。下面是第一个批次:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RmFvBWxxi3pq"
},
"outputs": [],
"source": [
"for features in fonts_ds.take(1):\n",
" for i, (name, value) in enumerate(features.items()):\n",
" if i>15:\n",
" break\n",
" print(f\"{name:20s}: {value}\")\n",
"print('...')\n",
"print(f\"[total: {len(features)} features]\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xrC3sKdeOhb5"
},
"source": [
"#### 可选:打包字段\n",
"\n",
"您可能不想像这样在单独的列中处理每个像素。在尝试使用此数据集之前,请务必将像素打包到图像张量中。\n",
"\n",
"下面是解析列名,从而为每个示例构建图像的代码:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hct5EMEWNyfH"
},
"outputs": [],
"source": [
"import re\n",
"\n",
"def make_images(features):\n",
" image = [None]*400\n",
" new_feats = {}\n",
"\n",
" for name, value in features.items():\n",
" match = re.match('r(\\d+)c(\\d+)', name)\n",
" if match:\n",
" image[int(match.group(1))*20+int(match.group(2))] = value\n",
" else:\n",
" new_feats[name] = value\n",
"\n",
" image = tf.stack(image, axis=0)\n",
" image = tf.reshape(image, [20, 20, -1])\n",
" new_feats['image'] = image\n",
"\n",
" return new_feats"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "61qy8utAwARP"
},
"source": [
"将该函数应用于数据集中的每个批次:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DJnnfIW9baE4"
},
"outputs": [],
"source": [
"fonts_image_ds = fonts_ds.map(make_images)\n",
"\n",
"for features in fonts_image_ds.take(1):\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_ThqrthGwHSm"
},
"source": [
"绘制生成的图像:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I5dcey31T_tk"
},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt\n",
"\n",
"plt.figure(figsize=(6,6), dpi=120)\n",
"\n",
"for n in range(9):\n",
" plt.subplot(3,3,n+1)\n",
" plt.imshow(features['image'][..., n])\n",
" plt.title(chr(features['m_label'][n]))\n",
" plt.axis('off')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7-nNR0Nncdd1"
},
"source": [
"## 低级函数"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3jiGZeUijJNd"
},
"source": [
"到目前为止,本教程重点介绍了用于读取 CSV 数据的最高级别效用函数。如果您的用例不符合基本模式,还有其他两个 API 可能对高级用户有所帮助。\n",
"\n",
"- `tf.io.decode_csv`:用于将文本行解析为 CSV 列张量列表的函数。\n",
"- `tf.data.experimental.CsvDataset`:较低级别的 CSV 数据集构造函数。\n",
"\n",
"本部分会重新创建 `tf.data.experimental.make_csv_dataset` 提供的功能,以演示如何使用此较低级别的功能。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LL_ixywomOHW"
},
"source": [
"### `tf.io.decode_csv`\n",
"\n",
"此函数会将字符串或字符串列表解码为列列表。\n",
"\n",
"与 `tf.data.experimental.make_csv_dataset` 不同,此函数不会尝试猜测列数据类型。您可以通过为每列提供包含正确类型值的记录 `record_defaults` 值列表来指定列类型。\n",
"\n",
"要使用 tf.io.decode_csv
将 Titanic 数据作为字符串读取,您可以使用以下代码:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m1D2C-qdlqeW"
},
"outputs": [],
"source": [
"text = pathlib.Path(titanic_file_path).read_text()\n",
"lines = text.split('\\n')[1:-1]\n",
"\n",
"all_strings = [str()]*10\n",
"all_strings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9W4UeJYyHPx5"
},
"outputs": [],
"source": [
"features = tf.io.decode_csv(lines, record_defaults=all_strings) \n",
"\n",
"for f in features:\n",
" print(f\"type: {f.dtype.name}, shape: {f.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j8TaHSQFoQL4"
},
"source": [
"要使用实际类型解析它们,请创建相应类型的 `record_defaults` 列表: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rzUjR59yoUe1"
},
"outputs": [],
"source": [
"print(lines[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7sPTunxwoeWU"
},
"outputs": [],
"source": [
"titanic_types = [int(), str(), float(), int(), int(), float(), str(), str(), str(), str()]\n",
"titanic_types"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "n3NlViCzoB7F"
},
"outputs": [],
"source": [
"features = tf.io.decode_csv(lines, record_defaults=titanic_types) \n",
"\n",
"for f in features:\n",
" print(f\"type: {f.dtype.name}, shape: {f.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m-LkTUTnpn2P"
},
"source": [
"注:在大批量行上调用 `tf.io.decode_csv` 比在单个 CSV 文本行上调用更有效。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Yp1UItJmqGqw"
},
"source": [
"### `tf.data.experimental.CsvDataset`\n",
"\n",
"`tf.data.experimental.CsvDataset` 类提供了一个最小的 CSV `Dataset` 接口,没有 `tf.data.experimental.make_csv_dataset` 函数的便利功能:列标题解析、列类型推断、自动重排、文件交错。\n",
"\n",
"此构造函数使用 `record_defaults` 的方式与 `tf.io.decode_csv` 相同:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9OzZLp3krP-t"
},
"outputs": [],
"source": [
"simple_titanic = tf.data.experimental.CsvDataset(titanic_file_path, record_defaults=titanic_types, header=True)\n",
"\n",
"for example in simple_titanic.take(1):\n",
" print([e.numpy() for e in example])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_HBmfI-Ks7dw"
},
"source": [
"上面的代码基本等价于:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "E5O5d69Yq7gG"
},
"outputs": [],
"source": [
"def decode_titanic_line(line):\n",
" return tf.io.decode_csv(line, titanic_types)\n",
"\n",
"manual_titanic = (\n",
" # Load the lines of text\n",
" tf.data.TextLineDataset(titanic_file_path)\n",
" # Skip the header row.\n",
" .skip(1)\n",
" # Decode the line.\n",
" .map(decode_titanic_line)\n",
")\n",
"\n",
"for example in manual_titanic.take(1):\n",
" print([e.numpy() for e in example])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5R3ralsnt2AC"
},
"source": [
"#### 多个文件\n",
"\n",
"要使用 `tf.data.experimental.CsvDataset` 解析字体数据集,您首先需要确定 `record_defaults` 的列类型。首先检查一个文件的第一行:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3tlFOTjCvAI5"
},
"outputs": [],
"source": [
"font_line = pathlib.Path(font_csvs[0]).read_text().splitlines()[1]\n",
"print(font_line)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "etyGu8K_ySRz"
},
"source": [
"只有前两个字段是字符串,其余的都是整数或浮点数,通过计算逗号的个数可以得到特征总数:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "crgZZn0BzkSB"
},
"outputs": [],
"source": [
"num_font_features = font_line.count(',')+1\n",
"font_column_types = [str(), str()] + [float()]*(num_font_features-2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YeK2Pw540RNj"
},
"source": [
"`tf.data.experimental.CsvDataset` 构造函数可以获取输入文件列表,但会按顺序读取它们。CSV 列表中的第一个文件是 `AGENCY.csv`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_SvL5Uvl0r0N"
},
"outputs": [],
"source": [
"font_csvs[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EfAX3G8Xywy6"
},
"source": [
"因此,当您将文件列表传递给 `CsvDataset` 时,会首先读取 `AGENCY.csv` 中的记录:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Gtr1E66VmBqj"
},
"outputs": [],
"source": [
"simple_font_ds = tf.data.experimental.CsvDataset(\n",
" font_csvs, \n",
" record_defaults=font_column_types, \n",
" header=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "k750Mgq4yt_o"
},
"outputs": [],
"source": [
"for row in simple_font_ds.take(10):\n",
" print(row[0].numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NiqWKQV21FrE"
},
"source": [
"要交错多个文件,请使用 `Dataset.interleave`。\n",
"\n",
"这是一个包含 CSV 文件名的初始数据集: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t9dS3SNb23W8"
},
"outputs": [],
"source": [
"font_files = tf.data.Dataset.list_files(\"fonts/*.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TNiLHMXpzHy5"
},
"source": [
"这会在每个周期重排文件名:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zNd-TYyNzIgg"
},
"outputs": [],
"source": [
"print('Epoch 1:')\n",
"for f in list(font_files)[:5]:\n",
" print(\" \", f.numpy())\n",
"print(' ...')\n",
"print()\n",
"\n",
"print('Epoch 2:')\n",
"for f in list(font_files)[:5]:\n",
" print(\" \", f.numpy())\n",
"print(' ...')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B0QB1PtU3WAN"
},
"source": [
"`interleave` 方法采用 `map_func`,它会为父 `Dataset`的每个元素创建一个子 `Dataset`。\n",
"\n",
"在这里,您要从文件数据集的每个元素创建一个 `tf.data.experimental.CsvDataset`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QWp4rH0Q4uPh"
},
"outputs": [],
"source": [
"def make_font_csv_ds(path):\n",
" return tf.data.experimental.CsvDataset(\n",
" path, \n",
" record_defaults=font_column_types, \n",
" header=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VxRGdLMB5nRF"
},
"source": [
"交错返回的 `Dataset` 通过循环遍历多个子 `Dataset` 来返回元素。请注意,下面的数据集如何在 `cycle_length=3` 三个字体文件中循环:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OePMNF_x1_Cc"
},
"outputs": [],
"source": [
"font_rows = font_files.interleave(make_font_csv_ds,\n",
" cycle_length=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UORIGWLy54-E"
},
"outputs": [],
"source": [
"fonts_dict = {'font_name':[], 'character':[]}\n",
"\n",
"for row in font_rows.take(10):\n",
" fonts_dict['font_name'].append(row[0].numpy().decode())\n",
" fonts_dict['character'].append(chr(row[2].numpy()))\n",
"\n",
"pd.DataFrame(fonts_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mkKZa_HX8zAm"
},
"source": [
"#### 性能\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8BtGHraUApdJ"
},
"source": [
"早些时候,有人注意到 `tf.io.decode_csv` 在一个批次字符串上运行时效率更高。\n",
"\n",
"当使用大批次时,可以利用这一事实来提高 CSV 加载性能(但请先尝试使用[缓存](#caching))。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d35zWMH7MDL1"
},
"source": [
"使用内置加载器 20,2048 个样本批次大约需要 17 秒。 "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ieUVAPryjpJS"
},
"outputs": [],
"source": [
"BATCH_SIZE=2048\n",
"fonts_ds = tf.data.experimental.make_csv_dataset(\n",
" file_pattern = \"fonts/*.csv\",\n",
" batch_size=BATCH_SIZE, num_epochs=1,\n",
" num_parallel_reads=100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MUC2KW4LkQIz"
},
"outputs": [],
"source": [
"%%time\n",
"for i,batch in enumerate(fonts_ds.take(20)):\n",
" print('.',end='')\n",
"\n",
"print()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5lhnh6rZEDS2"
},
"source": [
"将**批量文本行**传递给 `decode_csv` 运行速度更快,大约需要 5 秒:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4XbPZV1okVF9"
},
"outputs": [],
"source": [
"fonts_files = tf.data.Dataset.list_files(\"fonts/*.csv\")\n",
"fonts_lines = fonts_files.interleave(\n",
" lambda fname:tf.data.TextLineDataset(fname).skip(1), \n",
" cycle_length=100).batch(BATCH_SIZE)\n",
"\n",
"fonts_fast = fonts_lines.map(lambda x: tf.io.decode_csv(x, record_defaults=font_column_types))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "te9C2km-qO8W"
},
"outputs": [],
"source": [
"%%time\n",
"for i,batch in enumerate(fonts_fast.take(20)):\n",
" print('.',end='')\n",
"\n",
"print()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aebC1plsMeOi"
},
"source": [
"有关通过使用大批次提高 CSV 性能的另一个示例,请参阅[过拟合和欠拟合教程](../keras/overfit_and_underfit.ipynb)。\n",
"\n",
"这种方式可能有效,但请考虑其他选项,例如 `Dataset.cache` 和 `tf.data.Dataset.snapshot`,或者将您的数据重新编码为更简化的格式。"
]
}
],
"metadata": {
"colab": {
"name": "csv.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "xxx",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}