{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "wJcYs_ERTnnI" }, "outputs": [], "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "HMUDt0CiUJk9", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "77z2OchJTk0l" }, "source": [ "# 将 LoggingTensorHook 和 StopAtStepHook 迁移到 Keras 回调\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 运行\n", " 在 Github 上查看源代码\n", " 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "meUTrR4I6m1C" }, "source": [ "在 TensorFlow 1 中,可以使用 `tf.estimator.LoggingTensorHook` 监视和记录张量,而 `tf.estimator.StopAtStepHook` 则在使用 `tf.estimator.Estimator` 进行训练时有助于在指定步骤停止训练。本笔记本演示了如何使用带有 `Model.fit` 的自定义 Keras 回调 (`tf.keras.callbacks.Callback`) 从这些 API 迁移到 TensorFlow 2 中的对应项。\n", "\n", "Keras [回调](https://tensorflow.google.cn/guide/keras/custom_callback)是在内置 Keras `Model.fit`/`Model.evaluate`/`Model.predict` API 中的训练/评估/预测期间的不同点调用的对象。可以在 `tf.keras.callbacks.Callback` API 文档以及[编写自己的回调](../..guide/keras/custom_callback.ipynb/)和[使用内置方法进行训练和评估](https://tensorflow.google.cn/guide/keras/train_and_evaluate)(*使用回调* 部分)指南中详细了解回调。要从 TensorFlow 1 中的 `SessionRunHook` 迁移到 TensorFlow 2 中的 Keras 回调,请查看[迁移使用辅助逻辑的训练](sessionrunhook_callback.ipynb)指南。" ] }, { "cell_type": "markdown", "metadata": { "id": "YdZSoIXEbhg-" }, "source": [ "## 安装\n", "\n", "从导入和用于演示目的的简单数据集开始:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iE0vSfMXumKI", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow.compat.v1 as tf1" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m7rnGxsXtDkV", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "features = [[1., 1.5], [2., 2.5], [3., 3.5]]\n", "labels = [[0.3], [0.5], [0.7]]\n", "\n", "# Define an input function.\n", "def _input_fn():\n", " return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)" ] }, { "cell_type": "markdown", "metadata": { "id": "4uXff1BEssdE" }, "source": [ "## TensorFlow 1:使用 tf.estimator API 记录张量和停止训练" ] }, { "cell_type": "markdown", "metadata": { "id": "zW-X5cmzmkuw" }, "source": [ "在 TensorFlow 1 中,定义各种钩子来控制训练行为。随后,将这些钩子传递给 `tf.estimator.EstimatorSpec`。\n", "\n", "在下面的示例中:\n", "\n", "- 要监视/记录张量(例如模型权重或损失),可以使用 `tf.estimator.LoggingTensorHook`(`tf.train.LoggingTensorHook` 是它的别名)。\n", "- 要在特定步骤停止训练,请使用 `tf.estimator.StopAtStepHook`(`tf.train.StopAtStepHook` 是它的别名)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lqe9obf7suIj", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def _model_fn(features, labels, mode):\n", " dense = tf1.layers.Dense(1)\n", " logits = dense(features)\n", " loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)\n", " optimizer = tf1.train.AdagradOptimizer(0.05)\n", " train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())\n", "\n", " # Define the stop hook.\n", " stop_hook = tf1.train.StopAtStepHook(num_steps=2)\n", "\n", " # Access tensors to be logged by names.\n", " kernel_name = tf.identity(dense.weights[0])\n", " bias_name = tf.identity(dense.weights[1])\n", " logging_weight_hook = tf1.train.LoggingTensorHook(\n", " tensors=[kernel_name, bias_name],\n", " every_n_iter=1)\n", " # Log the training loss by the tensor object.\n", " logging_loss_hook = tf1.train.LoggingTensorHook(\n", " {'loss from LoggingTensorHook': loss},\n", " every_n_secs=3)\n", "\n", " # Pass all hooks to `EstimatorSpec`.\n", " return tf1.estimator.EstimatorSpec(mode,\n", " loss=loss,\n", " train_op=train_op,\n", " training_hooks=[stop_hook,\n", " logging_weight_hook,\n", " logging_loss_hook])\n", "\n", "estimator = tf1.estimator.Estimator(model_fn=_model_fn)\n", "\n", "# Begin training.\n", "# The training will stop after 2 steps, and the weights/loss will also be logged.\n", "estimator.train(_input_fn)" ] }, { "cell_type": "markdown", "metadata": { "id": "KEmzBjfnsxwT" }, "source": [ "## TensorFlow 2:使用自定义回调和 Model.fit 记录张量和停止训练" ] }, { "cell_type": "markdown", "metadata": { "id": "839R9i4xheI5" }, "source": [ "在 TensorFlow 2 中,当您使用内置 Keras `Model.fit`(或 `Model.evaluate`)进行训练/评估时,可以通过定义自定义 Keras `tf.keras.callbacks.Callback` 来配置张量监视和训练停止。随后,将它们传递给 `Model.fit`(或 `Model.evaluate`)的 `callbacks` 参数。(在[编写自己的回调](../..guide/keras/custom_callback.ipynb)指南中了解详情。)\n", "\n", "在下面的示例中:\n", "\n", "- 要重新创建 `StopAtStepHook` 的功能,请定义一个自定义回调(下称 `StopAtStepCallback`),可以在其中重写 `on_batch_end` 方法以在一定数量的步骤后停止训练。\n", "- 要重新创建 `LoggingTensorHook` 行为,请定义一个自定义回调 (`LoggingTensorCallback`),可以在其中手动记录和输出记录的张量,因为不支持按名称访问张量。此外,您还可以在自定义回调中实现记录频率。下面的示例将每两步打印一次权重。每 N 秒记录一次之类的其他策略也是可行的。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "atVciNgPs0fw", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class StopAtStepCallback(tf.keras.callbacks.Callback):\n", " def __init__(self, stop_step=None):\n", " super().__init__()\n", " self._stop_step = stop_step\n", "\n", " def on_batch_end(self, batch, logs=None):\n", " if self.model.optimizer.iterations >= self._stop_step:\n", " self.model.stop_training = True\n", " print('\\nstop training now')\n", "\n", "class LoggingTensorCallback(tf.keras.callbacks.Callback):\n", " def __init__(self, every_n_iter):\n", " super().__init__()\n", " self._every_n_iter = every_n_iter\n", " self._log_count = every_n_iter\n", "\n", " def on_batch_end(self, batch, logs=None):\n", " if self._log_count > 0:\n", " self._log_count -= 1\n", " print(\"Logging Tensor Callback: dense/kernel:\",\n", " model.layers[0].weights[0])\n", " print(\"Logging Tensor Callback: dense/bias:\",\n", " model.layers[0].weights[1])\n", " print(\"Logging Tensor Callback loss:\", logs[\"loss\"])\n", " else:\n", " self._log_count -= self._every_n_iter" ] }, { "cell_type": "markdown", "metadata": { "id": "30a8b71263e0" }, "source": [ "完成后,将新回调(`StopAtStepCallback` 和 `LoggingTensorCallback`)传递给 `Model.fit` 的 `callbacks` 参数:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Kip65sYBlKiu", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)\n", "model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])\n", "optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)\n", "model.compile(optimizer, \"mse\")\n", "\n", "# Begin training.\n", "# The training will stop after 2 steps, and the weights/loss will also be logged.\n", "model.fit(dataset, callbacks=[StopAtStepCallback(stop_step=2),\n", " LoggingTensorCallback(every_n_iter=2)])" ] }, { "cell_type": "markdown", "metadata": { "id": "19508f4720f5" }, "source": [ "## 后续步骤\n", "\n", "通过以下方式详细了解回调:\n", "\n", "- API 文档:`tf.keras.callbacks.Callback`\n", "- 指南:[编写自己的回调](../..guide/keras/custom_callback.ipynb/)\n", "- 指南:[使用内置方法进行训练和评估](https://tensorflow.google.cn/guide/keras/train_and_evaluate)(*使用回调*部分)\n", "\n", "此外,您可能还会发现下列与迁移相关的资源十分有用:\n", "\n", "- [提前停止迁移指南](early_stopping.ipynb):`tf.keras.callbacks.EarlyStopping` 是一个内置的提前停止回调\n", "- [TensorBoard 迁移指南](tensorboard.ipynb):TensorBoard 支持跟踪和显示指标\n", "- [使用辅助逻辑进行训练迁移指南](sessionrunhook_callback.ipynb):从 TensorFlow 1 中的 `SessionRunHook` 到 TensorFlow 2 中的 Keras 回调" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "logging_stop_hook.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }