{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "g_nWetWWd_ns" }, "outputs": [], "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "2pHVBk_seED1" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "N_fMsQ-N8I7j" }, "outputs": [], "source": [ "#@title MIT License\n", "#\n", "# Copyright (c) 2017 François Chollet\n", "#\n", "# Permission is hereby granted, free of charge, to any person obtaining a\n", "# copy of this software and associated documentation files (the \"Software\"),\n", "# to deal in the Software without restriction, including without limitation\n", "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n", "# and/or sell copies of the Software, and to permit persons to whom the\n", "# Software is furnished to do so, subject to the following conditions:\n", "#\n", "# The above copyright notice and this permission notice shall be included in\n", "# all copies or substantial portions of the Software.\n", "#\n", "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n", "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n", "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n", "# DEALINGS IN THE SOFTWARE." ] }, { "cell_type": "markdown", "metadata": { "id": "pZJ3uY9O17VN" }, "source": [ "# 保存和恢复模型" ] }, { "cell_type": "markdown", "metadata": { "id": "M4Ata7_wMul1" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "mBdde4YJeJKF" }, "source": [ "可以在训练期间和之后保存模型进度。这意味着模型可以从停止的地方恢复,避免长时间的训练。此外,保存还意味着您可以分享您的模型,其他人可以重现您的工作。在发布研究模型和技术时,大多数机器学习从业者会分享:\n", "\n", "- 用于创建模型的代码\n", "- 模型的训练权重或形参\n", "\n", "共享数据有助于其他人了解模型的工作原理,并使用新数据自行尝试。\n", "\n", "小心:TensorFlow 模型是代码,对于不受信任的代码,一定要小心。请参阅 [安全使用 TensorFlow](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) 以了解详情。\n", "\n", "### 选项\n", "\n", "根据您使用的 API,可以通过不同的方式保存 TensorFlow 模型。本指南使用 [tf.keras](https://tensorflow.google.cn/guide/keras) – 一种用于在 TensorFlow 中构建和训练模型的高级 API。建议使用本教程中使用的新的高级 `.keras` 格式来保存 Keras 对象,因为它提供了强大、高效的基于名称的保存,通常比低级或旧版格式更容易调试。如需更高级的保存或序列化工作流,尤其是那些涉及自定义对象的工作流,请参阅[保存和加载 Keras 模型指南](https://tensorflow.google.cn/guide/keras/save_and_serialize)。对于其他方式,请参阅[使用 SavedModel 格式指南](../../guide/saved_model.ipynb)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 设置日志级别为ERROR,以减少警告信息\n", "# 禁用 Gemini 的底层库(gRPC 和 Abseil)在初始化日志警告\n", "os.environ[\"GRPC_VERBOSITY\"] = \"ERROR\"\n", "os.environ[\"GLOG_minloglevel\"] = \"3\" # 0: INFO, 1: WARNING, 2: ERROR, 3: FATAL\n", "os.environ[\"GLOG_minloglevel\"] = \"true\"\n", "import logging\n", "import tensorflow as tf\n", "tf.get_logger().setLevel(logging.ERROR)\n", "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n", "!export TF_FORCE_GPU_ALLOW_GROWTH=true\n", "from pathlib import Path\n", "\n", "temp_dir = Path(\".temp\")\n", "temp_dir.mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "xCUREq7WXgvg" }, "source": [ "## 配置\n", "\n", "### 安装并导入" ] }, { "cell_type": "markdown", "metadata": { "id": "7l0MiTOrXtNv" }, "source": [ "安装并导入Tensorflow和依赖项:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RzIOVSdnMYyO" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pyyaml in /media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages (6.0.2)\n", "Requirement already satisfied: h5py in /media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages (3.11.0)\n", "Requirement already satisfied: numpy>=1.17.3 in /media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages (from h5py) (1.26.3)\n" ] } ], "source": [ "!pip install pyyaml h5py # Required to save models in HDF5 format" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7Nm7Tyb-gRt-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.17.0\n" ] } ], "source": [ "import os\n", "\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "\n", "print(tf.version.VERSION)" ] }, { "cell_type": "markdown", "metadata": { "id": "SbGsznErXWt6" }, "source": [ "### 获取示例数据集\n", "\n", "为了演示如何保存和加载权重,您将使用 [MNIST 数据集](http://yann.lecun.com/exdb/mnist/)。为了加快运行速度,请使用前 1000 个样本:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9rGfFwE9XVwz" }, "outputs": [], "source": [ "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n", "\n", "train_labels = train_labels[:1000]\n", "test_labels = test_labels[:1000]\n", "\n", "train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0\n", "test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0" ] }, { "cell_type": "markdown", "metadata": { "id": "anG3iVoXyZGI" }, "source": [ "### 定义模型" ] }, { "cell_type": "markdown", "metadata": { "id": "wynsOBfby0Pa" }, "source": [ "首先构建一个简单的序列(sequential)模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0HZbJIjxyX1S" }, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential_11\"\n",
              "
\n" ], "text/plain": [ "\u001b[1mModel: \"sequential_11\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
              "┃ Layer (type)                     Output Shape                  Param # ┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
              "│ dense_22 (Dense)                │ ?                      │   0 (unbuilt) │\n",
              "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
              "│ dropout_11 (Dropout)            │ ?                      │   0 (unbuilt) │\n",
              "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
              "│ dense_23 (Dense)                │ ?                      │   0 (unbuilt) │\n",
              "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
              "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ dense_22 (\u001b[38;5;33mDense\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_11 (\u001b[38;5;33mDropout\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_23 (\u001b[38;5;33mDense\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 0 (0.00 B)\n",
              "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 0 (0.00 B)\n",
              "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
              "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Define a simple sequential model\n", "def create_model():\n", " model = tf.keras.Sequential([\n", " keras.layers.Dense(512, activation='relu',),\n", " keras.layers.Dropout(0.2),\n", " keras.layers.Dense(10)\n", " ])\n", "\n", " model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\n", "\n", " return model\n", "\n", "# Create a basic model instance\n", "model = create_model()\n", "\n", "# Display the model's architecture\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "soDE0W_KH8rG" }, "source": [ "## 在训练期间保存模型(以 checkpoints 形式保存)" ] }, { "cell_type": "markdown", "metadata": { "id": "mRyd5qQQIXZm" }, "source": [ "您可以使用经过训练的模型而无需重新训练,或者在训练过程中断的情况下从离开处继续训练。`tf.keras.callbacks.ModelCheckpoint` 回调允许您在训练*期间*和*结束*时持续保存模型。\n", "\n", "### Checkpoint 回调用法\n", "\n", "创建一个只在训练期间保存权重的 `tf.keras.callbacks.ModelCheckpoint` 回调:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IFPuhwntH8VH" }, "outputs": [], "source": [ "checkpoint_path = temp_dir/\"training_1/cp.ckpt\"\n", "checkpoint_dir = os.path.dirname(checkpoint_path)\n", "\n", "# Create a callback that saves the model's weights\n", "cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,\n", " save_weights_only=True,\n", " verbose=1)\n", "\n", "# Train the model with the new callback\n", "model.fit(train_images, \n", " train_labels, \n", " epochs=10,\n", " validation_data=(test_images, test_labels),\n", " callbacks=[cp_callback]) # Pass callback to training\n", "\n", "# This may generate warnings related to saving the state of the optimizer.\n", "# These warnings (and similar warnings throughout this notebook)\n", "# are in place to discourage outdated usage, and can be ignored." ] }, { "cell_type": "markdown", "metadata": { "id": "rlM-sgyJO084" }, "source": [ "这将创建一个 TensorFlow checkpoint 文件集合,这些文件在每个 epoch 结束时更新:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gXG5FVKFOVQ3" }, "outputs": [], "source": [ "os.listdir(checkpoint_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "wlRN_f56Pqa9" }, "source": [ "只要两个模型共享相同的架构,您就可以在它们之间共享权重。因此,当从仅权重恢复模型时,创建一个与原始模型具有相同架构的模型,然后设置其权重。\n", "\n", "现在,重新构建一个未经训练的全新模型并基于测试集对其进行评估。未经训练的模型将以机会水平执行(约 10% 的准确率):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Fp5gbuiaPqCT" }, "outputs": [], "source": [ "# Create a basic model instance\n", "model = create_model()\n", "\n", "# Evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Untrained model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "1DTKpZssRSo3" }, "source": [ "然后从 checkpoint 加载权重并重新评估:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2IZxbwiRRSD2" }, "outputs": [], "source": [ "# Loads the weights\n", "model.load_weights(checkpoint_path)\n", "\n", "# Re-evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "bpAbKkAyVPV8" }, "source": [ "### checkpoint 回调选项\n", "\n", "回调提供了几个选项,为 checkpoint 提供唯一名称并调整 checkpoint 频率。\n", "\n", "训练一个新模型,每五个 epochs 保存一次唯一命名的 checkpoint :" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mQF_dlgIVOvq" }, "outputs": [], "source": [ "# Include the epoch in the file name (uses `str.format`)\n", "checkpoint_path = temp_dir/\"training_2/cp-{epoch:04d}.ckpt\"\n", "checkpoint_dir = os.path.dirname(checkpoint_path)\n", "\n", "batch_size = 32\n", "\n", "# Calculate the number of batches per epoch\n", "import math\n", "n_batches = len(train_images) / batch_size\n", "n_batches = math.ceil(n_batches) # round up the number of batches to the nearest whole integer\n", "\n", "# Create a callback that saves the model's weights every 5 epochs\n", "cp_callback = tf.keras.callbacks.ModelCheckpoint(\n", " filepath=checkpoint_path, \n", " verbose=1, \n", " save_weights_only=True,\n", " save_freq=5*n_batches)\n", "\n", "# Create a new model instance\n", "model = create_model()\n", "\n", "# Save the weights using the `checkpoint_path` format\n", "model.save_weights(checkpoint_path.format(epoch=0))\n", "\n", "# Train the model with the new callback\n", "model.fit(train_images, \n", " train_labels,\n", " epochs=50, \n", " batch_size=batch_size, \n", " callbacks=[cp_callback],\n", " validation_data=(test_images, test_labels),\n", " verbose=0)" ] }, { "cell_type": "markdown", "metadata": { "id": "1zFrKTjjavWI" }, "source": [ "现在,检查生成的检查点并选择最新检查点:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "p64q3-V4sXt0" }, "outputs": [], "source": [ "os.listdir(checkpoint_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1AN_fnuyR41H" }, "outputs": [], "source": [ "latest = tf.train.latest_checkpoint(checkpoint_dir)\n", "latest" ] }, { "cell_type": "markdown", "metadata": { "id": "Zk2ciGbKg561" }, "source": [ "注:默认 TensorFlow 格式只保存最近的 5 个检查点。\n", "\n", "要进行测试,请重置模型并加载最新检查点:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3M04jyK-H3QK" }, "outputs": [], "source": [ "# Create a new model instance\n", "model = create_model()\n", "\n", "# Load the previously saved weights\n", "model.load_weights(latest)\n", "\n", "# Re-evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "c2OxsJOTHxia" }, "source": [ "## 这些文件是什么?" ] }, { "cell_type": "markdown", "metadata": { "id": "JtdYhvWnH2ib" }, "source": [ "上述代码可将权重存储到[检查点](../../guide/checkpoint.ipynb)格式文件(仅包含二进制格式训练权重) 的合集中。检查点包含:\n", "\n", "- 一个或多个包含模型权重的分片。\n", "- 一个索引文件,指示哪些权重存储在哪个分片中。\n", "\n", "如果您在一台计算机上训练模型,您将获得一个具有如下后缀的分片:`.data-00000-of-00001`" ] }, { "cell_type": "markdown", "metadata": { "id": "S_FA-ZvxuXQV" }, "source": [ "## 手动保存权重\n", "\n", "要手动保存权重,请使用 `tf.keras.Model.save_weights`。默认情况下,`tf.keras`(尤其是 `Model.save_weights` 方法)使用扩展名为 `.ckpt` 的 TensorFlow [检查点](../../guide/checkpoint.ipynb)格式。要以扩展名为 `.h5` 的 HDF5 格式保存,请参阅[保存和加载模型](https://tensorflow.google.cn/guide/keras/save_and_serialize)指南。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R7W5plyZ-u9X" }, "outputs": [], "source": [ "# Save the weights\n", "model.save_weights(temp_dir/'./checkpoints/my_checkpoint')\n", "\n", "# Create a new model instance\n", "model = create_model()\n", "\n", "# Restore the weights\n", "model.load_weights(temp_dir/'./checkpoints/my_checkpoint')\n", "\n", "# Evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "kOGlxPRBEvV1" }, "source": [ "## 保存整个模型\n", "\n", "调用 `tf.keras.Model.save`,将模型的架构、权重和训练配置保存在单个 `model.keras` zip 存档中。\n", "\n", "整个模型可以保存为三种不同的文件格式(新的 `.keras` 格式和两种旧格式:`SavedModel` 和 `HDF5`)。将模型保存为 `path/to/model.keras` 会自动以最新格式保存。\n", "\n", "**注意**:对于 Keras 对象,建议使用新的高级 `.keras` 格式进行更丰富的基于名称的保存和重新加载,这样更易于调试。现有代码继续支持低级 SavedModel 格式和旧版 H5 格式。\n", "\n", "您可以通过以下方式切换到 SavedModel 格式:\n", "\n", "- 将 `save_format='tf'` 传递到 `save()`\n", "- 传递不带扩展名的文件名\n", "\n", "您可以通过以下方式切换到 H5 格式:\n", "\n", "- 将 `save_format='h5'` 传递到 `save()`\n", "- 传递以 `.h5` 结尾的文件名\n", "\n", "保存全功能模型会非常有用,您可以在 TensorFlow.js([Saved Model](https://tensorflow.google.cn/js/tutorials/conversion/import_saved_model)、[HDF5](https://tensorflow.google.cn/js/tutorials/conversion/import_keras))中加载它们,然后在网络浏览器中训练和运行,或者使用 TensorFlow Lite([Saved Model](https://tensorflow.google.cn/lite/models/convert/#convert_a_savedmodel_recommended_)、[HDF5](https://tensorflow.google.cn/lite/models/convert/#convert_a_keras_model_))转换它们以在移动设备上运行\n", "\n", "*自定义对象(例如,子类化模型或层)在保存和加载时需要特别注意。请参阅下面的**保存自定义对象**部分。" ] }, { "cell_type": "markdown", "metadata": { "id": "0fRGnlHMrkI7" }, "source": [ "### 新的高级 `.keras` 格式" ] }, { "cell_type": "markdown", "metadata": { "id": "eqO8jj7GsCDn" }, "source": [ "以 `.keras` 扩展名标记的新 Keras v3 保存格式是一种更简单、更高效的格式,它实现了基于名称的保存,从 Python 的角度确保您加载的内容与您保存的内容完全相同。这使得调试更容易,并且它是 Keras 的推荐格式。\n", "\n", "下面的部分说明了如何以 `.keras` 格式保存和恢复模型。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3f55mAXwukUX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 41ms/step - loss: 1.5505 - sparse_categorical_accuracy: 0.5339\n", "Epoch 2/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.4653 - sparse_categorical_accuracy: 0.8494 \n", "Epoch 3/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.2881 - sparse_categorical_accuracy: 0.9291 \n", "Epoch 4/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.2701 - sparse_categorical_accuracy: 0.9293 \n", "Epoch 5/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.1476 - sparse_categorical_accuracy: 0.9672 \n" ] } ], "source": [ "# Create and train a new model instance.\n", "model = create_model()\n", "model.fit(train_images, train_labels, epochs=5)\n", "\n", "# Save the entire model as a `.keras` zip archive.\n", "model.save(temp_dir/'my_model.keras')" ] }, { "cell_type": "markdown", "metadata": { "id": "iHqwaun5g8lD" }, "source": [ "从 `.keras` zip 归档重新加载新的 Keras 模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HyfUMOZwux_-" }, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential_12\"\n",
              "
\n" ], "text/plain": [ "\u001b[1mModel: \"sequential_12\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
              "┃ Layer (type)                     Output Shape                  Param # ┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
              "│ dense_24 (Dense)                │ (None, 512)            │       401,920 │\n",
              "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
              "│ dropout_12 (Dropout)            │ (None, 512)            │             0 │\n",
              "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
              "│ dense_25 (Dense)                │ (None, 10)             │         5,130 │\n",
              "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
              "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ dense_24 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m401,920\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_12 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_25 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m5,130\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 1,221,152 (4.66 MB)\n",
              "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,221,152\u001b[0m (4.66 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 407,050 (1.55 MB)\n",
              "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m407,050\u001b[0m (1.55 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
              "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Optimizer params: 814,102 (3.11 MB)\n",
              "
\n" ], "text/plain": [ "\u001b[1m Optimizer params: \u001b[0m\u001b[38;5;34m814,102\u001b[0m (3.11 MB)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "new_model = tf.keras.models.load_model(temp_dir/'my_model.keras')\n", "\n", "# Show the model architecture\n", "new_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "9Cn3pSBqvJ5f" }, "source": [ "尝试使用加载的模型运行评估和预测:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8BT4mHNIvMdW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 1s - 46ms/step - loss: 0.4289 - sparse_categorical_accuracy: 0.8560\n", "Restored model, accuracy: 85.60%\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 9ms/step\n", "(1000, 10)\n" ] } ], "source": [ "# Evaluate the restored model\n", "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n", "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))\n", "\n", "print(new_model.predict(test_images).shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "kPyhgcoVzqUB" }, "source": [ "### SavedModel 格式" ] }, { "cell_type": "markdown", "metadata": { "id": "LtcN4VIb7JkK" }, "source": [ "SavedModel 格式是另一种序列化模型的方式。以这种格式保存的模型可以使用 `tf.keras.models.load_model` 还原,并且与 TensorFlow Serving 兼容。[SavedModel 指南](../../guide/saved_model.ipynb)详细介绍了如何 `serve/inspect` SavedModel。以下部分说明了保存和恢复模型的步骤。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sI1YvCDFzpl3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 40ms/step - loss: 1.6375 - sparse_categorical_accuracy: 0.4978\n", "Epoch 2/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - loss: 0.4050 - sparse_categorical_accuracy: 0.9061 \n", "Epoch 3/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - loss: 0.3121 - sparse_categorical_accuracy: 0.9231 \n", "Epoch 4/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.1988 - sparse_categorical_accuracy: 0.9600 \n", "Epoch 5/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.1638 - sparse_categorical_accuracy: 0.9647 \n", "Saved artifact at '.temp/saved_model/my_model'. The following endpoints are available:\n", "\n", "* Endpoint 'serve'\n", " args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 784), dtype=tf.float32, name='keras_tensor_50')\n", "Output Type:\n", " TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)\n", "Captures:\n", " 140591918670928: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 140591918669968: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 140591918670544: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 140591918669776: TensorSpec(shape=(), dtype=tf.resource, name=None)\n" ] } ], "source": [ "# Create and train a new model instance.\n", "model = create_model()\n", "model.fit(train_images, train_labels, epochs=5)\n", "\n", "# Save the entire model as a SavedModel.\n", "(temp_dir/\"saved_model\").mkdir(parents=True, exist_ok=True)\n", "model.export(temp_dir/'saved_model/my_model', format='tf_saved_model') " ] }, { "cell_type": "markdown", "metadata": { "id": "iUvT_3qE8hV5" }, "source": [ "SavedModel 格式是一个包含 protobuf 二进制文件和 TensorFlow 检查点的目录。检查保存的模型目录:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sq8fPglI1RWA" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "my_model\n", "assets\tfingerprint.pb\tsaved_model.pb\tvariables\n" ] } ], "source": [ "# my_model directory\n", "!ls {temp_dir}/saved_model\n", "\n", "# Contains an assets folder, saved_model.pb, and variables folder.\n", "!ls {temp_dir}/saved_model/my_model" ] }, { "cell_type": "markdown", "metadata": { "id": "B7qfpvpY9HCe" }, "source": [ "从保存的模型重新加载新的 Keras 模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0YofwHdN0pxa" }, "outputs": [ { "data": { "text/plain": [ "._UserObject at 0x7fdf731c56d0>" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# new_model = tf.keras.models.load_model(temp_dir/'saved_model/my_model')\n", "new_model = tf.saved_model.load(temp_dir/'saved_model/my_model')\n", "\n", "# Check its architecture\n", "# new_model.summary()\n", "new_model" ] }, { "cell_type": "markdown", "metadata": { "id": "uWwgNaz19TH2" }, "source": [ "使用与原始模型相同的实参编译恢复的模型。尝试使用加载的模型运行评估和预测:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Yh5Mu0yOgE5J" }, "outputs": [ { "ename": "TypeError", "evalue": "'_UserObject' object is not callable", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[51], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Evaluate the restored model\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m loss, acc \u001b[38;5;241m=\u001b[39m new_model(test_images, test_labels, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mRestored model, accuracy: \u001b[39m\u001b[38;5;132;01m{:5.2f}\u001b[39;00m\u001b[38;5;124m%\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;241m100\u001b[39m \u001b[38;5;241m*\u001b[39m acc))\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(new_model\u001b[38;5;241m.\u001b[39mpredict(test_images)\u001b[38;5;241m.\u001b[39mshape)\n", "\u001b[0;31mTypeError\u001b[0m: '_UserObject' object is not callable" ] } ], "source": [ "# Evaluate the restored model\n", "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n", "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))\n", "\n", "print(new_model.predict(test_images).shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "SkGwf-50zLNn" }, "source": [ "### HDF5 格式\n", "\n", "Keras 使用 [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) 标准提供基本的旧版高级保存格式。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m2dkmJVCGUia" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 42ms/step - loss: 1.5758 - sparse_categorical_accuracy: 0.5213\n", "Epoch 2/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.4920 - sparse_categorical_accuracy: 0.8493 \n", "Epoch 3/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.2876 - sparse_categorical_accuracy: 0.9243 \n", "Epoch 4/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.2156 - sparse_categorical_accuracy: 0.9522 \n", "Epoch 5/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.1565 - sparse_categorical_accuracy: 0.9612 \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`. \n" ] } ], "source": [ "# Create and train a new model instance.\n", "model = create_model()\n", "model.fit(train_images, train_labels, epochs=5)\n", "\n", "# Save the entire model to a HDF5 file.\n", "# The '.h5' extension indicates that the model should be saved to HDF5.\n", "model.save(temp_dir/'my_model.h5') " ] }, { "cell_type": "markdown", "metadata": { "id": "GWmttMOqS68S" }, "source": [ "现在,从该文件重新创建模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5NDMO_7kS6Do" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" ] }, { "data": { "text/html": [ "
Model: \"sequential_17\"\n",
              "
\n" ], "text/plain": [ "\u001b[1mModel: \"sequential_17\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
              "┃ Layer (type)                     Output Shape                  Param # ┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
              "│ dense_34 (Dense)                │ (None, 512)            │       401,920 │\n",
              "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
              "│ dropout_17 (Dropout)            │ (None, 512)            │             0 │\n",
              "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
              "│ dense_35 (Dense)                │ (None, 10)             │         5,130 │\n",
              "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
              "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ dense_34 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m401,920\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_17 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_35 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m5,130\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 407,052 (1.55 MB)\n",
              "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m407,052\u001b[0m (1.55 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 407,050 (1.55 MB)\n",
              "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m407,050\u001b[0m (1.55 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
              "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Optimizer params: 2 (12.00 B)\n",
              "
\n" ], "text/plain": [ "\u001b[1m Optimizer params: \u001b[0m\u001b[38;5;34m2\u001b[0m (12.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Recreate the exact same model, including its weights and the optimizer\n", "new_model = tf.keras.models.load_model(temp_dir/'my_model.h5')\n", "\n", "# Show the model architecture\n", "new_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "JXQpbTicTBwt" }, "source": [ "检查其准确率(accuracy):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jwEaj9DnTCVA" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 2s - 47ms/step - loss: 0.4254 - sparse_categorical_accuracy: 0.8540\n", "Restored model, accuracy: 85.40%\n" ] } ], "source": [ "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n", "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "dGXqd4wWJl8O" }, "source": [ "Keras 通过检查模型的架构来保存这些模型。这种技术可以保存所有内容:\n", "\n", "- 权重值\n", "- 模型的架构\n", "- 模型的训练配置(您传递给 `.compile()` 方法的内容)\n", "- 优化器及其状态(如果有)(这样,您便可从中断的地方重新启动训练)\n", "\n", "Keras 无法保存 `v1.x` 优化器(来自 `tf.compat.v1.train`),因为它们与检查点不兼容。对于 v1.x 优化器,您需要在加载-失去优化器的状态后,重新编译模型。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "kAUKJQyGqTNH" }, "source": [ "### 保存自定义对象\n", "\n", "如果您使用的是 SavedModel 格式,则可以跳过此部分。高级 `.keras`/HDF5 格式与低级 SavedModel 格式之间的主要区别在于 `.keras`/HDF5 格式使用对象配置来保存模型架构,而 SavedModel 保存执行计算图。因此,SavedModels 能够保存自定义对象,例如子类化模型和自定义层,而无需原始代码。但是,因此调试低级 SavedModels 可能会更加困难,鉴于基于名称并且对于 Keras 是原生的特性,我们建议改用高级 `.keras` 格式。\n", "\n", "要将自定义对象保存到 `.keras` 和 HDF5,您必须执行以下操作:\n", "\n", "1. 在您的对象中定义一个 `get_config` 方法,并且可以选择定义一个 `from_config` 类方法。\n", " - `get_config(self)` 返回重新创建对象所需的形参的 JSON 可序列化字典。\n", " - `from_config(cls, config)` 使用从 `get_config` 返回的配置来创建一个新对象。默认情况下,此函数将使用配置作为初始化 kwarg (`return cls(**config)`)。\n", "2. 通过以下三种方式之一将自定义对象传递给模型:\n", " - 使用 `@tf.keras.utils.register_keras_serializable` 装饰器注册自定义对象。**(推荐)**\n", " - 加载模型时直接将对象传递给 `custom_objects` 实参。实参必须是将字符串类名映射到 Python 类的字典。例如 `tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})`\n", " - 将 `tf.keras.utils.custom_object_scope` 与 `custom_objects` 字典实参中包含的对象一起使用,并在作用域内放置一个 `tf.keras.models.load_model(path){ /code2} 调用。`\n", "\n", "有关自定义对象和 `get_config` 的示例,请参阅[从头开始编写层和模型](https://tensorflow.google.cn/guide/keras/custom_layers_and_models)教程。\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "save_and_load.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }