{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "wJcYs_ERTnnI" }, "outputs": [], "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "HMUDt0CiUJk9", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "77z2OchJTk0l" }, "source": [ "# 迁移提前停止\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "meUTrR4I6m1C" }, "source": [ "本笔记本演示了如何使用提前停止设置模型训练。首先,在 TensorFlow 1 中使用 `tf.estimator.Estimator` 和提前停止钩子,然后在 TensorFlow 2 中使用 Keras API 或自定义训练循环。 提前停止是一种正则化技术,可在验证损失达到特定阈值时停止训练。\n", "\n", "在 TensorFlow 2 中,可以通过三种方式实现提前停止:\n", "\n", "- 使用内置的 Keras 回调 `tf.keras.callbacks.EarlyStopping` 并将其传递给 `Model.fit`。\n", "- 定义自定义回调并将其传递给 Keras `Model.fit`。\n", "- 在[自定义训练循环](https://tensorflow.google.cn/guide/keras/writing_a_training_loop_from_scratch)中编写自定义提前停止规则(使用 `tf.GradientTape`)。" ] }, { "cell_type": "markdown", "metadata": { "id": "YdZSoIXEbhg-" }, "source": [ "## 安装" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iE0vSfMXumKI", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "import tensorflow as tf\n", "import tensorflow.compat.v1 as tf1\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "markdown", "metadata": { "id": "4uXff1BEssdE" }, "source": [ "## TensorFlow 1:使用提前停止钩子和 tf.estimator 提前停止" ] }, { "cell_type": "markdown", "metadata": { "id": "JaHhhhW5o8lL" }, "source": [ "首先,定义用于 MNIST 数据集加载和预处理的函数,以及与 `tf.estimator.Estimator` 一起使用的模型定义:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lqe9obf7suIj", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def normalize_img(image, label):\n", " return tf.cast(image, tf.float32) / 255., label\n", "\n", "def _input_fn():\n", " ds_train = tfds.load(\n", " name='mnist',\n", " split='train',\n", " shuffle_files=True,\n", " as_supervised=True)\n", "\n", " ds_train = ds_train.map(\n", " normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n", " ds_train = ds_train.batch(128)\n", " ds_train = ds_train.repeat(100)\n", " return ds_train\n", "\n", "def _eval_input_fn():\n", " ds_test = tfds.load(\n", " name='mnist',\n", " split='test',\n", " shuffle_files=True,\n", " as_supervised=True)\n", " ds_test = ds_test.map(\n", " normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n", " ds_test = ds_test.batch(128)\n", " return ds_test\n", "\n", "def _model_fn(features, labels, mode):\n", " flatten = tf1.layers.Flatten()(features)\n", " features = tf1.layers.Dense(128, 'relu')(flatten)\n", " logits = tf1.layers.Dense(10)(features)\n", "\n", " loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)\n", " optimizer = tf1.train.AdagradOptimizer(0.005)\n", " train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())\n", "\n", " return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)" ] }, { "cell_type": "markdown", "metadata": { "id": "hC_AY7KwqD0p" }, "source": [ "在 TensorFlow 1 中,提前停止的工作方式是使用 `tf.estimator.experimental.make_early_stopping_hook` 设置提前停止钩子。将钩子传递给 `make_early_stopping_hook` 方法作为 `should_stop_fn` 的参数,它可以接受不带任何参数的函数。一旦 `should_stop_fn` 返回 `True`,训练就会停止。\n", "\n", "下面的示例演示了如何实现将训练时间限制为最多 20 秒的提前停止技术:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HsOpjW5plH9Q", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "estimator = tf1.estimator.Estimator(model_fn=_model_fn)\n", "\n", "start_time = time.time()\n", "max_train_seconds = 20\n", "\n", "def should_stop_fn():\n", " return time.time() - start_time > max_train_seconds\n", "\n", "early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(\n", " estimator=estimator,\n", " should_stop_fn=should_stop_fn,\n", " run_every_secs=1,\n", " run_every_steps=None)\n", "\n", "train_spec = tf1.estimator.TrainSpec(\n", " input_fn=_input_fn,\n", " hooks=[early_stopping_hook])\n", "\n", "eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)\n", "\n", "tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)" ] }, { "cell_type": "markdown", "metadata": { "id": "KEmzBjfnsxwT" }, "source": [ "### TensorFlow 2:使用内置回调和 Model.fit 提前停止" ] }, { "cell_type": "markdown", "metadata": { "id": "GKwxnkIksPFW" }, "source": [ "准备 MNIST 数据集和一个简单的 Keras 模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "atVciNgPs0fw", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "(ds_train, ds_test), ds_info = tfds.load(\n", " 'mnist',\n", " split=['train', 'test'],\n", " shuffle_files=True,\n", " as_supervised=True,\n", " with_info=True,\n", ")\n", "\n", "ds_train = ds_train.map(\n", " normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n", "ds_train = ds_train.batch(128)\n", "\n", "ds_test = ds_test.map(\n", " normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n", "ds_test = ds_test.batch(128)\n", "\n", "model = tf.keras.models.Sequential([\n", " tf.keras.layers.Flatten(input_shape=(28, 28)),\n", " tf.keras.layers.Dense(128, activation='relu'),\n", " tf.keras.layers.Dense(10)\n", "])\n", "\n", "model.compile(\n", " optimizer=tf.keras.optimizers.Adam(0.005),\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "559Goxp3tOMl" }, "source": [ "在 TensorFlow 2 中,当您使用内置的 Keras `Model.fit`(或 `Model.evaluate`)时,可以通过将内置回调 `tf.keras.callbacks.EarlyStopping` 传递给 `Model.fit` 的 `callbacks` 参数来配置提前停止。\n", "\n", "`EarlyStopping` 回调会监视用户指定的指标,并在停止改进时结束训练。(请查看[使用内置方法进行训练和评估](https://tensorflow.google.cn/guide/keras/train_and_evaluate#using_callbacks)或 [API 文档](https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks/EarlyStopping)来了解详情。)\n", "\n", "下面是一个提前停止回调的示例,它监视损失并在显示没有改进的周期数设置为 `3` (`patience`) 后停止训练: " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Kip65sYBlKiu", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)\n", "\n", "# Only around 25 epochs are run during training, instead of 100.\n", "history = model.fit(\n", " ds_train,\n", " epochs=100,\n", " validation_data=ds_test,\n", " callbacks=[callback]\n", ")\n", "\n", "len(history.history['loss'])" ] }, { "cell_type": "markdown", "metadata": { "id": "a92c6ebb1a1c" }, "source": [ "### TensorFlow 2:使用自定义回调和 Model.fit 提前停止" ] }, { "cell_type": "markdown", "metadata": { "id": "wCwZ4BA8jaHY" }, "source": [ "您也可以实现[自定义的提前停止回调](https://tensorflow.google.cn/guide/keras/custom_callback/#early_stopping_at_minimum_loss),此回调也可以传递给 `Model.fit`(或 `Model.evaluate`)的 `callbacks` 参数。\n", "\n", "在此示例中,一旦 `self.model.stop_training` 设置为 `True`,训练过程就会停止:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Hns1fmwtjCg2", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class LimitTrainingTime(tf.keras.callbacks.Callback):\n", " def __init__(self, max_time_s):\n", " super().__init__()\n", " self.max_time_s = max_time_s\n", " self.start_time = None\n", "\n", " def on_train_begin(self, logs):\n", " self.start_time = time.time()\n", "\n", " def on_train_batch_end(self, batch, logs):\n", " now = time.time()\n", " if now - self.start_time > self.max_time_s:\n", " self.model.stop_training = True" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "s5mIzDOAkUKA", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Limit the training time to 30 seconds.\n", "callback = LimitTrainingTime(30)\n", "history = model.fit(\n", " ds_train,\n", " epochs=100,\n", " validation_data=ds_test,\n", " callbacks=[callback]\n", ")\n", "len(history.history['loss'])" ] }, { "cell_type": "markdown", "metadata": { "id": "kro_lKyEu60-" }, "source": [ "## TensorFlow 2:使用自定义训练循环提前停止" ] }, { "cell_type": "markdown", "metadata": { "id": "g5LU0lebvuIk" }, "source": [ "在 TensorFlow 2 中,如果您不使用[内置 Keras 方法](https://tensorflow.google.cn/guide/keras/train_and_evaluate)进行训练和评估,则可以在[自定义训练循环](https://tensorflow.google.cn/tutorials/customization/custom_training_walkthrough#training_loop)中实现提前停止。\n", "\n", "首先,使用 Keras API 定义另一个简单的模型、优化器、损失函数和指标:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oTGxr0PwAiQ4", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model = tf.keras.models.Sequential([\n", " tf.keras.layers.Flatten(input_shape=(28, 28)),\n", " tf.keras.layers.Dense(128, activation='relu'),\n", " tf.keras.layers.Dense(10)\n", "])\n", "\n", "optimizer = tf.keras.optimizers.Adam(0.005)\n", "loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", "\n", "train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()\n", "train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()\n", "val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()\n", "val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()" ] }, { "cell_type": "markdown", "metadata": { "id": "zecsnqRxvy0Q" }, "source": [ "[使用 tf.GradientTape](https://tensorflow.google.cn/guide/keras/writing_a_training_loop_from_scratch) 和 `@tf.function` 装饰器定义参数更新函数以[加快速度](https://tensorflow.google.cn/guide/function):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "s3w_55n0Ah7L", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "@tf.function\n", "def train_step(x, y):\n", " with tf.GradientTape() as tape:\n", " logits = model(x, training=True)\n", " loss_value = loss_fn(y, logits)\n", " grads = tape.gradient(loss_value, model.trainable_weights)\n", " optimizer.apply_gradients(zip(grads, model.trainable_weights))\n", " train_acc_metric.update_state(y, logits)\n", " train_loss_metric.update_state(y, logits)\n", " return loss_value\n", "\n", "@tf.function\n", "def test_step(x, y):\n", " logits = model(x, training=False)\n", " val_acc_metric.update_state(y, logits)\n", " val_loss_metric.update_state(y, logits)" ] }, { "cell_type": "markdown", "metadata": { "id": "-ZKS9ePGwd9r" }, "source": [ "接下来,编写一个自定义训练循环,可以在其中手动实现提前停止规则。\n", "\n", "下面的示例显示了当验证损失在一定数量的周期内没有改进时如何停止训练:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iZOzHqqSAkpK", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "epochs = 100\n", "patience = 5\n", "wait = 0\n", "best = float('inf')\n", "\n", "for epoch in range(epochs):\n", " print(\"\\nStart of epoch %d\" % (epoch,))\n", " start_time = time.time()\n", "\n", " for step, (x_batch_train, y_batch_train) in enumerate(ds_train):\n", " loss_value = train_step(x_batch_train, y_batch_train)\n", " if step % 200 == 0:\n", " print(\"Training loss at step %d: %.4f\" % (step, loss_value.numpy()))\n", " print(\"Seen so far: %s samples\" % ((step + 1) * 128)) \n", " train_acc = train_acc_metric.result()\n", " train_loss = train_loss_metric.result()\n", " train_acc_metric.reset_states()\n", " train_loss_metric.reset_states()\n", " print(\"Training acc over epoch: %.4f\" % (train_acc.numpy()))\n", "\n", " for x_batch_val, y_batch_val in ds_test:\n", " test_step(x_batch_val, y_batch_val)\n", " val_acc = val_acc_metric.result()\n", " val_loss = val_loss_metric.result()\n", " val_acc_metric.reset_states()\n", " val_loss_metric.reset_states()\n", " print(\"Validation acc: %.4f\" % (float(val_acc),))\n", " print(\"Time taken: %.2fs\" % (time.time() - start_time))\n", "\n", " # The early stopping strategy: stop the training if `val_loss` does not\n", " # decrease over a certain number of epochs.\n", " wait += 1\n", " if val_loss < best:\n", " best = val_loss\n", " wait = 0\n", " if wait >= patience:\n", " break" ] }, { "cell_type": "markdown", "metadata": { "id": "e85558980a4b" }, "source": [ "## 后续步骤\n", "\n", "- 在 [API 文档](https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks/EarlyStopping)中详细了解 Keras 内置提前停止回调 API。\n", "- 了解如何[编写自定义 Keras 回调](https://tensorflow.google.cn/guide/keras/custom_callback),包括[以最小损失提前停止](https://tensorflow.google.cn/guide/keras/custom_callback/#early_stopping_at_minimum_loss)。\n", "- 了解[使用 Keras 内置方法进行训练和评估](https://tensorflow.google.cn/guide/keras/train_and_evaluate#using_callbacks)。\n", "- 在使用 `EarlyStopping` 回调的[过拟合和欠拟合](tensorflow.org/tutorials/keras/overfit_and_underfit)教程中探索常见的正则化技术。" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "early_stopping.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }