{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "Ic4_occAAiAT" }, "outputs": [], "source": [ "import os\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 设置日志级别为ERROR,以减少警告信息\n", "# 禁用 Gemini 的底层库(gRPC 和 Abseil)在初始化日志警告\n", "os.environ[\"GRPC_VERBOSITY\"] = \"ERROR\"\n", "os.environ[\"GLOG_minloglevel\"] = \"3\" # 0: INFO, 1: WARNING, 2: ERROR, 3: FATAL\n", "os.environ[\"GLOG_minloglevel\"] = \"true\"\n", "import logging\n", "import tensorflow as tf\n", "tf.get_logger().setLevel(logging.ERROR)\n", "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n", "!export TF_FORCE_GPU_ALLOW_GROWTH=true" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cellView": "form", "id": "ioaprt5q5US7" }, "outputs": [], "source": [ "# Copyright 2019 The TensorFlow Authors.\n", "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cellView": "form", "id": "yCl0eTNH5RS3" }, "outputs": [], "source": [ "#@title MIT License\n", "#\n", "# Copyright (c) 2017 François Chollet\n", "#\n", "# Permission is hereby granted, free of charge, to any person obtaining a\n", "# copy of this software and associated documentation files (the \"Software\"),\n", "# to deal in the Software without restriction, including without limitation\n", "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n", "# and/or sell copies of the Software, and to permit persons to whom the\n", "# Software is furnished to do so, subject to the following conditions:\n", "#\n", "# The above copyright notice and this permission notice shall be included in\n", "# all copies or substantial portions of the Software.\n", "#\n", "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n", "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n", "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n", "# DEALINGS IN THE SOFTWARE." ] }, { "cell_type": "markdown", "metadata": { "id": "ItXfxkxvosLH" }, "source": [ "# 电影评论文本分类" ] }, { "cell_type": "markdown", "metadata": { "id": "hKY4XMc9o8iB" }, "source": [ "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
Model: \"sequential\"\n",
"
\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ embedding (Embedding) │ ? │ 0 (unbuilt) │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout (Dropout) │ ? │ 0 (unbuilt) │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ global_average_pooling1d │ ? │ 0 (unbuilt) │\n", "│ (GlobalAveragePooling1D) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_1 (Dropout) │ ? │ 0 (unbuilt) │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense (Dense) │ ? │ 0 (unbuilt) │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n", "\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ embedding (\u001b[38;5;33mEmbedding\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ global_average_pooling1d │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "│ (\u001b[38;5;33mGlobalAveragePooling1D\u001b[0m) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_1 (\u001b[38;5;33mDropout\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense (\u001b[38;5;33mDense\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Total params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = tf.keras.Sequential([\n", " layers.Embedding(max_features + 1, embedding_dim),\n", " layers.Dropout(0.2),\n", " layers.GlobalAveragePooling1D(),\n", " layers.Dropout(0.2),\n", " layers.Dense(1)])\n", "\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "6PbKQ6mucuKL" }, "source": [ "层按顺序堆叠以构建分类器:\n", "\n", "1. 第一个层是 `Embedding` 层。此层采用整数编码的评论,并查找每个单词索引的嵌入向量。这些向量是通过模型训练学习到的。向量向输出数组增加了一个维度。得到的维度为:`(batch, sequence, embedding)`。要详细了解嵌入向量,请参阅[单词嵌入向量](https://tensorflow.google.cn/text/guide/word_embeddings)教程。\n", "2. 接下来,`GlobalAveragePooling1D` 将通过对序列维度求平均值来为每个样本返回一个定长输出向量。这允许模型以尽可能最简单的方式处理变长输入。\n", "3. 最后一层与单个输出结点密集连接。" ] }, { "cell_type": "markdown", "metadata": { "id": "L4EqVWg4-llM" }, "source": [ "### 损失函数与优化器\n", "\n", "模型训练需要一个损失函数和一个优化器。由于这是一个二元分类问题,并且模型输出概率(具有 Sigmoid 激活的单一单元层),我们将使用 `losses.BinaryCrossentropy` 损失函数。\n", "\n", "现在,配置模型以使用优化器和损失函数:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "Mr0GP-cQ-llN" }, "outputs": [], "source": [ "model.compile(loss=losses.BinaryCrossentropy(from_logits=True),\n", " optimizer='adam',\n", " metrics=[tf.metrics.BinaryAccuracy(threshold=0.0)])" ] }, { "cell_type": "markdown", "metadata": { "id": "35jv_fzP-llU" }, "source": [ "### 训练模型\n", "\n", "将 `dataset` 对象传递给 fit 方法,对模型进行训练。" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "id": "tXSGrjWZ-llW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1729736653.040209 3211559 service.cc:146] XLA service 0x7ff5cc046510 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n", "I0000 00:00:1729736653.040274 3211559 service.cc:154] StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6\n", "I0000 00:00:1729736653.040291 3211559 service.cc:154] StreamExecutor device (1): NVIDIA GeForce RTX 2080 Ti, Compute Capability 7.5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m 62/625\u001b[0m \u001b[32m━\u001b[0m\u001b[37m━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - binary_accuracy: 0.5149 - loss: 0.6927" ] }, { "name": "stderr", "output_type": "stream", "text": [ "I0000 00:00:1729736655.192179 3211559 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 5ms/step - binary_accuracy: 0.5839 - loss: 0.6812 - val_binary_accuracy: 0.7276 - val_loss: 0.6142\n", "Epoch 2/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step - binary_accuracy: 0.7579 - loss: 0.5812 - val_binary_accuracy: 0.8058 - val_loss: 0.5011\n", "Epoch 3/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step - binary_accuracy: 0.8244 - loss: 0.4678 - val_binary_accuracy: 0.8306 - val_loss: 0.4291\n", "Epoch 4/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step - binary_accuracy: 0.8530 - loss: 0.3968 - val_binary_accuracy: 0.8352 - val_loss: 0.3904\n", "Epoch 5/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step - binary_accuracy: 0.8662 - loss: 0.3499 - val_binary_accuracy: 0.8526 - val_loss: 0.3592\n", "Epoch 6/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step - binary_accuracy: 0.8814 - loss: 0.3168 - val_binary_accuracy: 0.8552 - val_loss: 0.3425\n", "Epoch 7/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step - binary_accuracy: 0.8901 - loss: 0.2914 - val_binary_accuracy: 0.8474 - val_loss: 0.3385\n", "Epoch 8/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step - binary_accuracy: 0.9014 - loss: 0.2706 - val_binary_accuracy: 0.8564 - val_loss: 0.3247\n", "Epoch 9/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step - binary_accuracy: 0.9059 - loss: 0.2541 - val_binary_accuracy: 0.8594 - val_loss: 0.3166\n", "Epoch 10/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step - binary_accuracy: 0.9119 - loss: 0.2395 - val_binary_accuracy: 0.8610 - val_loss: 0.3143\n" ] } ], "source": [ "epochs = 10\n", "history = model.fit(\n", " train_ds,\n", " validation_data=val_ds,\n", " epochs=epochs)" ] }, { "cell_type": "markdown", "metadata": { "id": "9EEGuDVuzb5r" }, "source": [ "### 评估模型\n", "\n", "我们来看一下模型的性能如何。将返回两个值。损失值(loss)(一个表示误差的数字,值越低越好)与准确率(accuracy)。" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "zOMKywn4zReN" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m782/782\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - binary_accuracy: 0.8558 - loss: 0.3287\n", "Loss: 0.33292046189308167\n", "Accuracy: 0.8543599843978882\n" ] } ], "source": [ "loss, accuracy = model.evaluate(test_ds)\n", "\n", "print(\"Loss: \", loss)\n", "print(\"Accuracy: \", accuracy)" ] }, { "cell_type": "markdown", "metadata": { "id": "z1iEXVTR0Z2t" }, "source": [ "这种十分简单的方式实现了约 86% 的准确率。" ] }, { "cell_type": "markdown", "metadata": { "id": "ldbQqCw2Xc1W" }, "source": [ "### 创建准确率和损失随时间变化的图表\n", "\n", "`model.fit()` 会返回包含一个字典的 `History` 对象。该字典包含训练过程中产生的所有信息:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "-YcvZsdvWfDf" }, "outputs": [ { "data": { "text/plain": [ "dict_keys(['binary_accuracy', 'loss', 'val_binary_accuracy', 'val_loss'])" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "history_dict = history.history\n", "history_dict.keys()" ] }, { "cell_type": "markdown", "metadata": { "id": "1_CH32qJXruI" }, "source": [ "其中有四个条目:每个条目代表训练和验证过程中的一项监测指标。您可以使用这些指标来绘制用于比较的训练损失和验证损失图表,以及训练准确率和验证准确率图表:" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "2SEMeQ5YXs8z" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "