{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "g_nWetWWd_ns" }, "outputs": [], "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "2pHVBk_seED1" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "N_fMsQ-N8I7j" }, "outputs": [], "source": [ "#@title MIT License\n", "#\n", "# Copyright (c) 2017 François Chollet\n", "#\n", "# Permission is hereby granted, free of charge, to any person obtaining a\n", "# copy of this software and associated documentation files (the \"Software\"),\n", "# to deal in the Software without restriction, including without limitation\n", "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n", "# and/or sell copies of the Software, and to permit persons to whom the\n", "# Software is furnished to do so, subject to the following conditions:\n", "#\n", "# The above copyright notice and this permission notice shall be included in\n", "# all copies or substantial portions of the Software.\n", "#\n", "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n", "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n", "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n", "# DEALINGS IN THE SOFTWARE." ] }, { "cell_type": "markdown", "metadata": { "id": "pZJ3uY9O17VN" }, "source": [ "# 保存和恢复模型" ] }, { "cell_type": "markdown", "metadata": { "id": "M4Ata7_wMul1" }, "source": [ "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
Model: \"sequential_11\"\n",
"
\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential_11\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ dense_22 (Dense) │ ? │ 0 (unbuilt) │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_11 (Dropout) │ ? │ 0 (unbuilt) │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_23 (Dense) │ ? │ 0 (unbuilt) │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n", "\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ dense_22 (\u001b[38;5;33mDense\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_11 (\u001b[38;5;33mDropout\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_23 (\u001b[38;5;33mDense\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Total params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Define a simple sequential model\n", "def create_model():\n", " model = tf.keras.Sequential([\n", " keras.layers.Dense(512, activation='relu',),\n", " keras.layers.Dropout(0.2),\n", " keras.layers.Dense(10)\n", " ])\n", "\n", " model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\n", "\n", " return model\n", "\n", "# Create a basic model instance\n", "model = create_model()\n", "\n", "# Display the model's architecture\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "soDE0W_KH8rG" }, "source": [ "## 在训练期间保存模型(以 checkpoints 形式保存)" ] }, { "cell_type": "markdown", "metadata": { "id": "mRyd5qQQIXZm" }, "source": [ "您可以使用经过训练的模型而无需重新训练,或者在训练过程中断的情况下从离开处继续训练。`tf.keras.callbacks.ModelCheckpoint` 回调允许您在训练*期间*和*结束*时持续保存模型。\n", "\n", "### Checkpoint 回调用法\n", "\n", "创建一个只在训练期间保存权重的 `tf.keras.callbacks.ModelCheckpoint` 回调:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IFPuhwntH8VH" }, "outputs": [], "source": [ "checkpoint_path = temp_dir/\"training_1/cp.ckpt\"\n", "checkpoint_dir = os.path.dirname(checkpoint_path)\n", "\n", "# Create a callback that saves the model's weights\n", "cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,\n", " save_weights_only=True,\n", " verbose=1)\n", "\n", "# Train the model with the new callback\n", "model.fit(train_images, \n", " train_labels, \n", " epochs=10,\n", " validation_data=(test_images, test_labels),\n", " callbacks=[cp_callback]) # Pass callback to training\n", "\n", "# This may generate warnings related to saving the state of the optimizer.\n", "# These warnings (and similar warnings throughout this notebook)\n", "# are in place to discourage outdated usage, and can be ignored." ] }, { "cell_type": "markdown", "metadata": { "id": "rlM-sgyJO084" }, "source": [ "这将创建一个 TensorFlow checkpoint 文件集合,这些文件在每个 epoch 结束时更新:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gXG5FVKFOVQ3" }, "outputs": [], "source": [ "os.listdir(checkpoint_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "wlRN_f56Pqa9" }, "source": [ "只要两个模型共享相同的架构,您就可以在它们之间共享权重。因此,当从仅权重恢复模型时,创建一个与原始模型具有相同架构的模型,然后设置其权重。\n", "\n", "现在,重新构建一个未经训练的全新模型并基于测试集对其进行评估。未经训练的模型将以机会水平执行(约 10% 的准确率):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Fp5gbuiaPqCT" }, "outputs": [], "source": [ "# Create a basic model instance\n", "model = create_model()\n", "\n", "# Evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Untrained model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "1DTKpZssRSo3" }, "source": [ "然后从 checkpoint 加载权重并重新评估:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2IZxbwiRRSD2" }, "outputs": [], "source": [ "# Loads the weights\n", "model.load_weights(checkpoint_path)\n", "\n", "# Re-evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "bpAbKkAyVPV8" }, "source": [ "### checkpoint 回调选项\n", "\n", "回调提供了几个选项,为 checkpoint 提供唯一名称并调整 checkpoint 频率。\n", "\n", "训练一个新模型,每五个 epochs 保存一次唯一命名的 checkpoint :" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mQF_dlgIVOvq" }, "outputs": [], "source": [ "# Include the epoch in the file name (uses `str.format`)\n", "checkpoint_path = temp_dir/\"training_2/cp-{epoch:04d}.ckpt\"\n", "checkpoint_dir = os.path.dirname(checkpoint_path)\n", "\n", "batch_size = 32\n", "\n", "# Calculate the number of batches per epoch\n", "import math\n", "n_batches = len(train_images) / batch_size\n", "n_batches = math.ceil(n_batches) # round up the number of batches to the nearest whole integer\n", "\n", "# Create a callback that saves the model's weights every 5 epochs\n", "cp_callback = tf.keras.callbacks.ModelCheckpoint(\n", " filepath=checkpoint_path, \n", " verbose=1, \n", " save_weights_only=True,\n", " save_freq=5*n_batches)\n", "\n", "# Create a new model instance\n", "model = create_model()\n", "\n", "# Save the weights using the `checkpoint_path` format\n", "model.save_weights(checkpoint_path.format(epoch=0))\n", "\n", "# Train the model with the new callback\n", "model.fit(train_images, \n", " train_labels,\n", " epochs=50, \n", " batch_size=batch_size, \n", " callbacks=[cp_callback],\n", " validation_data=(test_images, test_labels),\n", " verbose=0)" ] }, { "cell_type": "markdown", "metadata": { "id": "1zFrKTjjavWI" }, "source": [ "现在,检查生成的检查点并选择最新检查点:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "p64q3-V4sXt0" }, "outputs": [], "source": [ "os.listdir(checkpoint_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1AN_fnuyR41H" }, "outputs": [], "source": [ "latest = tf.train.latest_checkpoint(checkpoint_dir)\n", "latest" ] }, { "cell_type": "markdown", "metadata": { "id": "Zk2ciGbKg561" }, "source": [ "注:默认 TensorFlow 格式只保存最近的 5 个检查点。\n", "\n", "要进行测试,请重置模型并加载最新检查点:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3M04jyK-H3QK" }, "outputs": [], "source": [ "# Create a new model instance\n", "model = create_model()\n", "\n", "# Load the previously saved weights\n", "model.load_weights(latest)\n", "\n", "# Re-evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "c2OxsJOTHxia" }, "source": [ "## 这些文件是什么?" ] }, { "cell_type": "markdown", "metadata": { "id": "JtdYhvWnH2ib" }, "source": [ "上述代码可将权重存储到[检查点](../../guide/checkpoint.ipynb)格式文件(仅包含二进制格式训练权重) 的合集中。检查点包含:\n", "\n", "- 一个或多个包含模型权重的分片。\n", "- 一个索引文件,指示哪些权重存储在哪个分片中。\n", "\n", "如果您在一台计算机上训练模型,您将获得一个具有如下后缀的分片:`.data-00000-of-00001`" ] }, { "cell_type": "markdown", "metadata": { "id": "S_FA-ZvxuXQV" }, "source": [ "## 手动保存权重\n", "\n", "要手动保存权重,请使用 `tf.keras.Model.save_weights`。默认情况下,`tf.keras`(尤其是 `Model.save_weights` 方法)使用扩展名为 `.ckpt` 的 TensorFlow [检查点](../../guide/checkpoint.ipynb)格式。要以扩展名为 `.h5` 的 HDF5 格式保存,请参阅[保存和加载模型](https://tensorflow.google.cn/guide/keras/save_and_serialize)指南。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R7W5plyZ-u9X" }, "outputs": [], "source": [ "# Save the weights\n", "model.save_weights(temp_dir/'./checkpoints/my_checkpoint')\n", "\n", "# Create a new model instance\n", "model = create_model()\n", "\n", "# Restore the weights\n", "model.load_weights(temp_dir/'./checkpoints/my_checkpoint')\n", "\n", "# Evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "kOGlxPRBEvV1" }, "source": [ "## 保存整个模型\n", "\n", "调用 `tf.keras.Model.save`,将模型的架构、权重和训练配置保存在单个 `model.keras` zip 存档中。\n", "\n", "整个模型可以保存为三种不同的文件格式(新的 `.keras` 格式和两种旧格式:`SavedModel` 和 `HDF5`)。将模型保存为 `path/to/model.keras` 会自动以最新格式保存。\n", "\n", "**注意**:对于 Keras 对象,建议使用新的高级 `.keras` 格式进行更丰富的基于名称的保存和重新加载,这样更易于调试。现有代码继续支持低级 SavedModel 格式和旧版 H5 格式。\n", "\n", "您可以通过以下方式切换到 SavedModel 格式:\n", "\n", "- 将 `save_format='tf'` 传递到 `save()`\n", "- 传递不带扩展名的文件名\n", "\n", "您可以通过以下方式切换到 H5 格式:\n", "\n", "- 将 `save_format='h5'` 传递到 `save()`\n", "- 传递以 `.h5` 结尾的文件名\n", "\n", "保存全功能模型会非常有用,您可以在 TensorFlow.js([Saved Model](https://tensorflow.google.cn/js/tutorials/conversion/import_saved_model)、[HDF5](https://tensorflow.google.cn/js/tutorials/conversion/import_keras))中加载它们,然后在网络浏览器中训练和运行,或者使用 TensorFlow Lite([Saved Model](https://tensorflow.google.cn/lite/models/convert/#convert_a_savedmodel_recommended_)、[HDF5](https://tensorflow.google.cn/lite/models/convert/#convert_a_keras_model_))转换它们以在移动设备上运行\n", "\n", "*自定义对象(例如,子类化模型或层)在保存和加载时需要特别注意。请参阅下面的**保存自定义对象**部分。" ] }, { "cell_type": "markdown", "metadata": { "id": "0fRGnlHMrkI7" }, "source": [ "### 新的高级 `.keras` 格式" ] }, { "cell_type": "markdown", "metadata": { "id": "eqO8jj7GsCDn" }, "source": [ "以 `.keras` 扩展名标记的新 Keras v3 保存格式是一种更简单、更高效的格式,它实现了基于名称的保存,从 Python 的角度确保您加载的内容与您保存的内容完全相同。这使得调试更容易,并且它是 Keras 的推荐格式。\n", "\n", "下面的部分说明了如何以 `.keras` 格式保存和恢复模型。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3f55mAXwukUX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 41ms/step - loss: 1.5505 - sparse_categorical_accuracy: 0.5339\n", "Epoch 2/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.4653 - sparse_categorical_accuracy: 0.8494 \n", "Epoch 3/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.2881 - sparse_categorical_accuracy: 0.9291 \n", "Epoch 4/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.2701 - sparse_categorical_accuracy: 0.9293 \n", "Epoch 5/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.1476 - sparse_categorical_accuracy: 0.9672 \n" ] } ], "source": [ "# Create and train a new model instance.\n", "model = create_model()\n", "model.fit(train_images, train_labels, epochs=5)\n", "\n", "# Save the entire model as a `.keras` zip archive.\n", "model.save(temp_dir/'my_model.keras')" ] }, { "cell_type": "markdown", "metadata": { "id": "iHqwaun5g8lD" }, "source": [ "从 `.keras` zip 归档重新加载新的 Keras 模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HyfUMOZwux_-" }, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential_12\"\n",
"
\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential_12\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ dense_24 (Dense) │ (None, 512) │ 401,920 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_12 (Dropout) │ (None, 512) │ 0 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_25 (Dense) │ (None, 10) │ 5,130 │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n", "\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ dense_24 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m401,920\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_12 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_25 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m5,130\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Total params: 1,221,152 (4.66 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,221,152\u001b[0m (4.66 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 407,050 (1.55 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m407,050\u001b[0m (1.55 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Optimizer params: 814,102 (3.11 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Optimizer params: \u001b[0m\u001b[38;5;34m814,102\u001b[0m (3.11 MB)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "new_model = tf.keras.models.load_model(temp_dir/'my_model.keras')\n", "\n", "# Show the model architecture\n", "new_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "9Cn3pSBqvJ5f" }, "source": [ "尝试使用加载的模型运行评估和预测:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8BT4mHNIvMdW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 1s - 46ms/step - loss: 0.4289 - sparse_categorical_accuracy: 0.8560\n", "Restored model, accuracy: 85.60%\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 9ms/step\n", "(1000, 10)\n" ] } ], "source": [ "# Evaluate the restored model\n", "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n", "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))\n", "\n", "print(new_model.predict(test_images).shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "kPyhgcoVzqUB" }, "source": [ "### SavedModel 格式" ] }, { "cell_type": "markdown", "metadata": { "id": "LtcN4VIb7JkK" }, "source": [ "SavedModel 格式是另一种序列化模型的方式。以这种格式保存的模型可以使用 `tf.keras.models.load_model` 还原,并且与 TensorFlow Serving 兼容。[SavedModel 指南](../../guide/saved_model.ipynb)详细介绍了如何 `serve/inspect` SavedModel。以下部分说明了保存和恢复模型的步骤。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sI1YvCDFzpl3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 40ms/step - loss: 1.6375 - sparse_categorical_accuracy: 0.4978\n", "Epoch 2/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - loss: 0.4050 - sparse_categorical_accuracy: 0.9061 \n", "Epoch 3/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - loss: 0.3121 - sparse_categorical_accuracy: 0.9231 \n", "Epoch 4/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.1988 - sparse_categorical_accuracy: 0.9600 \n", "Epoch 5/5\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.1638 - sparse_categorical_accuracy: 0.9647 \n", "Saved artifact at '.temp/saved_model/my_model'. The following endpoints are available:\n", "\n", "* Endpoint 'serve'\n", " args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 784), dtype=tf.float32, name='keras_tensor_50')\n", "Output Type:\n", " TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)\n", "Captures:\n", " 140591918670928: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 140591918669968: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 140591918670544: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 140591918669776: TensorSpec(shape=(), dtype=tf.resource, name=None)\n" ] } ], "source": [ "# Create and train a new model instance.\n", "model = create_model()\n", "model.fit(train_images, train_labels, epochs=5)\n", "\n", "# Save the entire model as a SavedModel.\n", "(temp_dir/\"saved_model\").mkdir(parents=True, exist_ok=True)\n", "model.export(temp_dir/'saved_model/my_model', format='tf_saved_model') " ] }, { "cell_type": "markdown", "metadata": { "id": "iUvT_3qE8hV5" }, "source": [ "SavedModel 格式是一个包含 protobuf 二进制文件和 TensorFlow 检查点的目录。检查保存的模型目录:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sq8fPglI1RWA" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "my_model\n", "assets\tfingerprint.pb\tsaved_model.pb\tvariables\n" ] } ], "source": [ "# my_model directory\n", "!ls {temp_dir}/saved_model\n", "\n", "# Contains an assets folder, saved_model.pb, and variables folder.\n", "!ls {temp_dir}/saved_model/my_model" ] }, { "cell_type": "markdown", "metadata": { "id": "B7qfpvpY9HCe" }, "source": [ "从保存的模型重新加载新的 Keras 模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0YofwHdN0pxa" }, "outputs": [ { "data": { "text/plain": [ "
Model: \"sequential_17\"\n",
"
\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential_17\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ dense_34 (Dense) │ (None, 512) │ 401,920 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_17 (Dropout) │ (None, 512) │ 0 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_35 (Dense) │ (None, 10) │ 5,130 │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n", "\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ dense_34 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m401,920\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_17 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_35 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m5,130\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Total params: 407,052 (1.55 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m407,052\u001b[0m (1.55 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 407,050 (1.55 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m407,050\u001b[0m (1.55 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Optimizer params: 2 (12.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Optimizer params: \u001b[0m\u001b[38;5;34m2\u001b[0m (12.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Recreate the exact same model, including its weights and the optimizer\n", "new_model = tf.keras.models.load_model(temp_dir/'my_model.h5')\n", "\n", "# Show the model architecture\n", "new_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "JXQpbTicTBwt" }, "source": [ "检查其准确率(accuracy):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jwEaj9DnTCVA" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 2s - 47ms/step - loss: 0.4254 - sparse_categorical_accuracy: 0.8540\n", "Restored model, accuracy: 85.40%\n" ] } ], "source": [ "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n", "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "dGXqd4wWJl8O" }, "source": [ "Keras 通过检查模型的架构来保存这些模型。这种技术可以保存所有内容:\n", "\n", "- 权重值\n", "- 模型的架构\n", "- 模型的训练配置(您传递给 `.compile()` 方法的内容)\n", "- 优化器及其状态(如果有)(这样,您便可从中断的地方重新启动训练)\n", "\n", "Keras 无法保存 `v1.x` 优化器(来自 `tf.compat.v1.train`),因为它们与检查点不兼容。对于 v1.x 优化器,您需要在加载-失去优化器的状态后,重新编译模型。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "kAUKJQyGqTNH" }, "source": [ "### 保存自定义对象\n", "\n", "如果您使用的是 SavedModel 格式,则可以跳过此部分。高级 `.keras`/HDF5 格式与低级 SavedModel 格式之间的主要区别在于 `.keras`/HDF5 格式使用对象配置来保存模型架构,而 SavedModel 保存执行计算图。因此,SavedModels 能够保存自定义对象,例如子类化模型和自定义层,而无需原始代码。但是,因此调试低级 SavedModels 可能会更加困难,鉴于基于名称并且对于 Keras 是原生的特性,我们建议改用高级 `.keras` 格式。\n", "\n", "要将自定义对象保存到 `.keras` 和 HDF5,您必须执行以下操作:\n", "\n", "1. 在您的对象中定义一个 `get_config` 方法,并且可以选择定义一个 `from_config` 类方法。\n", " - `get_config(self)` 返回重新创建对象所需的形参的 JSON 可序列化字典。\n", " - `from_config(cls, config)` 使用从 `get_config` 返回的配置来创建一个新对象。默认情况下,此函数将使用配置作为初始化 kwarg (`return cls(**config)`)。\n", "2. 通过以下三种方式之一将自定义对象传递给模型:\n", " - 使用 `@tf.keras.utils.register_keras_serializable` 装饰器注册自定义对象。**(推荐)**\n", " - 加载模型时直接将对象传递给 `custom_objects` 实参。实参必须是将字符串类名映射到 Python 类的字典。例如 `tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})`\n", " - 将 `tf.keras.utils.custom_object_scope` 与 `custom_objects` 字典实参中包含的对象一起使用,并在作用域内放置一个 `tf.keras.models.load_model(path){ /code2} 调用。`\n", "\n", "有关自定义对象和 `get_config` 的示例,请参阅[从头开始编写层和模型](https://tensorflow.google.cn/guide/keras/custom_layers_and_models)教程。\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "save_and_load.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }