{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "wJcYs_ERTnnI" }, "outputs": [], "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "HMUDt0CiUJk9" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "77z2OchJTk0l" }, "source": [ "# 迁移检查点保存\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 运行\n", " 在 Github 上查看源代码\n", " 下载笔记本
\n" ] }, { "cell_type": "markdown", "metadata": { "id": "hIo_p2FWFIRx" }, "source": [ "持续保存“最佳”模型或模型权重/参数有许多好处,包括能够跟踪训练进度并从不同的保存状态加载保存的模型。\n", "\n", "在 TensorFlow 1 中,要使用 `tf.estimator.Estimator` API 在训练/验证期间配置检查点保存,可以在 `tf.estimator.RunConfig` 中指定计划或使用 `tf.estimator.CheckpointSaverHook`。本指南演示了如何从该工作流迁移到 TensorFlow 2 Keras API。\n", "\n", "在 TensorFlow 2 中,可以通过多种方式配置 `tf.keras.callbacks.ModelCheckpoint`:\n", "\n", "- 根据使用 `save_best_only=True` 参数监视的指标保存“最佳”版本,其中 `monitor` 可以是 `'loss'`、`'val_loss'`、`'accuracy'` 或 `'val_accuracy'`。\n", "- 以特定频率持续保存(使用 `save_freq` 参数)。\n", "- 通过将 `save_weights_only` 设置为 `True`,仅保存权重/参数而不是整个模型。\n", "\n", "有关详情,请参阅 {class}`tensorflow.keras.callbacks.ModelCheckpoint` API 文档和[保存和加载模型](../../tutorials/keras/save_and_load.ipynb)教程中的*训练期间保存检查点*部分。在[保存和加载 Keras 模型](https://tensorflow.google.cn/guide/keras/save_and_serialize)指南中的 *TF 检查点格式*部分中详细了解检查点格式。另外,要添加容错,可以使用 `tf.keras.callbacks.BackupAndRestore` 或 `tf.train.Checkpoint` 手动设置检查点。在[容错迁移指南](fault_tolerance.ipynb)中了解详情。\n", "\n", "Keras [回调](https://tensorflow.google.cn/guide/keras/custom_callback)是在内置 Keras `Model.fit`/`Model.evaluate`/`Model.predict` API 中的训练/评估/预测期间的不同点调用的对象。请在指南末尾的*后续步骤*部分中了解详情。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 设置日志级别为ERROR,以减少警告信息\n", "# 禁用 Gemini 的底层库(gRPC 和 Abseil)在初始化日志警告\n", "os.environ[\"GRPC_VERBOSITY\"] = \"ERROR\"\n", "os.environ[\"GLOG_minloglevel\"] = \"3\" # 0: INFO, 1: WARNING, 2: ERROR, 3: FATAL\n", "os.environ[\"GLOG_minloglevel\"] = \"true\"\n", "import logging\n", "import tensorflow as tf\n", "tf.get_logger().setLevel(logging.ERROR)\n", "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n", "!export TF_FORCE_GPU_ALLOW_GROWTH=true\n", "\n", "from pathlib import Path\n", "\n", "temp_dir = Path(\".temp\")\n", "temp_dir.mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "f55c103999de" }, "source": [ "## 安装\n", "\n", "从导入和用于演示目的的简单数据集开始:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "X74yjOb-e18w" }, "outputs": [], "source": [ "import tensorflow.compat.v1 as tf1\n", "import tensorflow as tf\n", "import numpy as np\n", "import tempfile" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "2r8r4d8FfMny" }, "outputs": [], "source": [ "mnist = tf.keras.datasets.mnist\n", "\n", "(x_train, y_train),(x_test, y_test) = mnist.load_data()\n", "x_train, x_test = x_train / 255.0, x_test / 255.0" ] }, { "cell_type": "markdown", "metadata": { "id": "wrqBkG4RFLP_" }, "source": [ "## TensorFlow 1:使用 tf.estimator API 保存检查点\n", "\n", "此 TensorFlow 1 示例展示了如何配置 `tf.estimator.RunConfig` 以在使用 `tf.estimator.Estimator` API 进行训练/评估期间的每一步保存检查点:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "upA8nuf3FEq5" }, "outputs": [], "source": [ "feature_columns = [tf1.feature_column.numeric_column(\"x\", shape=[28, 28])]\n", "\n", "config = tf1.c.RunConfig(save_summary_steps=1,\n", " save_checkpoints_steps=1)\n", "\n", "path = tempfile.mkdtemp(dir=temp_dir)\n", "\n", "classifier = tf1.estimator.DNNClassifier(\n", " feature_columns=feature_columns,\n", " hidden_units=[256, 32],\n", " optimizer=tf1.train.AdamOptimizer(0.001),\n", " n_classes=10,\n", " dropout=0.2,\n", " model_dir=path,\n", " config = config\n", ")\n", "\n", "train_input_fn = tf1.estimator.inputs.numpy_input_fn(\n", " x={\"x\": x_train},\n", " y=y_train.astype(np.int32),\n", " num_epochs=10,\n", " batch_size=50,\n", " shuffle=True,\n", ")\n", "\n", "test_input_fn = tf1.estimator.inputs.numpy_input_fn(\n", " x={\"x\": x_test},\n", " y=y_test.astype(np.int32),\n", " num_epochs=10,\n", " shuffle=False\n", ")\n", "\n", "train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)\n", "eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,\n", " steps=10,\n", " throttle_secs=0)\n", "\n", "tf1.estimator.train_and_evaluate(estimator=classifier,\n", " train_spec=train_spec,\n", " eval_spec=eval_spec)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3u96G4MtRVqU" }, "outputs": [], "source": [ "%ls {classifier.model_dir}" ] }, { "cell_type": "markdown", "metadata": { "id": "QvE_uxDJFUX-" }, "source": [ "## TensorFlow 2:使用 Model.fit 的 Keras 回调保存检查点\n", "\n", "在 TensorFlow 2 中,使用内置 Keras `Model.fit`(或 `Model.evaluate`)进行训练/评估时,可以配置 `tf.keras.callbacks.ModelCheckpoint`,然后将其传递给 `Model.fit`(或 `Model.evaluate`)的 `callbacks` 参数。(请在 API 文档和[使用内置方法进行训练和评估](https://tensorflow.google.cn/guide/keras/train_and_evaluate)指南中的*使用回调*部分中了解详情。)\n", "\n", "在下面的示例中,您将使用 `tf.keras.callbacks.ModelCheckpoint` 回调将检查点存储在临时目录中:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "9FLBhT2BFX2H" }, "outputs": [ { "ename": "RuntimeError", "evalue": "Bad StatusOr access: INTERNAL: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 25428426752", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[15], line 20\u001b[0m\n\u001b[1;32m 15\u001b[0m log_dir \u001b[38;5;241m=\u001b[39m tempfile\u001b[38;5;241m.\u001b[39mmkdtemp(\u001b[38;5;28mdir\u001b[39m\u001b[38;5;241m=\u001b[39mtemp_dir)\n\u001b[1;32m 17\u001b[0m model_checkpoint_callback \u001b[38;5;241m=\u001b[39m tf\u001b[38;5;241m.\u001b[39mkeras\u001b[38;5;241m.\u001b[39mcallbacks\u001b[38;5;241m.\u001b[39mModelCheckpoint(\n\u001b[1;32m 18\u001b[0m filepath\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mlog_dir\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/test.keras\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 20\u001b[0m model\u001b[38;5;241m.\u001b[39mfit(x\u001b[38;5;241m=\u001b[39mx_train,\n\u001b[1;32m 21\u001b[0m y\u001b[38;5;241m=\u001b[39my_train,\n\u001b[1;32m 22\u001b[0m epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m,\n\u001b[1;32m 23\u001b[0m validation_data\u001b[38;5;241m=\u001b[39m(x_test, y_test),\n\u001b[1;32m 24\u001b[0m callbacks\u001b[38;5;241m=\u001b[39m[model_checkpoint_callback])\n", "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122\u001b[0m, in \u001b[0;36mfilter_traceback..error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 119\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[1;32m 120\u001b[0m \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;66;03m# `keras.config.disable_traceback_filtering()`\u001b[39;00m\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 123\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n", "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/context.py:657\u001b[0m, in \u001b[0;36mContext.ensure_initialized\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 654\u001b[0m pywrap_tfe\u001b[38;5;241m.\u001b[39mTFE_ContextOptionsSetRunEagerOpAsFunction(opts, \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 655\u001b[0m pywrap_tfe\u001b[38;5;241m.\u001b[39mTFE_ContextOptionsSetJitCompileRewrite(\n\u001b[1;32m 656\u001b[0m opts, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jit_compile_rewrite)\n\u001b[0;32m--> 657\u001b[0m context_handle \u001b[38;5;241m=\u001b[39m pywrap_tfe\u001b[38;5;241m.\u001b[39mTFE_NewContext(opts)\n\u001b[1;32m 658\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 659\u001b[0m pywrap_tfe\u001b[38;5;241m.\u001b[39mTFE_DeleteContextOptions(opts)\n", "\u001b[0;31mRuntimeError\u001b[0m: Bad StatusOr access: INTERNAL: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 25428426752" ] } ], "source": [ "def create_model():\n", " return tf.keras.models.Sequential([\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(512, activation='relu'),\n", " tf.keras.layers.Dropout(0.2),\n", " tf.keras.layers.Dense(10, activation='softmax')\n", " ])\n", "\n", "model = create_model()\n", "model.compile(optimizer='adam',\n", " loss='sparse_categorical_crossentropy',\n", " metrics=['accuracy'],\n", " steps_per_execution=10)\n", "\n", "log_dir = tempfile.mkdtemp(dir=temp_dir)\n", "\n", "model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n", " filepath=f\"{log_dir}/test.keras\")\n", "\n", "model.fit(x=x_train,\n", " y=y_train,\n", " epochs=10,\n", " validation_data=(x_test, y_test),\n", " callbacks=[model_checkpoint_callback])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SROSmhyyLBA-" }, "outputs": [], "source": [ "%ls {model_checkpoint_callback.filepath}" ] }, { "cell_type": "markdown", "metadata": { "id": "rQUS8nO9FZlH" }, "source": [ "## 后续步骤\n", "\n", "在以下资源中详细了解检查点:\n", "\n", "- API 文档:`tf.keras.callbacks.ModelCheckpoint`\n", "- 教程:[保存和加载模型](../../tutorials/keras/save_and_load.ipynb)(*训练期间保存检查点*部分)\n", "- 指南:[保存和加载 Keras 模型](https://tensorflow.google.cn/guide/keras/save_and_serialize)(*TF 检查点格式*部分)\n", "\n", "以下资源中详细了解回调:\n", "\n", "- API 文档:`tf.keras.callbacks.Callback`\n", "- 指南:[编写自己的回调](https://tensorflow.google.cn/guide/keras/guide/keras/custom_callback)\n", "- 指南:[使用内置方法进行训练和评估](https://tensorflow.google.cn/guide/keras/train_and_evaluate)(*使用回调*部分)\n", "\n", "此外,您可能还会发现下列与迁移相关的资源十分有用:\n", "\n", "- [容错迁移指南](fault_tolerance.ipynb):用于 `Model.fit` 的 `tf.keras.callbacks.BackupAndRestore`,或用于自定义训练循环的 `tf.train.Checkpoint` 和 `tf.train.CheckpointManager` API\n", "- [提前停止迁移指南](early_stopping.ipynb):`tf.keras.callbacks.EarlyStopping` 是一个内置的提前停止回调\n", "- [TensorBoard 迁移指南](tensorboard.ipynb):TensorBoard 支持跟踪和显示指标\n", "- [LoggingTensorHook 和 StopAtStepHook 到 Keras 回调迁移指南](logging_stop_hook.ipynb)\n", "- [Keras 回调的 SessionRunHook 指南](sessionrunhook_callback.ipynb)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "checkpoint_saver.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "xxx", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 0 }