{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "DweYe9FcbMK_" }, "outputs": [], "source": [ "##### Copyright 2018 The TensorFlow Authors.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "AVV2e0XKbJeX", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "sZfSvVcDo6GQ" }, "source": [ "# 加载文本" ] }, { "cell_type": "markdown", "metadata": { "id": "giK0nMbZFnoR" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "dwlfPb11GH8J" }, "source": [ "本教程演示了两种加载和预处理文本的方法。\n", "\n", "- 首先,您将使用 Keras 效用函数和预处理层。这包括用于将数据转换为 `tf.data.Dataset` 的 `tf.keras.utils.text_dataset_from_directory` 和用于数据标准化、词例化和向量化的 `tf.keras.layers.TextVectorization`。如果您是 TensorFlow 新手,则应当从这些开始。\n", "- 然后,您将使用 `tf.data.TextLineDataset` 等较低级别的效用函数来加载文本文件,并使用 [TensorFlow Text](https://tensorflow.google.cn/text) API(如 `text.UnicodeScriptTokenizer` 和 `text.case_fold_utf8`)来预处理数据以实现粒度更细的控制。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sa6IKWvADqH7", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "!pip install \"tensorflow-text==2.11.*\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "baYFZMW_bJHh", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import collections\n", "import pathlib\n", "\n", "import tensorflow as tf\n", "\n", "from tensorflow.keras import layers\n", "from tensorflow.keras import losses\n", "from tensorflow.keras import utils\n", "from tensorflow.keras.layers import TextVectorization\n", "\n", "import tensorflow_datasets as tfds\n", "import tensorflow_text as tf_text" ] }, { "cell_type": "markdown", "metadata": { "id": "Az-d_K5_HQ5k" }, "source": [ "## 示例 1:预测 Stack Overflow 问题的标签\n", "\n", "作为第一个示例,您将从 Stack Overflow 下载一个编程问题的数据集。每个问题(*“How do I sort a dictionary by value?”*)都会添加一个标签(`Python`、`CSharp`、`JavaScript` 或 `Java`)。您的任务是开发一个模型来预测问题的标签。这是多类分类的一个示例,多类分类是一种重要且广泛适用的机器学习问题。" ] }, { "cell_type": "markdown", "metadata": { "id": "tjC3yLa5IjP7" }, "source": [ "### 下载并探索数据集\n", "\n", "首先,使用 `tf.keras.utils.get_file` 下载 Stack Overflow 数据集,然后探索目录结构:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8ELgzA6SHTuV", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "data_url = 'https://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz'\n", "\n", "dataset_dir = utils.get_file(\n", " origin=data_url,\n", " untar=True,\n", " cache_dir='stack_overflow',\n", " cache_subdir='')\n", "\n", "dataset_dir = pathlib.Path(dataset_dir).parent" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jIrPl5fUH2gb", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "list(dataset_dir.iterdir())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fEoV7YByJoWQ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_dir = dataset_dir/'train'\n", "list(train_dir.iterdir())" ] }, { "cell_type": "markdown", "metadata": { "id": "3mxAN17MhEh0" }, "source": [ "`train/csharp`、`train/java`、`train/python` 和 `train/javascript` 目录包含许多文本文件,每个文件都是一个 Stack Overflow 问题。\n", "\n", "打印示例文件并检查数据:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Go1vTSGdJu08", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "sample_file = train_dir/'python/1755.txt'\n", "\n", "with open(sample_file) as f:\n", " print(f.read())" ] }, { "cell_type": "markdown", "metadata": { "id": "deWBTkpJiO7D" }, "source": [ "### 加载数据集\n", "\n", "接下来,您将从磁盘加载数据并将其准备成适合训练的格式。为此,您将使用 `tf.keras.utils.text_dataset_from_directory` 效用函数来创建带标签的 `tf.data.Dataset`。如果您是 `tf.data` 新手,它是用于构建输入流水线的强大工具集合。(要了解更多信息,请参阅 [tf.data:构建 TensorFlow 输入流水线](../../guide/data.ipynb)指南。)\n", "\n", "`tf.keras.utils.text_dataset_from_directory` API 需要如下目录结构:\n", "\n", "```\n", "train/\n", "...csharp/\n", "......1.txt\n", "......2.txt\n", "...java/\n", "......1.txt\n", "......2.txt\n", "...javascript/\n", "......1.txt\n", "......2.txt\n", "...python/\n", "......1.txt\n", "......2.txt\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "Dyl6JTAjlbQV" }, "source": [ "运行机器学习实验时,最佳做法是将数据集拆成三份:[训练](https://developers.google.com/machine-learning/glossary#training_set)、[验证](https://developers.google.com/machine-learning/glossary#validation_set)和[测试](https://developers.google.com/machine-learning/glossary#test-set)。\n", "\n", "Stack Overflow 数据集已经拆分为训练集和测试集,但缺少验证集。\n", "\n", "通过使用 `tf.keras.utils.text_dataset_from_directory` 并将 `validation_split` 设置为 `0.2`(即 20%),使用训练数据的 80:20 拆分创建验证集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qqyliMw8N-az", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "batch_size = 32\n", "seed = 42\n", "\n", "raw_train_ds = utils.text_dataset_from_directory(\n", " train_dir,\n", " batch_size=batch_size,\n", " validation_split=0.2,\n", " subset='training',\n", " seed=seed)" ] }, { "cell_type": "markdown", "metadata": { "id": "DMI_gPLfloD7" }, "source": [ "正如前面的单元输出所示,训练文件夹中有 8,000 个样本,您将使用其中的 80%(即 6,400 个)进行训练。稍后您将学习到,可以通过将 `tf.data.Dataset` 直接传递给 `Model.fit` 来训练模型。\n", "\n", "首先,遍历数据集并打印出一些样本来感受一下数据。\n", "\n", "注:为了增加分类问题的难度,数据集作者将编程问题中出现的单词 *Python*、*CSharp*、*JavaScript* 或 *Java* 替换为 *blank* 一词。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_JMTyZ6Glt_C", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for text_batch, label_batch in raw_train_ds.take(1):\n", " for i in range(10):\n", " print(\"Question: \", text_batch.numpy()[i])\n", " print(\"Label:\", label_batch.numpy()[i])" ] }, { "cell_type": "markdown", "metadata": { "id": "jCZGl4Q5l2sS" }, "source": [ "标签为 `0`、`1`、`2` 或 `3`。要查看其中哪些对应于哪个字符串标签,可以检查数据集上的 `class_names` 属性:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gIpCS7YjmGkj", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for i, label in enumerate(raw_train_ds.class_names):\n", " print(\"Label\", i, \"corresponds to\", label)" ] }, { "cell_type": "markdown", "metadata": { "id": "oUsdn-37qol9" }, "source": [ "接下来,您将使用 `tf.keras.utils.text_dataset_from_directory` 创建验证集和测试集。您将使用训练集中剩余的 1,600 条评论进行验证。\n", "\n", "注:使用 `tf.keras.utils.text_dataset_from_directory` 的 `validation_split` 和 `subset` 参数时,请确保要么指定随机种子,要么传递 `shuffle=False`,这样验证拆分和训练拆分就不会重叠。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x7m6sCWJQuYt", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Create a validation set.\n", "raw_val_ds = utils.text_dataset_from_directory(\n", " train_dir,\n", " batch_size=batch_size,\n", " validation_split=0.2,\n", " subset='validation',\n", " seed=seed)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BXMZc7fMQwKE", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "test_dir = dataset_dir/'test'\n", "\n", "# Create a test set.\n", "raw_test_ds = utils.text_dataset_from_directory(\n", " test_dir,\n", " batch_size=batch_size)" ] }, { "cell_type": "markdown", "metadata": { "id": "Xdt-ATrGRGDL" }, "source": [ "### 准备用于训练的数据集" ] }, { "cell_type": "markdown", "metadata": { "id": "N6fRti45Rlj8" }, "source": [ "接下来,您将使用 `tf.keras.layers.TextVectorization` 层对数据进行标准化、词例化和向量化。\n", "\n", "- *标准化*是指预处理文本,通常是移除标点符号或 HTML 元素以简化数据集。\n", "- *词例化*是指将字符串拆分为词例(例如,通过按空格分割将一个句子拆分为各个单词)。\n", "- *向量化*是指将词例转换为编号,以便将它们输入到神经网络中。\n", "\n", "所有这些任务都可以通过这一层来完成。(您可以在 `tf.keras.layers.TextVectorization` API 文档中了解有关这些内容的更多信息。)\n", "\n", "请注意:\n", "\n", "- 默认标准化会将文本转换为小写并移除标点符号 (`standardize='lower_and_strip_punctuation'`)。\n", "- 默认分词器会按空格分割 (`split='whitespace'`)。\n", "- 默认向量化模式为 `'int'` (`output_mode='int'`)。这会输出整数索引(每个词例一个)。此模式可用于构建考虑词序的模型。您还可以使用其他模式(例如 `'binary'`)来构建[词袋](https://developers.google.com/machine-learning/glossary#bag-of-words)模型。\n", "\n", "您将使用 `TextVectorization` 构建两个模型来详细了解标准化、词例化和向量化:\n", "\n", "- 首先,您将使用 `'binary'` 向量化模式来构建词袋模型。\n", "- 随后,您将使用具有 1D ConvNet 的 `'int'` 模式。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "voaC43rZR0jc", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "VOCAB_SIZE = 10000\n", "\n", "binary_vectorize_layer = TextVectorization(\n", " max_tokens=VOCAB_SIZE,\n", " output_mode='binary')" ] }, { "cell_type": "markdown", "metadata": { "id": "ifDPFxuf2Hfz" }, "source": [ "对于 `'int'` 模式,除了最大词汇量之外,您还需要设置显式最大序列长度 (`MAX_SEQUENCE_LENGTH`),这会导致层将序列精确地填充或截断为 `output_sequence_length` 值:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XWsY01Zl2aRe", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "MAX_SEQUENCE_LENGTH = 250\n", "\n", "int_vectorize_layer = TextVectorization(\n", " max_tokens=VOCAB_SIZE,\n", " output_mode='int',\n", " output_sequence_length=MAX_SEQUENCE_LENGTH)" ] }, { "cell_type": "markdown", "metadata": { "id": "ts6h9b5atD-Y" }, "source": [ "接下来,调用 `TextVectorization.adapt` 以使预处理层的状态适合数据集。这会使模型构建字符串到整数的索引。\n", "\n", "注:在调用 `TextVectorization.adapt` 时请务必仅使用您的训练数据(使用测试集会泄漏信息)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yTXsdDEqSf9e", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Make a text-only dataset (without labels), then call `TextVectorization.adapt`.\n", "train_text = raw_train_ds.map(lambda text, labels: text)\n", "binary_vectorize_layer.adapt(train_text)\n", "int_vectorize_layer.adapt(train_text)" ] }, { "cell_type": "markdown", "metadata": { "id": "XKVO6Jg7Sls0" }, "source": [ "打印使用这些层预处理数据的结果:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RngfPyArSsvM", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def binary_vectorize_text(text, label):\n", " text = tf.expand_dims(text, -1)\n", " return binary_vectorize_layer(text), label" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_1W54wf0LhQ0", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def int_vectorize_text(text, label):\n", " text = tf.expand_dims(text, -1)\n", " return int_vectorize_layer(text), label" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vi_sElMiSmXe", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Retrieve a batch (of 32 reviews and labels) from the dataset.\n", "text_batch, label_batch = next(iter(raw_train_ds))\n", "first_question, first_label = text_batch[0], label_batch[0]\n", "print(\"Question\", first_question)\n", "print(\"Label\", first_label)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UGukZoYv2v3v", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "print(\"'binary' vectorized question:\",\n", " binary_vectorize_text(first_question, first_label)[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Lu07FsIw2yH5", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "print(\"'int' vectorized question:\",\n", " int_vectorize_text(first_question, first_label)[0])" ] }, { "cell_type": "markdown", "metadata": { "id": "wgjeF9PdS7tN" }, "source": [ "如上所示,`TextVectorization` 的 `'binary'` 模式返回一个数组,表示哪些词例在输入中至少存在一次,而 `'int'` 模式将每个词例替换为一个整数,从而保留它们的顺序。\n", "\n", "您可以通过在层上调用 `TextVectorization.get_vocabulary` 来查找每个整数对应的词例(字符串):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WpBnTZilS8wt", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "print(\"1289 ---> \", int_vectorize_layer.get_vocabulary()[1289])\n", "print(\"313 ---> \", int_vectorize_layer.get_vocabulary()[313])\n", "print(\"Vocabulary size: {}\".format(len(int_vectorize_layer.get_vocabulary())))" ] }, { "cell_type": "markdown", "metadata": { "id": "0kHgPE_YwHvp" }, "source": [ "差不多可以训练您的模型了。\n", "\n", "作为最后的预处理步骤,将之前创建的 `TextVectorization` 层应用于训练集、验证集和测试集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "46LeHmnD55wJ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "binary_train_ds = raw_train_ds.map(binary_vectorize_text)\n", "binary_val_ds = raw_val_ds.map(binary_vectorize_text)\n", "binary_test_ds = raw_test_ds.map(binary_vectorize_text)\n", "\n", "int_train_ds = raw_train_ds.map(int_vectorize_text)\n", "int_val_ds = raw_val_ds.map(int_vectorize_text)\n", "int_test_ds = raw_test_ds.map(int_vectorize_text)" ] }, { "cell_type": "markdown", "metadata": { "id": "NHuAF8hYfP5Z" }, "source": [ "### 配置数据集以提高性能\n", "\n", "以下是加载数据时应该使用的两种重要方法,以确保 I/O 不会阻塞。\n", "\n", "- 从磁盘加载后,`Dataset.cache` 会将数据保存在内存中。这将确保数据集在训练模型时不会成为瓶颈。如果您的数据集太大而无法放入内存,也可以使用此方法创建高性能的磁盘缓存,这比许多小文件的读取效率更高。\n", "- `Dataset.prefetch` 会在训练时将数据预处理和模型执行重叠。\n", "\n", "您可以在[使用 tf.data API 提升性能](../../guide/data_performance.ipynb)指南的*预提取*部分中详细了解这两种方法,以及如何将数据缓存到磁盘。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PabA9DFIfSz7", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "AUTOTUNE = tf.data.AUTOTUNE\n", "\n", "def configure_dataset(dataset):\n", " return dataset.cache().prefetch(buffer_size=AUTOTUNE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J8GcJLvb3JH0", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "binary_train_ds = configure_dataset(binary_train_ds)\n", "binary_val_ds = configure_dataset(binary_val_ds)\n", "binary_test_ds = configure_dataset(binary_test_ds)\n", "\n", "int_train_ds = configure_dataset(int_train_ds)\n", "int_val_ds = configure_dataset(int_val_ds)\n", "int_test_ds = configure_dataset(int_test_ds)" ] }, { "cell_type": "markdown", "metadata": { "id": "NYGb7z_bfpGm" }, "source": [ "### 训练模型。\n", "\n", "是时候创建您的神经网络了。\n", "\n", "对于 `'binary'` 向量化数据,定义一个简单的词袋线性模型,然后对其进行配置和训练:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2q8iAU-VMzaN", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "binary_model = tf.keras.Sequential([layers.Dense(4)])\n", "\n", "binary_model.compile(\n", " loss=losses.SparseCategoricalCrossentropy(from_logits=True),\n", " optimizer='adam',\n", " metrics=['accuracy'])\n", "\n", "history = binary_model.fit(\n", " binary_train_ds, validation_data=binary_val_ds, epochs=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "EwidD-SwNIkz" }, "source": [ "接下来,您将使用 `'int'` 向量化层来构建 1D ConvNet:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5ztw2XH_LbVz", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def create_model(vocab_size, num_labels):\n", " model = tf.keras.Sequential([\n", " layers.Embedding(vocab_size, 64, mask_zero=True),\n", " layers.Conv1D(64, 5, padding=\"valid\", activation=\"relu\", strides=2),\n", " layers.GlobalMaxPooling1D(),\n", " layers.Dense(num_labels)\n", " ])\n", " return model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "s9rG1cFRL31Z", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# `vocab_size` is `VOCAB_SIZE + 1` since `0` is used additionally for padding.\n", "int_model = create_model(vocab_size=VOCAB_SIZE + 1, num_labels=4)\n", "int_model.compile(\n", " loss=losses.SparseCategoricalCrossentropy(from_logits=True),\n", " optimizer='adam',\n", " metrics=['accuracy'])\n", "history = int_model.fit(int_train_ds, validation_data=int_val_ds, epochs=5)" ] }, { "cell_type": "markdown", "metadata": { "id": "x3J9Eeuv97zE" }, "source": [ "比较两个模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N8ViDXw99v_u", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "print(\"Linear model on binary vectorized data:\")\n", "print(binary_model.summary())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P9BOeoCwborD", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "print(\"ConvNet model on int vectorized data:\")\n", "print(int_model.summary())" ] }, { "cell_type": "markdown", "metadata": { "id": "zYYW9tUdCtTy" }, "source": [ "在测试数据上评估两个模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5dTc4nZqf7fK", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "binary_loss, binary_accuracy = binary_model.evaluate(binary_test_ds)\n", "int_loss, int_accuracy = int_model.evaluate(int_test_ds)\n", "\n", "print(\"Binary model accuracy: {:2.2%}\".format(binary_accuracy))\n", "print(\"Int model accuracy: {:2.2%}\".format(int_accuracy))" ] }, { "cell_type": "markdown", "metadata": { "id": "F9dhj8Hey9DS" }, "source": [ "注:此示例数据集代表了一个相当简单的分类问题。更复杂的数据集和问题会在预处理策略和模型架构上带来微妙但显著的差异。务必尝试不同的超参数和周期来比较各种方法。" ] }, { "cell_type": "markdown", "metadata": { "id": "h9GaXTsIgP-3" }, "source": [ "### 导出模型\n", "\n", "在上面的代码中,您在向模型馈送文本之前对数据集应用了 `tf.keras.layers.TextVectorization`。如果您想让模型能够处理原始字符串(例如,为了简化部署),您可以在模型中包含 `TextVectorization` 层。\n", "\n", "为此,您可以使用刚刚训练的权重创建一个新模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_bRe3KX8gRCX", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "export_model = tf.keras.Sequential(\n", " [binary_vectorize_layer, binary_model,\n", " layers.Activation('sigmoid')])\n", "\n", "export_model.compile(\n", " loss=losses.SparseCategoricalCrossentropy(from_logits=False),\n", " optimizer='adam',\n", " metrics=['accuracy'])\n", "\n", "# Test it with `raw_test_ds`, which yields raw strings\n", "loss, accuracy = export_model.evaluate(raw_test_ds)\n", "print(\"Accuracy: {:2.2%}\".format(accuracy))" ] }, { "cell_type": "markdown", "metadata": { "id": "m2eqTVBP4DUN" }, "source": [ "现在,您的模型可以将原始字符串作为输入,并使用 `Model.predict` 预测每个标签的得分。定义一个函数来查找得分最高的标签:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GU53uRXz45iO", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def get_string_labels(predicted_scores_batch):\n", " predicted_int_labels = tf.math.argmax(predicted_scores_batch, axis=1)\n", " predicted_labels = tf.gather(raw_train_ds.class_names, predicted_int_labels)\n", " return predicted_labels" ] }, { "cell_type": "markdown", "metadata": { "id": "yqnWc7Nn5eou" }, "source": [ "### 在新数据上运行推断" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BOR2MupW1_zS", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "inputs = [\n", " \"how do I extract keys from a dict into a list?\", # 'python'\n", " \"debug public static void main(string[] args) {...}\", # 'java'\n", "]\n", "predicted_scores = export_model.predict(inputs)\n", "predicted_labels = get_string_labels(predicted_scores)\n", "for input, label in zip(inputs, predicted_labels):\n", " print(\"Question: \", input)\n", " print(\"Predicted label: \", label.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "0QDVfii_4slI" }, "source": [ "将文本预处理逻辑包含在模型中后,您可以导出用于生产的模型,从而简化部署并降低[训练/测试偏差](https://developers.google.com/machine-learning/guides/rules-of-ml#training-serving_skew)的可能性。\n", "\n", "在选择应用 `tf.keras.layers.TextVectorization` 层的位置时,需要注意性能差异。在模型之外使用它可以让您在 GPU 上训练时进行异步 CPU 处理和数据缓冲。因此,如果您在 GPU 上训练模型,您应该在开发模型时使用此选项以获得最佳性能,然后在准备好部署时进行切换,在模型中包含 `TextVectorization` 层。\n", "\n", "请参阅[保存和加载模型](../keras/save_and_load.ipynb)教程,详细了解如何保存模型。" ] }, { "cell_type": "markdown", "metadata": { "id": "p4cvuFzavTRy" }, "source": [ "## 例 2:预测《伊利亚特》翻译的作者\n" ] }, { "cell_type": "markdown", "metadata": { "id": "fOlJ22508RIe" }, "source": [ "下面提供了一个使用 `tf.data.TextLineDataset` 从文本文件中加载样本,以及使用 [TensorFlow Text](https://tensorflow.google.cn/text) 预处理数据的示例。您将使用同一作品(荷马的《伊利亚特》)的三种不同英语翻译,训练一个模型来识别给定单行文本的译者。" ] }, { "cell_type": "markdown", "metadata": { "id": "-pCgKbOSk7kU" }, "source": [ "### 下载并探索数据集\n", "\n", "三个译本的文本来自:\n", "\n", "- [William Cowper](https://en.wikipedia.org/wiki/William_Cowper):[文本](https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt)\n", "- [Edward, Earl of Derby](https://en.wikipedia.org/wiki/Edward_Smith-Stanley,_14th_Earl_of_Derby):[文本](https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt)\n", "- [Samuel Butler](https://en.wikipedia.org/wiki/Samuel_Butler_%28novelist%29):[文本](https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt)\n", "\n", "本教程中使用的文本文件经历了一些典型的预处理任务,例如移除文档页眉和页脚、行号和章节标题。\n", "\n", "将这些稍微改动过的文件下载到本地:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4YlKQthEYlFw", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "DIRECTORY_URL = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'\n", "FILE_NAMES = ['cowper.txt', 'derby.txt', 'butler.txt']\n", "\n", "for name in FILE_NAMES:\n", " text_dir = utils.get_file(name, origin=DIRECTORY_URL + name)\n", "\n", "parent_dir = pathlib.Path(text_dir).parent\n", "list(parent_dir.iterdir())" ] }, { "cell_type": "markdown", "metadata": { "id": "M8PHK5J_cXE5" }, "source": [ "### 加载数据集\n", "\n", "以前,使用 `tf.keras.utils.text_dataset_from_directory` 时,文件的所有内容都会被视为单个样本。在这里,您将使用 `tf.data.TextLineDataset`,它旨在从文本文件创建 `tf.data.Dataset`,其中每个样本都是原始文件中的一行文本。`TextLineDataset` 对于主要基于行的文本数据(例如,诗歌或错误日志)非常有用。\n", "\n", "遍历这些文件,将每个文件加载到自己的数据集中。每个样本都需要单独加标签,因此请使用 `Dataset.map` 为每个样本应用标签添加器功能。这将遍历数据集中的每个样本,同时返回 (`example, label`) 对。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YIIWIdPXgk7I", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def labeler(example, index):\n", " return example, tf.cast(index, tf.int64)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8Ajx7AmZnEg3", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "labeled_data_sets = []\n", "\n", "for i, file_name in enumerate(FILE_NAMES):\n", " lines_dataset = tf.data.TextLineDataset(str(parent_dir/file_name))\n", " labeled_dataset = lines_dataset.map(lambda ex: labeler(ex, i))\n", " labeled_data_sets.append(labeled_dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "wPOsVK1e9NGM" }, "source": [ "接下来,您将使用 `Dataset.concatenate` 将这些带标签的数据集组合到一个数据集中,并使用 `Dataset.shuffle` 打乱其顺序:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6jAeYkTIi9-2", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "BUFFER_SIZE = 50000\n", "BATCH_SIZE = 64\n", "VALIDATION_SIZE = 5000" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Qd544E-Sh63L", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "all_labeled_data = labeled_data_sets[0]\n", "for labeled_dataset in labeled_data_sets[1:]:\n", " all_labeled_data = all_labeled_data.concatenate(labeled_dataset)\n", "\n", "all_labeled_data = all_labeled_data.shuffle(\n", " BUFFER_SIZE, reshuffle_each_iteration=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "r4JEHrJXeG5k" }, "source": [ "像以前一样打印出几个样本。数据集尚未经过批处理,因此 `all_labeled_data` 中的每个条目都对应一个数据点:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gywKlN0xh6u5", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for text, label in all_labeled_data.take(10):\n", " print(\"Sentence: \", text.numpy())\n", " print(\"Label:\", label.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "5rrpU2_sfDh0" }, "source": [ "### 准备用于训练的数据集\n", "\n", "现在,将不再使用 `tf.keras.layers.TextVectorization` 来预处理文本数据集,而是使用 TensorFlow Text API 对数据进行标准化和词例化、构建词汇表并使用 `tf.lookup.StaticVocabularyTable` 将词例映射到整数以馈送给模型。(详细了解 [TensorFlow Text](https://tensorflow.google.cn/text))。\n", "\n", "定义一个将文本转换为小写并对其进行词例化的函数:\n", "\n", "- TensorFlow Text 提供各种分词器。在此示例中,您将使用 `text.UnicodeScriptTokenizer` 对数据集进行词例化。\n", "- 您将使用 `Dataset.map` 将词例化应用于数据集。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v4DpQW-Y12rm", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "tokenizer = tf_text.UnicodeScriptTokenizer()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pz8xEj0ugu51", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def tokenize(text, unused_label):\n", " lower_case = tf_text.case_fold_utf8(text)\n", " return tokenizer.tokenize(lower_case)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vzUrAzOq31QL", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "tokenized_ds = all_labeled_data.map(tokenize)" ] }, { "cell_type": "markdown", "metadata": { "id": "jx4Q2i8XLV7o" }, "source": [ "您可以遍历数据集并打印出一些词例化的样本:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g2mkWri7LiGq", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for text_batch in tokenized_ds.take(5):\n", " print(\"Tokens: \", text_batch.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "JPd4PsskJ_Xt" }, "source": [ "接下来,您将通过按频率对词例进行排序并保留顶部 `VOCAB_SIZE` 词例来构建词汇表:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YkHtbGnDh6mg", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "tokenized_ds = configure_dataset(tokenized_ds)\n", "\n", "vocab_dict = collections.defaultdict(lambda: 0)\n", "for toks in tokenized_ds.as_numpy_iterator():\n", " for tok in toks:\n", " vocab_dict[tok] += 1\n", "\n", "vocab = sorted(vocab_dict.items(), key=lambda x: x[1], reverse=True)\n", "vocab = [token for token, count in vocab]\n", "vocab = vocab[:VOCAB_SIZE]\n", "vocab_size = len(vocab)\n", "print(\"Vocab size: \", vocab_size)\n", "print(\"First five vocab entries:\", vocab[:5])" ] }, { "cell_type": "markdown", "metadata": { "id": "PyKSsaNAKi17" }, "source": [ "要将词例转换为整数,请使用 `vocab` 集创建 `tf.lookup.StaticVocabularyTable`。您将词例映射到 [`2`, `vocab_size + 2`] 范围内的整数。与 `TextVectorization` 层一样,保留 `0` 表示填充,保留 `1` 表示词汇表外 (OOV) 词例。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kCBo2yFHD7y6", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "keys = vocab\n", "values = range(2, len(vocab) + 2) # Reserve `0` for padding, `1` for OOV tokens.\n", "\n", "init = tf.lookup.KeyValueTensorInitializer(\n", " keys, values, key_dtype=tf.string, value_dtype=tf.int64)\n", "\n", "num_oov_buckets = 1\n", "vocab_table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets)" ] }, { "cell_type": "markdown", "metadata": { "id": "Z5F-EiBpOADE" }, "source": [ "最后,定义一个函数来使用分词器和查找表对数据集进行标准化、词例化和向量化:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HcIQ7LOTh6eT", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def preprocess_text(text, label):\n", " standardized = tf_text.case_fold_utf8(text)\n", " tokenized = tokenizer.tokenize(standardized)\n", " vectorized = vocab_table.lookup(tokenized)\n", " return vectorized, label" ] }, { "cell_type": "markdown", "metadata": { "id": "v6S5Qyabi-vo" }, "source": [ "您可以在单个样本上尝试此操作并打印输出:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jgxPZaxUuTbk", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "example_text, example_label = next(iter(all_labeled_data))\n", "print(\"Sentence: \", example_text.numpy())\n", "vectorized_text, example_label = preprocess_text(example_text, example_label)\n", "print(\"Vectorized sentence: \", vectorized_text.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "p9qHM0v8k_Mg" }, "source": [ "现在,使用 `Dataset.map` 在数据集上运行预处理函数:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KmQVsAgJ-RM0", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "all_encoded_data = all_labeled_data.map(preprocess_text)" ] }, { "cell_type": "markdown", "metadata": { "id": "_YZToSXSm0qr" }, "source": [ "### 将数据集拆分为训练集和测试集\n" ] }, { "cell_type": "markdown", "metadata": { "id": "itxIJwkrUXgv" }, "source": [ "Keras `TextVectorization` 层还会对向量化数据进行批处理和填充。填充是必需的,因为批次内的样本需要具有相同的大小和形状,但这些数据集中的样本并非全部相同 – 每行文本具有不同数量的单词。\n", "\n", "`tf.data.Dataset` 支持拆分和填充批次数据集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "r-rmbijQh6bf", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_data = all_encoded_data.skip(VALIDATION_SIZE).shuffle(BUFFER_SIZE)\n", "validation_data = all_encoded_data.take(VALIDATION_SIZE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qTP0IwHBCn0Q", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_data = train_data.padded_batch(BATCH_SIZE)\n", "validation_data = validation_data.padded_batch(BATCH_SIZE)" ] }, { "cell_type": "markdown", "metadata": { "id": "m-wmFq8uW1zS" }, "source": [ "现在,`validation_data` 和 `train_data` 不是 (`example, label`) 对的集合,而是批次的集合。每个批次都是一对表示为数组的(*许多样本*、*许多标签*)。\n", "\n", "为了说明这一点:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kMslWfuwoqpB", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "sample_text, sample_labels = next(iter(validation_data))\n", "print(\"Text batch shape: \", sample_text.shape)\n", "print(\"Label batch shape: \", sample_labels.shape)\n", "print(\"First text example: \", sample_text[0])\n", "print(\"First label example: \", sample_labels[0])" ] }, { "cell_type": "markdown", "metadata": { "id": "UI4I6_Sa0vWu" }, "source": [ "由于您将 `0` 用于填充,将 `1` 用于词汇外 (OOV) 词例,词汇量增加了两倍:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u21LlkO8QGRX", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "vocab_size += 2" ] }, { "cell_type": "markdown", "metadata": { "id": "h44Ox11OYLP-" }, "source": [ "像以前一样配置数据集以提高性能:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BpT0b_7mYRXV", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_data = configure_dataset(train_data)\n", "validation_data = configure_dataset(validation_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "K8SUhGFNsmRi" }, "source": [ "### 训练模型\n", "\n", "您可以像以前一样在此数据集上训练模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QJgI1pow2YR9", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model = create_model(vocab_size=vocab_size, num_labels=3)\n", "\n", "model.compile(\n", " optimizer='adam',\n", " loss=losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])\n", "\n", "history = model.fit(train_data, validation_data=validation_data, epochs=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KTPCYf_Jh6TH", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "loss, accuracy = model.evaluate(validation_data)\n", "\n", "print(\"Loss: \", loss)\n", "print(\"Accuracy: {:2.2%}\".format(accuracy))" ] }, { "cell_type": "markdown", "metadata": { "id": "_knIsO-r4pHb" }, "source": [ "### 导出模型" ] }, { "cell_type": "markdown", "metadata": { "id": "FEuMLJA_Xiwo" }, "source": [ "为了使模型能够将原始字符串作为输入,您将创建一个 Keras `TextVectorization` 层,该层执行与您的自定义预处理函数相同的步骤。由于您已经训练了一个词汇表,可以使用 `TextVectorization.set_vocabulary`(而不是 `TextVectorization.adapt`)来训练一个新词汇表。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_ODkRXbk6aHb", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "preprocess_layer = TextVectorization(\n", " max_tokens=vocab_size,\n", " standardize=tf_text.case_fold_utf8,\n", " split=tokenizer.tokenize,\n", " output_mode='int',\n", " output_sequence_length=MAX_SEQUENCE_LENGTH)\n", "\n", "preprocess_layer.set_vocabulary(vocab)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G-Cvd27y4qwt", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "export_model = tf.keras.Sequential(\n", " [preprocess_layer, model,\n", " layers.Activation('sigmoid')])\n", "\n", "export_model.compile(\n", " loss=losses.SparseCategoricalCrossentropy(from_logits=False),\n", " optimizer='adam',\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Pyg0B4zsc-UD", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Create a test dataset of raw strings.\n", "test_ds = all_labeled_data.take(VALIDATION_SIZE).batch(BATCH_SIZE)\n", "test_ds = configure_dataset(test_ds)\n", "\n", "loss, accuracy = export_model.evaluate(test_ds)\n", "\n", "print(\"Loss: \", loss)\n", "print(\"Accuracy: {:2.2%}\".format(accuracy))" ] }, { "cell_type": "markdown", "metadata": { "id": "o6Mm0Y9QYQwE" }, "source": [ "正如预期的那样,编码验证集上的模型和原始验证集上的导出模型的损失和准确率相同。" ] }, { "cell_type": "markdown", "metadata": { "id": "Stk2BP8GE-qo" }, "source": [ "### 在新数据上运行推断" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-w1fQGJPD2Yh", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "inputs = [\n", " \"Join'd to th' Ionians with their flowing robes,\", # Label: 1\n", " \"the allies, and his armour flashed about him so that he seemed to all\", # Label: 2\n", " \"And with loud clangor of his arms he fell.\", # Label: 0\n", "]\n", "\n", "predicted_scores = export_model.predict(inputs)\n", "predicted_labels = tf.math.argmax(predicted_scores, axis=1)\n", "\n", "for input, label in zip(inputs, predicted_labels):\n", " print(\"Question: \", input)\n", " print(\"Predicted label: \", label.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "9eA8TVdnA-3L" }, "source": [ "## 使用 TensorFlow Datasets (TFDS) 下载更多数据集\n" ] }, { "cell_type": "markdown", "metadata": { "id": "2QFSxfZ3Vqsn" }, "source": [ "您可以从 [TensorFlow Datasets](https://tensorflow.google.cn/datasets/catalog/overview) 下载更多数据集。\n", "\n", "在此示例中,您将使用 [IMDB Large Movie Review Dataset](https://tensorflow.google.cn/datasets/catalog/imdb_reviews) 来训练情感分类模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NzC65LOaVw0B", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Training set.\n", "train_ds = tfds.load(\n", " 'imdb_reviews',\n", " split='train[:80%]',\n", " batch_size=BATCH_SIZE,\n", " shuffle_files=True,\n", " as_supervised=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XKGkgPBkFh0k", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Validation set.\n", "val_ds = tfds.load(\n", " 'imdb_reviews',\n", " split='train[80%:]',\n", " batch_size=BATCH_SIZE,\n", " shuffle_files=True,\n", " as_supervised=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "BQjf3YZAb5Ne" }, "source": [ "打印几个样本:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Bq1w8MnfWt2C", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for review_batch, label_batch in val_ds.take(1):\n", " for i in range(5):\n", " print(\"Review: \", review_batch[i].numpy())\n", " print(\"Label: \", label_batch[i].numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "q-lVaukyb75k" }, "source": [ "您现在可以像以前一样预处理数据并训练模型。\n", "\n", "注:您将对模型使用 `tf.keras.losses.BinaryCrossentropy` 而不是 `tf.keras.losses.SparseCategoricalCrossentropy`,因为这是一个二元分类问题。" ] }, { "cell_type": "markdown", "metadata": { "id": "ciz2CxAsZw3Z" }, "source": [ "### 准备用于训练的数据集" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UzT_t9ihZLH4", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "vectorize_layer = TextVectorization(\n", " max_tokens=VOCAB_SIZE,\n", " output_mode='int',\n", " output_sequence_length=MAX_SEQUENCE_LENGTH)\n", "\n", "# Make a text-only dataset (without labels), then call `TextVectorization.adapt`.\n", "train_text = train_ds.map(lambda text, labels: text)\n", "vectorize_layer.adapt(train_text)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zz-Xrd_ZZ4tB", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def vectorize_text(text, label):\n", " text = tf.expand_dims(text, -1)\n", " return vectorize_layer(text), label" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ycn0Itd6g5aF", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_ds = train_ds.map(vectorize_text)\n", "val_ds = val_ds.map(vectorize_text)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jc11jQTlZ5lj", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Configure datasets for performance as before.\n", "train_ds = configure_dataset(train_ds)\n", "val_ds = configure_dataset(val_ds)" ] }, { "cell_type": "markdown", "metadata": { "id": "SQzoYkaGZ82Z" }, "source": [ "### 创建、配置和训练模型" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B9IOTLkyZ-a7", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model = create_model(vocab_size=VOCAB_SIZE + 1, num_labels=1)\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xLnDs5dhaBAk", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model.compile(\n", " loss=losses.BinaryCrossentropy(from_logits=True),\n", " optimizer='adam',\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rq59QpNzaDMa", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "history = model.fit(train_ds, validation_data=val_ds, epochs=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gCMWCEtyaEbR", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "loss, accuracy = model.evaluate(val_ds)\n", "\n", "print(\"Loss: \", loss)\n", "print(\"Accuracy: {:2.2%}\".format(accuracy))" ] }, { "cell_type": "markdown", "metadata": { "id": "jGtqLXVnaaFy" }, "source": [ "### 导出模型" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yE9WZARZaZr1", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "export_model = tf.keras.Sequential(\n", " [vectorize_layer, model,\n", " layers.Activation('sigmoid')])\n", "\n", "export_model.compile(\n", " loss=losses.SparseCategoricalCrossentropy(from_logits=False),\n", " optimizer='adam',\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bhF8tDH-afoC", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# 0 --> negative review\n", "# 1 --> positive review\n", "inputs = [\n", " \"This is a fantastic movie.\",\n", " \"This is a bad movie.\",\n", " \"This movie was so bad that it was good.\",\n", " \"I will never say yes to watching this movie.\",\n", "]\n", "\n", "predicted_scores = export_model.predict(inputs)\n", "predicted_labels = [int(round(x[0])) for x in predicted_scores]\n", "\n", "for input, label in zip(inputs, predicted_labels):\n", " print(\"Question: \", input)\n", " print(\"Predicted label: \", label)" ] }, { "cell_type": "markdown", "metadata": { "id": "q1KSXDFPWiPN" }, "source": [ "## 结论\n", "\n", "本教程演示了几种加载和预处理文本的方法。接下来,您可以探索其他文本预处理 [TensorFlow Text](https://tensorflow.google.cn/text) 教程,例如:\n", "\n", "- [使用 TF Text 进行 BERT 预处理](https://tensorflow.google.cn/text/guide/bert_preprocessing_guide)\n", "- [使用 TF Text 进行词例化](https://tensorflow.google.cn/text/guide/tokenizers)\n", "- [子词分词器](https://tensorflow.google.cn/text/guide/subwords_tokenizer)\n", "\n", "此外,您还可以在 [TensorFlow Datasets](https://tensorflow.google.cn/datasets/catalog/overview) 上找到新的数据集。而且,要详细了解 `tf.data`,请查看有关[构建输入流水线](../../guide/data.ipynb)的指南。" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "text.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }