{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "b518b04cbfe0" }, "outputs": [], "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "906e07f6e562", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "394e705afdd5" }, "source": [ "# 保存和加载 Keras 模型" ] }, { "cell_type": "markdown", "metadata": { "id": "60de82f6bcea" }, "source": [ "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
model.add_loss()
和 `model.add_metric()` 添加的外部损失和指标不会被保存(这与 SavedModel 不同)。如果您的模型有此类损失和指标且您想要恢复训练,则您需要在加载模型后自行重新添加这些损失。请注意,这不适用于通过 self.add_loss()
和 `self.add_metric()` 在层内创建的损失/指标。只要该层被加载,这些损失和指标就会被保留,因为它们是该层 `call` 方法的一部分。\n",
"- 已保存的文件中不包含**自定义对象(如自定义层)的计算图**。在加载时,Keras 需要访问这些对象的 Python 类/函数以重建模型。请参阅[自定义对象](#custom-objects)。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bf78706009bf"
},
"source": [
"## 保存架构\n",
"\n",
"模型的配置(或架构)指定模型包含的层,以及这些层的连接方式*。如果您有模型的配置,则可以使用权重的新初始化状态创建模型,而无需编译信息。\n",
"\n",
"*请注意,这仅适用于使用函数式或序列式 API 定义的模型,不适用于子类化模型。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "58a708dbb5da"
},
"source": [
"### 序贯模型或函数式 API 模型的配置\n",
"\n",
"这些类型的模型是显式的层计算图:它们的配置始终以结构化形式提供。\n",
"\n",
"#### API\n",
"\n",
"- `get_config()` 和 `from_config()`\n",
"- `tf.keras.models.model_to_json()` 和 `tf.keras.models.model_from_json()`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3d8b20812b50"
},
"source": [
"#### `get_config()` 和 `from_config()`\n",
"\n",
"调用 `config = model.get_config()` 将返回一个包含模型配置的 Python 字典。然后可以通过 `Sequential.from_config(config)`(针对 `Sequential` 模型)或 `Model.from_config(config)`(针对函数式 API 模型)重建同一模型。\n",
"\n",
"相同的工作流也适用于任何可序列化的层。\n",
"\n",
"**层示例:**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4f26b94e879a",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"layer = keras.layers.Dense(3, activation=\"relu\")\n",
"layer_config = layer.get_config()\n",
"new_layer = keras.layers.Dense.from_config(layer_config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a7e5dd2a439c"
},
"source": [
"**示例:**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ae0842be8a2a",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])\n",
"config = model.get_config()\n",
"new_model = keras.Sequential.from_config(config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1e97ca5f73d7"
},
"source": [
"**函数式模型示例:**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "da001f34e412",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"inputs = keras.Input((32,))\n",
"outputs = keras.layers.Dense(1)(inputs)\n",
"model = keras.Model(inputs, outputs)\n",
"config = model.get_config()\n",
"new_model = keras.Model.from_config(config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d7c08fae3eef"
},
"source": [
"#### `to_json()` 和 `tf.keras.models.model_from_json()`\n",
"\n",
"这与 `get_config` / `from_config` 类似,不同之处在于它会将模型转换成 JSON 字符串,之后该字符串可以在没有原始模型类的情况下进行加载。它还特定于模型,不适用于层。\n",
"\n",
"**示例:**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "12885447bd35",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])\n",
"json_config = model.to_json()\n",
"new_model = keras.models.model_from_json(json_config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "edcae4bf461c"
},
"source": [
"### 仅加载 TensorFlow 计算图\n",
"\n",
"**模型和层**\n",
"\n",
"子类化模型和层的架构在 `__init__` 和 `call` 方法中进行定义。它们被视为 Python 字节码,无法将其序列化为与 JSON 兼容的配置。您可以尝试将字节码序列化(例如通过 `pickle`),但这样做极不安全,因为模型将无法在其他系统上进行加载。\n",
"\n",
"为了保存/加载带有自定义层的模型或子类化模型,您应该重写 `get_config` 和 `from_config`(可选)方法。此外,您还应该注册自定义对象,以便 Keras 能够感知它。\n",
"\n",
"**自定义函数**\n",
"\n",
"自定义函数(如激活损失或初始化)不需要 `get_config` 方法。只需将函数名称注册为自定义对象,就足以进行加载。\n",
"\n",
"**仅加载 TensorFlow 计算图**\n",
"\n",
"您可以加载由 Keras 生成的 TensorFlow 计算图。要进行此类加载,您无需提供任何 `custom_objects`。您可以执行以下代码进行加载:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1651c6825106",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"model.save(\"my_model\")\n",
"tensorflow_graph = tf.saved_model.load(\"my_model\")\n",
"x = np.random.uniform(size=(4, 32)).astype(np.float32)\n",
"predicted = tensorflow_graph(x).numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b15faa16734b"
},
"source": [
"请注意,此方式有几个缺点:\n",
"\n",
"- tf.saved_model.load
返回的对象不是 Keras 模型,因此不太容易使用。例如,您将无法访问 .predict()
或 .fit()
。\n",
"- `tf.saved_model.load` 返回的对象不是 Keras 模型,因此不太容易使用。例如,您将无法访问 `.predict()` 或 `.fit()`。\n",
"\n",
"虽然不鼓励使用此方式,但当您遇到棘手问题(例如,您丢失了自定义对象的代码,或在使用 `tf.keras.models.load_model()` 加载模型时遇到问题)时,它还是能够提供帮助。\n",
"\n",
"有关详细信息,请参阅 [`tf.saved_model.load` 相关页面](https://tensorflow.google.cn/api_docs/python/tf/saved_model/load)。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d308bc27a04d"
},
"source": [
"#### 定义配置方法\n",
"\n",
"规范:\n",
"\n",
"- `get_config` 应该返回一个 JSON 可序列化字典,以便兼容 Keras 节省架构和模型的 API。\n",
"- `from_config(config)` (`classmethod`) 应返回从配置创建的新层或模型对象。默认实现返回 `cls(**config)`。\n",
"\n",
"**示例:**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "e18c4668dadc",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"class CustomLayer(keras.layers.Layer):\n",
" def __init__(self, a):\n",
" self.var = tf.Variable(a, name=\"var_a\")\n",
"\n",
" def call(self, inputs, training=False):\n",
" if training:\n",
" return inputs * self.var\n",
" else:\n",
" return inputs\n",
"\n",
" def get_config(self):\n",
" return {\"a\": self.var.numpy()}\n",
"\n",
" # There's actually no need to define `from_config` here, since returning\n",
" # `cls(**config)` is the default behavior.\n",
" @classmethod\n",
" def from_config(cls, config):\n",
" return cls(**config)\n",
"\n",
"\n",
"layer = CustomLayer(5)\n",
"layer.var.assign(2)\n",
"\n",
"serialized_layer = keras.layers.serialize(layer)\n",
"new_layer = keras.layers.deserialize(\n",
" serialized_layer, custom_objects={\"CustomLayer\": CustomLayer}\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "425a9baa574e"
},
"source": [
"#### 注册自定义对象\n",
"\n",
"Keras 会记录哪个类生成了配置。在上面的示例中,`tf.keras.layers.serialize` 会生成自定义层的序列化形式:\n",
"\n",
"```\n",
"{'class_name': 'CustomLayer', 'config': {'a': 2}}\n",
"```\n",
"\n",
"Keras 会维护一份所有内置层、模型、优化器和指标类的主列表,用于查找正确的类以调用 `from_config`。如果找不到该类,则会引发错误 (`Value Error: Unknown layer`)。可以通过几种方式将自定义类注册到此列表中:\n",
"\n",
"1. 在加载函数中设置 `custom_objects` 参数。(请参阅上文”定义配置方法“部分中的示例)\n",
"2. `tf.keras.utils.custom_object_scope` 或者 `tf.keras.utils.CustomObjectScope`\n",
"3. `tf.keras.utils.register_keras_serializable`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a047be0ba572"
},
"source": [
"#### 示例:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "04a82ec30b5c",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"class CustomLayer(keras.layers.Layer):\n",
" def __init__(self, units=32, **kwargs):\n",
" super(CustomLayer, self).__init__(**kwargs)\n",
" self.units = units\n",
"\n",
" def build(self, input_shape):\n",
" self.w = self.add_weight(\n",
" shape=(input_shape[-1], self.units),\n",
" initializer=\"random_normal\",\n",
" trainable=True,\n",
" )\n",
" self.b = self.add_weight(\n",
" shape=(self.units,), initializer=\"random_normal\", trainable=True\n",
" )\n",
"\n",
" def call(self, inputs):\n",
" return tf.matmul(inputs, self.w) + self.b\n",
"\n",
" def get_config(self):\n",
" config = super(CustomLayer, self).get_config()\n",
" config.update({\"units\": self.units})\n",
" return config\n",
"\n",
"\n",
"def custom_activation(x):\n",
" return tf.nn.tanh(x) ** 2\n",
"\n",
"\n",
"# Make a model with the CustomLayer and custom_activation\n",
"inputs = keras.Input((32,))\n",
"x = CustomLayer(32)(inputs)\n",
"outputs = keras.layers.Activation(custom_activation)(x)\n",
"model = keras.Model(inputs, outputs)\n",
"\n",
"# Retrieve the config\n",
"config = model.get_config()\n",
"\n",
"# At loading time, register the custom objects with a `custom_object_scope`:\n",
"custom_objects = {\"CustomLayer\": CustomLayer, \"custom_activation\": custom_activation}\n",
"with keras.utils.custom_object_scope(custom_objects):\n",
" new_model = keras.Model.from_config(config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "13c7f2a1be03"
},
"source": [
"### 内存中模型克隆\n",
"\n",
"您还可以通过 `tf.keras.models.clone_model()` 在内存中克隆模型。这相当于获取模型的配置,然后通过配置重建模型(因此它不会保留编译信息或层的权重值)。\n",
"\n",
"**示例:**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "93056ffe6eb4",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"with keras.utils.custom_object_scope(custom_objects):\n",
" new_model = keras.models.clone_model(model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "05c91a5a23e3"
},
"source": [
"## 您只需使用模型进行推断:在这种情况下,您无需重新开始训练,因此不需要编译信息或优化器状态。\n",
"\n",
"在内存中将权重从一层转移到另一层\n",
"\n",
"- 您只需使用模型进行推断:在这种情况下,您无需重新开始训练,因此不需要编译信息或优化器状态。\n",
"- 您正在进行迁移学习:在这种情况下,您需要重用先验模型的状态来训练新模型,因此不需要先验模型的编译信息。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c5229f4014f2"
},
"source": [
"### 用于内存中权重迁移的 API\n",
"\n",
"您可以使用 `get_weights` 和 `set_weights` 在不同对象之间复制权重:\n",
"\n",
"- `tf.keras.layers.Layer.get_weights()`:返回 Numpy 数组列表。\n",
"- `tf.keras.layers.Layer.set_weights()`:将模型权重设置为 `weights` 参数中的值。\n",
"\n",
"示例如下。\n",
"\n",
"***通常建议使用相同的 API 来构建模型。如果您在序贯模型和函数式模型之间,或在函数式模型和子类化模型等之间进行切换,请始终重新构建预训练模型并将预训练权重加载到该模型。***"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "c9124df19cb2",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def create_layer():\n",
" layer = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")\n",
" layer.build((None, 784))\n",
" return layer\n",
"\n",
"\n",
"layer_1 = create_layer()\n",
"layer_2 = create_layer()\n",
"\n",
"# Copy weights from layer 1 to layer 2\n",
"layer_2.set_weights(layer_1.get_weights())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ff7945516c7d"
},
"source": [
"***在内存中将权重从一个模型转移到另一个具有兼容架构的模型***"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "11005d4023d4",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"# Create a simple functional model\n",
"inputs = keras.Input(shape=(784,), name=\"digits\")\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n",
"outputs = keras.layers.Dense(10, name=\"predictions\")(x)\n",
"functional_model = keras.Model(inputs=inputs, outputs=outputs, name=\"3_layer_mlp\")\n",
"\n",
"# Define a subclassed model with the same architecture\n",
"class SubclassedModel(keras.Model):\n",
" def __init__(self, output_dim, name=None):\n",
" super(SubclassedModel, self).__init__(name=name)\n",
" self.output_dim = output_dim\n",
" self.dense_1 = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")\n",
" self.dense_2 = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")\n",
" self.dense_3 = keras.layers.Dense(output_dim, name=\"predictions\")\n",
"\n",
" def call(self, inputs):\n",
" x = self.dense_1(inputs)\n",
" x = self.dense_2(x)\n",
" x = self.dense_3(x)\n",
" return x\n",
"\n",
" def get_config(self):\n",
" return {\"output_dim\": self.output_dim, \"name\": self.name}\n",
"\n",
"\n",
"subclassed_model = SubclassedModel(10)\n",
"# Call the subclassed model once to create the weights.\n",
"subclassed_model(tf.ones((1, 784)))\n",
"\n",
"# Copy weights from functional_model to subclassed_model.\n",
"subclassed_model.set_weights(functional_model.get_weights())\n",
"\n",
"assert len(functional_model.weights) == len(subclassed_model.weights)\n",
"for a, b in zip(functional_model.weights, subclassed_model.weights):\n",
" np.testing.assert_allclose(a.numpy(), b.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bd4d08bff725"
},
"source": [
"***无状态层的情况***\n",
"\n",
"无状态层不会改变权重的顺序或数量,因此即便存在额外的/缺失的无状态层,模型也可以具有兼容架构。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "927dc7934d44",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"inputs = keras.Input(shape=(784,), name=\"digits\")\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n",
"outputs = keras.layers.Dense(10, name=\"predictions\")(x)\n",
"functional_model = keras.Model(inputs=inputs, outputs=outputs, name=\"3_layer_mlp\")\n",
"\n",
"inputs = keras.Input(shape=(784,), name=\"digits\")\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n",
"\n",
"# Add a dropout layer, which does not contain any weights.\n",
"x = keras.layers.Dropout(0.5)(x)\n",
"outputs = keras.layers.Dense(10, name=\"predictions\")(x)\n",
"functional_model_with_dropout = keras.Model(\n",
" inputs=inputs, outputs=outputs, name=\"3_layer_mlp\"\n",
")\n",
"\n",
"functional_model_with_dropout.set_weights(functional_model.get_weights())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "199e984872d3"
},
"source": [
"### 用于将权重保存到磁盘并将其加载回来的 API\n",
"\n",
"可以用以下格式调用 `model.save_weights`,将权重保存到磁盘:\n",
"\n",
"- TensorFlow 检查点\n",
"- HDF5\n",
"\n",
"`model.save_weights` 的默认格式是 TensorFlow 检查点。可以通过以下两种方式指定保存格式:\n",
"\n",
"1. `save_format` 参数:将值设置为 `save_format=\"tf\"` 或 `save_format=\"h5\"`。\n",
"2. `path` 参数:如果路径以 `.h5` 或 `.hdf5` 结束,则使用 HDF5 格式。除非设置了 `save_format`,否则对于其他后缀,将使用 TensorFlow 检查点格式。\n",
"\n",
"您还可以选择将权重作为内存中 Numpy 数组取回。每个 API 都有自己的优缺点,详情如下。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3505dc65d6c1"
},
"source": [
"### TF 检查点格式\n",
"\n",
"**示例:**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "f92053377391",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"# Runnable example\n",
"sequential_model = keras.Sequential(\n",
" [\n",
" keras.Input(shape=(784,), name=\"digits\"),\n",
" keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\"),\n",
" keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\"),\n",
" keras.layers.Dense(10, name=\"predictions\"),\n",
" ]\n",
")\n",
"sequential_model.save_weights(\"ckpt\")\n",
"load_status = sequential_model.load_weights(\"ckpt\")\n",
"\n",
"# `assert_consumed` can be used as validation that all variable values have been\n",
"# restored from the checkpoint. See `tf.train.Checkpoint.restore` for other\n",
"# methods in the Status object.\n",
"load_status.assert_consumed()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "87f1145ac846"
},
"source": [
"#### 格式详细信息\n",
"\n",
"TensorFlow 检查点格式使用对象特性名称来保存和恢复权重。以 `tf.keras.layers.Dense` 层为例。该层包含两个权重:`dense.kernel` 和 `dense.bias`。将层保存为 `tf` 格式后,生成的检查点会包含 `\"kernel\"` 和 `\"bias\"` 键及其对应的权重值。有关详情,请参阅 [TF 检查点指南中的“加载机制”](https://tensorflow.google.cn/guide/checkpoint#loading_mechanics)。\n",
"\n",
"请注意,特性/计算图边缘根据**父对象中使用的名称而非变量的名称**进行命名。请考虑下面示例中的 `CustomLayer`。变量 `CustomLayer.var` 是将 `\"var\"` 而非 `\"var_a\"` 作为键的一部分来保存的。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "c919189b3697",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"class CustomLayer(keras.layers.Layer):\n",
" def __init__(self, a):\n",
" self.var = tf.Variable(a, name=\"var_a\")\n",
"\n",
"\n",
"layer = CustomLayer(5)\n",
"layer_ckpt = tf.train.Checkpoint(layer=layer).save(\"custom_layer\")\n",
"\n",
"ckpt_reader = tf.train.load_checkpoint(layer_ckpt)\n",
"\n",
"ckpt_reader.get_variable_to_dtype_map()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c4e5a7162b13"
},
"source": [
"#### 迁移学习示例\n",
"\n",
"本质上,只要两个模型具有相同的架构,它们就可以共享同一个检查点。\n",
"\n",
"**示例:**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "78d08199d27f",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"inputs = keras.Input(shape=(784,), name=\"digits\")\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n",
"outputs = keras.layers.Dense(10, name=\"predictions\")(x)\n",
"functional_model = keras.Model(inputs=inputs, outputs=outputs, name=\"3_layer_mlp\")\n",
"\n",
"# Extract a portion of the functional model defined in the Setup section.\n",
"# The following lines produce a new model that excludes the final output\n",
"# layer of the functional model.\n",
"pretrained = keras.Model(\n",
" functional_model.inputs, functional_model.layers[-1].input, name=\"pretrained_model\"\n",
")\n",
"# Randomly assign \"trained\" weights.\n",
"for w in pretrained.weights:\n",
" w.assign(tf.random.normal(w.shape))\n",
"pretrained.save_weights(\"pretrained_ckpt\")\n",
"pretrained.summary()\n",
"\n",
"# Assume this is a separate program where only 'pretrained_ckpt' exists.\n",
"# Create a new functional model with a different output dimension.\n",
"inputs = keras.Input(shape=(784,), name=\"digits\")\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n",
"outputs = keras.layers.Dense(5, name=\"predictions\")(x)\n",
"model = keras.Model(inputs=inputs, outputs=outputs, name=\"new_model\")\n",
"\n",
"# Load the weights from pretrained_ckpt into model.\n",
"model.load_weights(\"pretrained_ckpt\")\n",
"\n",
"# Check that all of the pretrained weights have been loaded.\n",
"for a, b in zip(pretrained.weights, model.weights):\n",
" np.testing.assert_allclose(a.numpy(), b.numpy())\n",
"\n",
"print(\"\\n\", \"-\" * 50)\n",
"model.summary()\n",
"\n",
"# Example 2: Sequential model\n",
"# Recreate the pretrained model, and load the saved weights.\n",
"inputs = keras.Input(shape=(784,), name=\"digits\")\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n",
"pretrained_model = keras.Model(inputs=inputs, outputs=x, name=\"pretrained\")\n",
"\n",
"# Sequential example:\n",
"model = keras.Sequential([pretrained_model, keras.layers.Dense(5, name=\"predictions\")])\n",
"model.summary()\n",
"\n",
"pretrained_model.load_weights(\"pretrained_ckpt\")\n",
"\n",
"# Warning! Calling `model.load_weights('pretrained_ckpt')` won't throw an error,\n",
"# but will *not* work as expected. If you inspect the weights, you'll see that\n",
"# none of the weights will have loaded. `pretrained_model.load_weights()` is the\n",
"# correct method to call."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7b07ad5fe5b0"
},
"source": [
"通常建议使用相同的 API 来构建模型。如果您在序贯模型和函数式模型之间切换,或在函数式模型和子类化模型等之间切换,请始终重新构建预训练模型并将预训练权重加载到该模型。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2ab83c542e2d"
},
"source": [
"下一个问题是,如果模型架构截然不同,如何保存权重并将其加载到不同模型?解决方案是使用 `tf.train.Checkpoint` 来保存和恢复确切的层/变量。\n",
"\n",
"**示例:**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "97037b9ea265",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"# Create a subclassed model that essentially uses functional_model's first\n",
"# and last layers.\n",
"# First, save the weights of functional_model's first and last dense layers.\n",
"first_dense = functional_model.layers[1]\n",
"last_dense = functional_model.layers[-1]\n",
"ckpt_path = tf.train.Checkpoint(\n",
" dense=first_dense, kernel=last_dense.kernel, bias=last_dense.bias\n",
").save(\"ckpt\")\n",
"\n",
"# Define the subclassed model.\n",
"class ContrivedModel(keras.Model):\n",
" def __init__(self):\n",
" super(ContrivedModel, self).__init__()\n",
" self.first_dense = keras.layers.Dense(64)\n",
" self.kernel = self.add_variable(\"kernel\", shape=(64, 10))\n",
" self.bias = self.add_variable(\"bias\", shape=(10,))\n",
"\n",
" def call(self, inputs):\n",
" x = self.first_dense(inputs)\n",
" return tf.matmul(x, self.kernel) + self.bias\n",
"\n",
"\n",
"model = ContrivedModel()\n",
"# Call model on inputs to create the variables of the dense layer.\n",
"_ = model(tf.ones((1, 784)))\n",
"\n",
"# Create a Checkpoint with the same structure as before, and load the weights.\n",
"tf.train.Checkpoint(\n",
" dense=model.first_dense, kernel=model.kernel, bias=model.bias\n",
").restore(ckpt_path).assert_consumed()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "18356461e7dd"
},
"source": [
"### HDF5 格式\n",
"\n",
"HDF5 格式包含按层名称分组的权重。权重是通过将可训练权重列表与不可训练权重列表连接起来进行排序的列表(与 `layer.weights` 相同)。因此,如果模型的层和可训练状态与保存在检查点中的相同,则可以使用 HDF5 检查点。\n",
"\n",
"**示例:**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "43aec1e07913",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"# Runnable example\n",
"sequential_model = keras.Sequential(\n",
" [\n",
" keras.Input(shape=(784,), name=\"digits\"),\n",
" keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\"),\n",
" keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\"),\n",
" keras.layers.Dense(10, name=\"predictions\"),\n",
" ]\n",
")\n",
"sequential_model.save_weights(\"weights.h5\")\n",
"sequential_model.load_weights(\"weights.h5\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dc63aef6e0d3"
},
"source": [
"请注意,当模型包含嵌套层时,更改 `layer.trainable` 可能导致 `layer.weights` 的顺序不同。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "83b70826944a",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"class NestedDenseLayer(keras.layers.Layer):\n",
" def __init__(self, units, name=None):\n",
" super(NestedDenseLayer, self).__init__(name=name)\n",
" self.dense_1 = keras.layers.Dense(units, name=\"dense_1\")\n",
" self.dense_2 = keras.layers.Dense(units, name=\"dense_2\")\n",
"\n",
" def call(self, inputs):\n",
" return self.dense_2(self.dense_1(inputs))\n",
"\n",
"\n",
"nested_model = keras.Sequential([keras.Input((784,)), NestedDenseLayer(10, \"nested\")])\n",
"variable_names = [v.name for v in nested_model.weights]\n",
"print(\"variables: {}\".format(variable_names))\n",
"\n",
"print(\"\\nChanging trainable status of one of the nested layers...\")\n",
"nested_model.get_layer(\"nested\").dense_1.trainable = False\n",
"\n",
"variable_names_2 = [v.name for v in nested_model.weights]\n",
"print(\"\\nvariables: {}\".format(variable_names_2))\n",
"print(\"variable ordering changed:\", variable_names != variable_names_2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cc261c1a31ee"
},
"source": [
"#### 迁移学习示例\n",
"\n",
"从 HDF5 加载预训练权重时,建议将权重加载到设置了检查点的原始模型中,然后将所需的权重/层提取到新模型中。\n",
"\n",
"**示例:**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "06cabc31494a",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def create_functional_model():\n",
" inputs = keras.Input(shape=(784,), name=\"digits\")\n",
" x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n",
" x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n",
" outputs = keras.layers.Dense(10, name=\"predictions\")(x)\n",
" return keras.Model(inputs=inputs, outputs=outputs, name=\"3_layer_mlp\")\n",
"\n",
"\n",
"functional_model = create_functional_model()\n",
"functional_model.save_weights(\"pretrained_weights.h5\")\n",
"\n",
"# In a separate program:\n",
"pretrained_model = create_functional_model()\n",
"pretrained_model.load_weights(\"pretrained_weights.h5\")\n",
"\n",
"# Create a new model by extracting layers from the original model:\n",
"extracted_layers = pretrained_model.layers[:-1]\n",
"extracted_layers.append(keras.layers.Dense(5, name=\"dense_3\"))\n",
"model = keras.Sequential(extracted_layers)\n",
"model.summary()"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "save_and_serialize.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}