{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "wJcYs_ERTnnI" }, "outputs": [], "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "HMUDt0CiUJk9", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "77z2OchJTk0l" }, "source": [ "# 将 SessionRunHook 迁移到 Keras 回调\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 运行\n", " 在 Github 上查看源代码\n", " 下载笔记本
\n" ] }, { "cell_type": "markdown", "metadata": { "id": "KZHPY55aFyXT" }, "source": [ "在 TensorFlow 1 中,要自定义训练的行为,可以使用 `tf.estimator.SessionRunHook` 和 `tf.estimator.Estimator`。本指南演示了如何使用 `tf.keras.callbacks.Callback` API 从 `SessionRunHook` 迁移到 TensorFlow 2 的自定义回调,此 API 与 Keras `Model.fit` 一起用于训练(以及 `Model.evaluate` 和 `Model.predict`)。您将通过实现 `SessionRunHook` 和 `Callback` 任务来学习如何做到这一点,此任务会在训练期间测量每秒的样本数。\n", "\n", "回调的示例为检查点保存 (`tf.keras.callbacks.ModelCheckpoint`) 和 [TensorBoard](%60tf.keras.callbacks.TensorBoard%60) 摘要编写。Keras [回调](https://tensorflow.google.cn/guide/keras/custom_callback)是在内置 Keras `Model.fit`/`Model.evaluate`/`Model.predict` API 中的训练/评估/预测期间的不同点调用的对象。可以在 `tf.keras.callbacks.Callback` API 文档以及[编写自己的回调](https://tensorflow.google.cn/guide/keras/custom_callback.ipynb/)和[使用内置方法进行训练和评估](https://tensorflow.google.cn/guide/keras/train_and_evaluate)(*使用回调*部分)指南中详细了解回调。" ] }, { "cell_type": "markdown", "metadata": { "id": "29da56bf859d" }, "source": [ "## 安装\n", "\n", "从导入和用于演示目的的简单数据集开始:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "296d8b0DoKpV", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow.compat.v1 as tf1\n", "\n", "import time\n", "from datetime import datetime\n", "from absl import flags" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xVGYtUXyXNuE", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "features = [[1., 1.5], [2., 2.5], [3., 3.5]]\n", "labels = [[0.3], [0.5], [0.7]]\n", "eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]\n", "eval_labels = [[0.8], [0.9], [1.]]" ] }, { "cell_type": "markdown", "metadata": { "id": "ON4zQifT0Vec" }, "source": [ "## TensorFlow 1:使用 tf.estimator API 创建自定义 SessionRunHook\n", "\n", "下面的 TensorFlow 1 示例展示了如何设置自定义 `SessionRunHook` 以在训练期间测量每秒的样本数。创建钩子 (`LoggerHook`) 后,将其传递给 `tf.estimator.Estimator.train` 的 `hooks` 参数。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "S-myEclbXUL7", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def _input_fn():\n", " return tf1.data.Dataset.from_tensor_slices(\n", " (features, labels)).batch(1).repeat(100)\n", "\n", "def _model_fn(features, labels, mode):\n", " logits = tf1.layers.Dense(1)(features)\n", " loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)\n", " optimizer = tf1.train.AdagradOptimizer(0.05)\n", " train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())\n", " return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xd9sPTkO0ZTD", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class LoggerHook(tf1.train.SessionRunHook):\n", " \"\"\"Logs loss and runtime.\"\"\"\n", "\n", " def begin(self):\n", " self._step = -1\n", " self._start_time = time.time()\n", " self.log_frequency = 10\n", "\n", " def before_run(self, run_context):\n", " self._step += 1\n", "\n", " def after_run(self, run_context, run_values):\n", " if self._step % self.log_frequency == 0:\n", " current_time = time.time()\n", " duration = current_time - self._start_time\n", " self._start_time = current_time\n", " examples_per_sec = self.log_frequency / duration\n", " print('Time:', datetime.now(), ', Step #:', self._step,\n", " ', Examples per second:', examples_per_sec)\n", "\n", "estimator = tf1.estimator.Estimator(model_fn=_model_fn)\n", "\n", "# Begin training.\n", "estimator.train(_input_fn, hooks=[LoggerHook()])" ] }, { "cell_type": "markdown", "metadata": { "id": "3uZCDMrM2CEg" }, "source": [ "## TensorFlow 2:为 Model.fit 创建自定义 Keras 回调\n", "\n", "在 TensorFlow 2 中,当您使用内置 Keras `Model.fit`(或 `Model.evaluate`)进行训练/评估时,可以配置自定义 `tf.keras.callbacks.Callback`,然后将其传递给 `Model.fit`(或 `Model.evaluate`)的 `callbacks` 参数。(在[编写自己的回调](../..guide/keras/custom_callback.ipynb)指南中了解详情。)\n", "\n", "在下面的示例中,您将编写一个自定义 `tf.keras.callbacks.Callback` 来记录各种指标 – 它将测量每秒的样本数,这应该与前面的 `SessionRunHook` 示例中的指标相当。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UbMPoiB92KRG", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class CustomCallback(tf.keras.callbacks.Callback):\n", "\n", " def on_train_begin(self, logs = None):\n", " self._step = -1\n", " self._start_time = time.time()\n", " self.log_frequency = 10\n", "\n", " def on_train_batch_begin(self, batch, logs = None):\n", " self._step += 1\n", "\n", " def on_train_batch_end(self, batch, logs = None):\n", " if self._step % self.log_frequency == 0:\n", " current_time = time.time()\n", " duration = current_time - self._start_time\n", " self._start_time = current_time\n", " examples_per_sec = self.log_frequency / duration\n", " print('Time:', datetime.now(), ', Step #:', self._step,\n", " ', Examples per second:', examples_per_sec)\n", "\n", "callback = CustomCallback()\n", "\n", "dataset = tf.data.Dataset.from_tensor_slices(\n", " (features, labels)).batch(1).repeat(100)\n", "\n", "model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])\n", "optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)\n", "\n", "model.compile(optimizer, \"mse\")\n", "\n", "# Begin training.\n", "result = model.fit(dataset, callbacks=[callback], verbose = 0)\n", "# Provide the results of training metrics.\n", "result.history" ] }, { "cell_type": "markdown", "metadata": { "id": "EFqFi21Ftskq" }, "source": [ "## 后续步骤\n", "\n", "通过下列方式详细了解回调:\n", "\n", "- API 文档:`tf.keras.callbacks.Callback`\n", "- 指南:[编写自己的回调](../..guide/keras/custom_callback.ipynb/)\n", "- 指南:[使用内置方法进行训练和评估](https://tensorflow.google.cn/guide/keras/train_and_evaluate)(*使用回调*部分)\n", "\n", "此外,您可能还会发现下列与迁移相关的资源十分有用:\n", "\n", "- [提前停止迁移指南](early_stopping.ipynb):`tf.keras.callbacks.EarlyStopping` 是一个内置的提前停止回调\n", "- [TensorBoard 迁移指南](tensorboard.ipynb):TensorBoard 支持跟踪和显示指标\n", "- [LoggingTensorHook 和 StopAtStepHook 到 Keras 回调迁移指南](logging_stop_hook.ipynb)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "sessionrunhook_callback.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }