{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "1l8bWGmIJuQa" }, "outputs": [], "source": [ "##### Copyright 2019 The TensorFlow Authors.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "CPSnXS88KFEo", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "89xNCIO5hiCj" }, "source": [ "# 使用分布策略保存和加载模型" ] }, { "cell_type": "markdown", "metadata": { "id": "9Ejs4QVxIdAm" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看在 Google Colab 中运行在 Github 上查看源代码下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "A0lG6qgThxAS" }, "source": [ "## 概述\n", "\n", "本教程演示了如何在训练期间或训练之后使用 `tf.distribute.Strategy` 以 SavedModel 格式保存和加载模型。有两种用于保存和加载 Keras 模型的 API:高级(`tf.keras.Model.save` 和 `tf.keras.models.load_model`)和低级(`tf.saved_model.save` 和 `tf.saved_model.load`)。\n", "\n", "要全面了解 SavedModel 和序列化,请阅读[已保存模型指南](../../guide/saved_model.ipynb)和 [Keras 模型序列化指南](https://tensorflow.google.cn/guide/keras/save_and_serialize)。我们从一个简单的示例开始。\n", "\n", "小心:TensorFlow 模型是代码,对于不受信任的代码,一定要小心。请参阅[安全使用 TensorFlow](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) 以了解详情。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "FITHltVKQ4eZ" }, "source": [ "导入依赖项:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RWG5HchAiOrZ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import tensorflow_datasets as tfds\n", "\n", "import tensorflow as tf\n" ] }, { "cell_type": "markdown", "metadata": { "id": "qqapWj98ptNV" }, "source": [ "使用 TensorFlow Datasets 和 `tf.data` 加载和准备数据,并使用 `tf.distribute.MirroredStrategy` 创建模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yrYiAf_ziRyw", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "\n", "def get_data():\n", " datasets = tfds.load(name='mnist', as_supervised=True)\n", " mnist_train, mnist_test = datasets['train'], datasets['test']\n", "\n", " BUFFER_SIZE = 10000\n", "\n", " BATCH_SIZE_PER_REPLICA = 64\n", " BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync\n", "\n", " def scale(image, label):\n", " image = tf.cast(image, tf.float32)\n", " image /= 255\n", "\n", " return image, label\n", "\n", " train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)\n", " eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)\n", "\n", " return train_dataset, eval_dataset\n", "\n", "def get_model():\n", " with mirrored_strategy.scope():\n", " model = tf.keras.Sequential([\n", " tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),\n", " tf.keras.layers.MaxPooling2D(),\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(64, activation='relu'),\n", " tf.keras.layers.Dense(10)\n", " ])\n", "\n", " model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " optimizer=tf.keras.optimizers.Adam(),\n", " metrics=[tf.metrics.SparseCategoricalAccuracy()])\n", " return model" ] }, { "cell_type": "markdown", "metadata": { "id": "qmU4Y3feS9Na" }, "source": [ "使用 `tf.keras.Model.fit` 训练模型: " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zmGurbJmS_vN", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model = get_model()\n", "train_dataset, eval_dataset = get_data()\n", "model.fit(train_dataset, epochs=2)" ] }, { "cell_type": "markdown", "metadata": { "id": "L01wjgvRizHS" }, "source": [ "## 保存和加载模型\n", "\n", "现在,您已经有一个简单的模型可供使用,让我们探索保存/加载 API。有两种可用的 API:\n", "\n", "- 高级 (Keras):`Model.save` 和 `tf.keras.models.load_model`(`.keras` zip 存档格式)\n", "- 低级:`tf.saved_model.save` 和 `tf.saved_model.load`(TF SavedModel 格式)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "FX_IF2F1tvFs" }, "source": [ "### Keras API" ] }, { "cell_type": "markdown", "metadata": { "id": "O8xfceg4Z3H_" }, "source": [ "以下为使用 Keras API 保存和加载模型的示例:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LYOStjV5knTQ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "keras_model_path = '/tmp/keras_save.keras'\n", "model.save(keras_model_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "yvQIdQp3zNMp" }, "source": [ "恢复无 `tf.distribute.Strategy` 的模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WrXAAVtrzRgv", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "restored_keras_model = tf.keras.models.load_model(keras_model_path)\n", "restored_keras_model.fit(train_dataset, epochs=2)" ] }, { "cell_type": "markdown", "metadata": { "id": "gYAnskzorda-" }, "source": [ "恢复模型后,您可以继在它上面续训练,甚至不需要再次调用 `Model.compile`,因为它在保存之前已经编译。模型以 Keras zip 存档格式保存,由 `.keras` 扩展名标记。有关详情,请参阅 [Keras 保存指南](https://tensorflow.google.cn/guide/keras/save_and_serialize)。\n", "\n", "现在,恢复模型并使用 `tf.distribute.Strategy` 对其进行训练:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wROPrJaAqBQz", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')\n", "with another_strategy.scope():\n", " restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)\n", " restored_keras_model_ds.fit(train_dataset, epochs=2)" ] }, { "cell_type": "markdown", "metadata": { "id": "PdiiPmL5tQk5" }, "source": [ "正如 `Model.fit` 输出所示,`tf.distribute.Strategy` 可以按预期进行加载。此处使用的策略不必与保存前所用策略相同。 " ] }, { "cell_type": "markdown", "metadata": { "id": "3CrXIbmFt0f6" }, "source": [ "### `tf.saved_model` API" ] }, { "cell_type": "markdown", "metadata": { "id": "HtGzPp6et4Em" }, "source": [ "使用较低级别的 API 保存模型类似于 Keras API:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4y6T31APuCqK", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model = get_model() # get a fresh model\n", "saved_model_path = '/tmp/tf_save'\n", "tf.saved_model.save(model, saved_model_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "q1QNRYcwuRll" }, "source": [ "可以使用 `tf.saved_model.load` 进行加载。但是,由于它是一个较低级别的 API(因此用例范围更广泛),不会返回 Keras 模型。相反,它会返回一个对象,其中包含可用于进行推断的函数。例如:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aaEKqBSPwAuM", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "DEFAULT_FUNCTION_KEY = 'serving_default'\n", "loaded = tf.saved_model.load(saved_model_path)\n", "inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]" ] }, { "cell_type": "markdown", "metadata": { "id": "x65l7AaHUZCA" }, "source": [ "加载的对象可能包含多个函数,每个函数与一个键关联。`\"serving_default\"` 键是使用已保存的 Keras 模型的推断函数的默认键。要使用此函数进行推断,请运行以下代码: " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5Ore5q8-UjW1", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "predict_dataset = eval_dataset.map(lambda image, label: image)\n", "for batch in predict_dataset.take(1):\n", " print(inference_func(batch))" ] }, { "cell_type": "markdown", "metadata": { "id": "osB1LY8WwUJZ" }, "source": [ "您还可以采用分布式方式加载和进行推断:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iDYvu12zYTmT", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "another_strategy = tf.distribute.MirroredStrategy()\n", "with another_strategy.scope():\n", " loaded = tf.saved_model.load(saved_model_path)\n", " inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]\n", "\n", " dist_predict_dataset = another_strategy.experimental_distribute_dataset(\n", " predict_dataset)\n", "\n", " # Calling the function in a distributed manner\n", " for batch in dist_predict_dataset:\n", " result = another_strategy.run(inference_func, args=(batch,))\n", " print(result)\n", " break" ] }, { "cell_type": "markdown", "metadata": { "id": "hWGSukoyw3fF" }, "source": [ "调用已恢复的函数只是基于已保存模型的前向传递 (`tf.keras.Model.predict`)。如果您想继续训练加载的函数,或者将加载的函数嵌入到更大的模型中,应如何操作?通常的做法是将此加载对象封装到 Keras 层以实现此目的。幸运的是,[TF Hub](https://tensorflow.google.cn/hub) 为此提供了 [`hub.KerasLayer`](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/keras_layer.py),如下所示:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "clfk3hQoyKu6", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import tensorflow_hub as hub\n", "\n", "def build_model(loaded):\n", " x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')\n", " # Wrap what's loaded to a KerasLayer\n", " keras_layer = hub.KerasLayer(loaded, trainable=True)(x)\n", " model = tf.keras.Model(x, keras_layer)\n", " return model\n", "\n", "another_strategy = tf.distribute.MirroredStrategy()\n", "with another_strategy.scope():\n", " loaded = tf.saved_model.load(saved_model_path)\n", " model = build_model(loaded)\n", "\n", " model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " optimizer=tf.keras.optimizers.Adam(),\n", " metrics=[tf.metrics.SparseCategoricalAccuracy()])\n", " model.fit(train_dataset, epochs=2)" ] }, { "cell_type": "markdown", "metadata": { "id": "Oe1z_OtSJlu2" }, "source": [ "在上面的示例中,TensorFlow Hub 的 `hub.KerasLayer` 可将从 `tf.saved_model.load` 加载回的结果封装到可用于构建其他模型的 Keras 层。这对于迁移学习非常实用。 " ] }, { "cell_type": "markdown", "metadata": { "id": "KFDOZpK5Wa3W" }, "source": [ "### 我应使用哪种 API?" ] }, { "cell_type": "markdown", "metadata": { "id": "GC6GQ9HDLxD6" }, "source": [ "对于保存,如果您使用的是 Keras 模型,请使用 Keras `Model.save` API,除非您需要低级 API 允许的额外控制。如果您保存的不是 Keras 模型,那么您只能选择使用较低级的 API `tf.saved_model.save`。\n", "\n", "对于加载,您的 API 选择取决于您要从加载 API 中获得什么。如果您无法(或不想)获取 Keras 模型,请使用 `tf.saved_model.load`。否则,请使用 `tf.keras.models.load_model`。请注意,只有保存 Keras 模型后,才能恢复 Keras 模型。\n", "\n", "可以搭配使用 API。您可以使用 `model.save` 保存 Keras 模型,并使用低级 API `tf.saved_model.load` 加载非 Keras 模型。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ktwg2GwnXE8v", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model = get_model()\n", "\n", "# Saving the model using Keras `Model.save`\n", "model.save(saved_model_path)\n", "\n", "another_strategy = tf.distribute.MirroredStrategy()\n", "# Loading the model using the lower-level API\n", "with another_strategy.scope():\n", " loaded = tf.saved_model.load(saved_model_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "0Z7lSj8nZiW5" }, "source": [ "### 从本地设备保存/加载" ] }, { "cell_type": "markdown", "metadata": { "id": "NVAjWcosZodw" }, "source": [ "在远程设备上训练的过程中从本地 I/O 设备保存和加载时(例如,使用 Cloud TPU 时),必须使用 `tf.saved_model.SaveOptions` 和 `tf.saved_model.LoadOptions` 中的选项 `experimental_io_device` 将 I/O 设备设置为 `localhost`。例如:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jFcuzsI94bNA", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "model = get_model()\n", "\n", "# Saving the model to a path on localhost.\n", "saved_model_path = '/tmp/tf_save'\n", "save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')\n", "model.save(saved_model_path, options=save_options)\n", "\n", "# Loading the model from a path on localhost.\n", "another_strategy = tf.distribute.MirroredStrategy()\n", "with another_strategy.scope():\n", " load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')\n", " loaded = tf.keras.models.load_model(saved_model_path, options=load_options)" ] }, { "cell_type": "markdown", "metadata": { "id": "hJTWOnC9iuA3" }, "source": [ "### 警告" ] }, { "cell_type": "markdown", "metadata": { "id": "2cCSZrD7VCxe" }, "source": [ "一种特殊情况是当您以某种方式创建 Keras 模型,然后在训练之前保存它们。例如:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gurSIbDFjOBc", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class SubclassedModel(tf.keras.Model):\n", " \"\"\"Example model defined by subclassing `tf.keras.Model`.\"\"\"\n", "\n", " output_name = 'output_layer'\n", "\n", " def __init__(self):\n", " super(SubclassedModel, self).__init__()\n", " self._dense_layer = tf.keras.layers.Dense(\n", " 5, dtype=tf.dtypes.float32, name=self.output_name)\n", "\n", " def call(self, inputs):\n", " return self._dense_layer(inputs)\n", "\n", "my_model = SubclassedModel()\n", "try:\n", " my_model.save(saved_model_path)\n", "except ValueError as e:\n", " print(f'{type(e).__name__}: ', *e.args)" ] }, { "cell_type": "markdown", "metadata": { "id": "D4qMyXFDSPDO" }, "source": [ "SavedModel 保存跟踪 `tf.function` 时生成的 `tf.types.experimental.ConcreteFunction` 对象(请查看[计算图和 tf.function 简介](../../guide/intro_to_graphs.ipynb)指南中的*函数何时执行跟踪?*了解更多信息)。如果您收到像这样的 `ValueError`,那是因为 `Model.save` 无法找到或创建跟踪的 `ConcreteFunction`。\n", "\n", "**小心:**您不应在一个 `ConcreteFunction` 都没有的情况下保存模型,因为如果这样做,低级 API 将生成一个没有 `ConcreteFunction` 签名的 SavedModel([详细了解](../../guide/saved_model.ipynb) SavedModel 格式)。例如:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "064SE47mYDj8", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "tf.saved_model.save(my_model, saved_model_path)\n", "x = tf.saved_model.load(saved_model_path)\n", "x.signatures" ] }, { "cell_type": "markdown", "metadata": { "id": "LRTxlASJX-cY" }, "source": [ "一般而言,模型的前向传递(`call` 方法)会在第一次调用模型时被自动跟踪,通常是通过 Keras `Model.fit` 方法。如果您设置了输入形状,例如通过将第一层设为 `tf.keras.layers.InputLayer` 或其他层类型,并将 `input_shape` 关键字参数传递给它,Keras [序贯](https://tensorflow.google.cn/guide/keras/sequential_model)和[函数式](https://tensorflow.google.cn/guide/keras/functional) API 也可以生成 `ConcreteFunction`。\n", "\n", "要验证您的模型是否有任何跟踪的 `ConcreteFunction`,请检查 `Model.save_spec` 是否为 `None`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xAXise4eR0YJ", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "print(my_model.save_spec() is None)" ] }, { "cell_type": "markdown", "metadata": { "id": "G2G_FQrWJAO3" }, "source": [ "我们使用 `tf.keras.Model.fit` 来训练模型,可以注意到,`save_spec` 被定义并且模型保存将生效:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cv5LTi0zDkKS", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "BATCH_SIZE_PER_REPLICA = 4\n", "BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync\n", "\n", "dataset_size = 100\n", "dataset = tf.data.Dataset.from_tensors(\n", " (tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))\n", " ).repeat(dataset_size).batch(BATCH_SIZE)\n", "\n", "my_model.compile(optimizer='adam', loss='mean_squared_error')\n", "my_model.fit(dataset, epochs=2)\n", "\n", "print(my_model.save_spec() is None)\n", "my_model.save(saved_model_path)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "save_and_load.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }