{ "cells": [ { "cell_type": "markdown", "metadata": { "cellView": "form", "id": "3jTMb1dySr3V" }, "source": [ "````{admonition} Copyright 2020 The TensorFlow Authors.\n", "```\n", "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "```\n", "````" ] }, { "cell_type": "markdown", "metadata": { "id": "6DWfyNThSziV" }, "source": [ "# 模块、层和模型简介\n", "\n", "\n", " \n", " \n", " \n", " \n", "
View on TensorFlow.org 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from set_env import temp_dir" ] }, { "cell_type": "markdown", "metadata": { "id": "v0DdlfacAdTZ" }, "source": [ "要进行 TensorFlow 机器学习,您可能需要定义、保存和恢复模型。\n", "\n", "抽象地说,模型是:\n", "\n", "- 一个在张量上进行某些计算的函数(**前向传递**)\n", "- 一些可以更新以响应训练的变量\n", "\n", "在本指南中,您将深入学习 Keras,了解如何定义 TensorFlow 模型。本文着眼于 TensorFlow 如何收集变量和模型,以及如何保存和恢复它们。\n", "\n", "注:如果您想立即开始使用 Keras,请参阅 [Keras 指南集合](./keras/)。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "VSa6ayJmfZxZ" }, "source": [ "## 设置" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "goZwOXp_xyQj" }, "outputs": [], "source": [ "import tensorflow as tf\n", "from datetime import datetime\n", "\n", "%load_ext tensorboard" ] }, { "cell_type": "markdown", "metadata": { "id": "yt5HEbsYAbw1" }, "source": [ "## TensorFlow 模块\n", "\n", "大多数模型都由层组成。层是具有已知数学结构的函数,可以重复使用并具有可训练的变量。在 TensorFlow 中,层和模型的大多数高级实现(例如 Keras 或 [Sonnet](https://github.com/deepmind/sonnet))都在以下同一个基础类上构建:`tf.Module`。\n", "\n", "### 构建模块\n", "\n", "下面是一个在标量张量上运行的非常简单的 `tf.Module` 示例:\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "alhYPVEtAiSy" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class SimpleModule(tf.Module):\n", " def __init__(self, name=None):\n", " super().__init__(name=name)\n", " self.a_variable = tf.Variable(5.0, name=\"train_me\")\n", " self.non_trainable_variable = tf.Variable(5.0, trainable=False, name=\"do_not_train_me\")\n", " def __call__(self, x):\n", " return self.a_variable * x + self.non_trainable_variable\n", "\n", "simple_module = SimpleModule(name=\"simple\")\n", "\n", "simple_module(tf.constant(5.0))" ] }, { "cell_type": "markdown", "metadata": { "id": "JwMc_zu5Ant8" }, "source": [ "模块和引申而来的层是“对象”的深度学习术语:它们具有内部状态以及使用该状态的方法。\n", "\n", "`__call__` 并无特殊之处,只是其行为与 [Python 可调用对象](https://stackoverflow.com/questions/111234/what-is-a-callable)类似;您可以使用任何函数来调用模型。\n", "\n", "您可以出于任何原因开启和关闭变量的可训练性,包括在微调过程中冻结层和变量。\n", "\n", "注:`tf.Module` 是 `tf.keras.layers.Layer` 和 `tf.keras.Model` 的基类,因此您在此处看到的一切内容也适用于 Keras。出于历史兼容性原因,Keras 层不会从模块收集变量,因此您的模型应仅使用模块或仅使用 Keras 层。不过,下面给出的用于检查变量的方法在这两种情况下相同。\n", "\n", "通过将 `tf.Module` 子类化,将自动收集分配给该对象属性的任何 `tf.Variable` 或 `tf.Module` 实例。这样,您可以保存和加载变量,还可以创建 `tf.Module` 的集合。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "CyzYy4A_CbVf" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trainable variables: (,)\n", "all variables: (, )\n" ] } ], "source": [ "# All trainable variables\n", "print(\"trainable variables:\", simple_module.trainable_variables)\n", "# Every variable\n", "print(\"all variables:\", simple_module.variables)" ] }, { "cell_type": "markdown", "metadata": { "id": "nuSFrRUNCaaW" }, "source": [ "下面是一个由模块组成的两层线性层模型的示例。\n", "\n", "首先是一个密集(线性)层:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "Efb2p2bzAn-V" }, "outputs": [], "source": [ "class Dense(tf.Module):\n", " def __init__(self, in_features, out_features, name=None):\n", " super().__init__(name=name)\n", " self.w = tf.Variable(\n", " tf.random.normal([in_features, out_features]), name='w')\n", " self.b = tf.Variable(tf.zeros([out_features]), name='b')\n", " def __call__(self, x):\n", " y = tf.matmul(x, self.w) + self.b\n", " return tf.nn.relu(y)" ] }, { "cell_type": "markdown", "metadata": { "id": "bAhMuC-UpnhX" }, "source": [ "随后是完整的模型,此模型将创建并应用两个层实例:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "QQ7qQf-DFw74" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model results: tf.Tensor([[0. 0.]], shape=(1, 2), dtype=float32)\n" ] } ], "source": [ "class SequentialModule(tf.Module):\n", " def __init__(self, name=None):\n", " super().__init__(name=name)\n", "\n", " self.dense_1 = Dense(in_features=3, out_features=3)\n", " self.dense_2 = Dense(in_features=3, out_features=2)\n", "\n", " def __call__(self, x):\n", " x = self.dense_1(x)\n", " return self.dense_2(x)\n", "\n", "# You have made a model!\n", "my_model = SequentialModule(name=\"the_model\")\n", "\n", "# Call it, with random results\n", "print(\"Model results:\", my_model(tf.constant([[2.0, 2.0, 2.0]])))" ] }, { "cell_type": "markdown", "metadata": { "id": "d1oUzasJHHXf" }, "source": [ "`tf.Module` 实例将以递归方式自动收集分配给它的任何 `tf.Variable` 或 `tf.Module` 实例。这样,您可以使用单个模型实例管理 `tf.Module` 的集合,并保存和加载整个模型。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "JLFA5_PEGb6C" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submodules: (<__main__.Dense object at 0x7f5aac17bd40>, <__main__.Dense object at 0x7f5aad17e1e0>)\n" ] } ], "source": [ "print(\"Submodules:\", my_model.submodules)\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "6lzoB8pcRN12" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n" ] } ], "source": [ "for var in my_model.variables:\n", " print(var, \"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "hoaxL3zzm0vK" }, "source": [ "### 等待创建变量\n", "\n", "您在这里可能已经注意到,必须定义层的输入和输出大小。这样,`w` 变量才会具有已知的形状并且可被分配。\n", "\n", "通过将变量创建推迟到第一次使用特定输入形状调用模块时,您将无需预先指定输入大小。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "XsGCLFXlnPum" }, "outputs": [], "source": [ "class FlexibleDenseModule(tf.Module):\n", " # Note: No need for `in_features`\n", " def __init__(self, out_features, name=None):\n", " super().__init__(name=name)\n", " self.is_built = False\n", " self.out_features = out_features\n", "\n", " def __call__(self, x):\n", " # Create variables on first call.\n", " if not self.is_built:\n", " self.w = tf.Variable(\n", " tf.random.normal([x.shape[-1], self.out_features]), name='w')\n", " self.b = tf.Variable(tf.zeros([self.out_features]), name='b')\n", " self.is_built = True\n", "\n", " y = tf.matmul(x, self.w) + self.b\n", " return tf.nn.relu(y)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "8bjOWax9LOkP" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model results: tf.Tensor([[0. 0.]], shape=(1, 2), dtype=float32)\n" ] } ], "source": [ "# Used in a module\n", "class MySequentialModule(tf.Module):\n", " def __init__(self, name=None):\n", " super().__init__(name=name)\n", "\n", " self.dense_1 = FlexibleDenseModule(out_features=3)\n", " self.dense_2 = FlexibleDenseModule(out_features=2)\n", "\n", " def __call__(self, x):\n", " x = self.dense_1(x)\n", " return self.dense_2(x)\n", "\n", "my_model = MySequentialModule(name=\"the_model\")\n", "print(\"Model results:\", my_model(tf.constant([[2.0, 2.0, 2.0]])))" ] }, { "cell_type": "markdown", "metadata": { "id": "49JfbhVrpOLH" }, "source": [ "这种灵活性是 TensorFlow 层通常仅需要指定其输出的形状(例如在 `tf.keras.layers.Dense` 中),而无需指定输入和输出大小的原因。" ] }, { "cell_type": "markdown", "metadata": { "id": "JOLVVBT8J_dl" }, "source": [ "### 保存权重\n", "\n", "您可以将 `tf.Module` 保存为[检查点](./checkpoint.ipynb)和 [SavedModel](./saved_model.ipynb)。\n", "\n", "检查点即是权重(即模块及其子模块内部的变量集的值)。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "pHXKRDk7OLHA" }, "outputs": [ { "data": { "text/plain": [ "'.temp/my_checkpoint'" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chkp_path = temp_dir/\"my_checkpoint\"\n", "checkpoint = tf.train.Checkpoint(model=my_model)\n", "checkpoint.write(chkp_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "WXOPMBR4T4ZR" }, "source": [ "检查点由两种文件组成---数据本身以及元数据的索引文件。索引文件跟踪实际保存的内容和检查点的编号,而检查点数据包含变量值及其特性查找路径。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "jBV3fprlTWqJ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ".temp/my_checkpoint.data-00000-of-00001 .temp/my_checkpoint.index\n" ] } ], "source": [ "!ls {temp_dir}/my_checkpoint*" ] }, { "cell_type": "markdown", "metadata": { "id": "CowCuBTvXgUu" }, "source": [ "您可以查看检查点内部,以确保整个变量集合已由包含这些变量的 Python 对象保存并排序。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "o2QAdfpvS8tB" }, "outputs": [ { "data": { "text/plain": [ "[('_CHECKPOINTABLE_OBJECT_GRAPH', []),\n", " ('model/dense_1/b/.ATTRIBUTES/VARIABLE_VALUE', [3]),\n", " ('model/dense_1/w/.ATTRIBUTES/VARIABLE_VALUE', [3, 3]),\n", " ('model/dense_2/b/.ATTRIBUTES/VARIABLE_VALUE', [2]),\n", " ('model/dense_2/w/.ATTRIBUTES/VARIABLE_VALUE', [3, 2])]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.train.list_variables(chkp_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "4eGaNiQWcK4j" }, "source": [ "在分布式(多机)训练期间,可以将它们分片,这就是要对它们进行编号(例如 '00000-of-00001')的原因。不过,在本例中,只有一个分片。\n", "\n", "重新加载模型时,将重写 Python 对象中的值。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "UV8rdDzcwVVg" }, "outputs": [ { "ename": "NotFoundError", "evalue": "Error when restoring from checkpoint or SavedModel at my_checkpoint: Unsuccessful TensorSliceReader constructor: Failed to find any matching files for my_checkpoint\nPlease double-check that the path is correct. You may be missing the checkpoint suffix (e.g. the '-1' in 'path/to/ckpt-1').", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/training/py_checkpoint_reader.py:92\u001b[0m, in \u001b[0;36mNewCheckpointReader\u001b[0;34m(filepattern)\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 92\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m CheckpointReader(compat\u001b[38;5;241m.\u001b[39mas_bytes(filepattern))\n\u001b[1;32m 93\u001b[0m \u001b[38;5;66;03m# TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the\u001b[39;00m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;66;03m# issue with throwing python exceptions from C++.\u001b[39;00m\n", "\u001b[0;31mRuntimeError\u001b[0m: Unsuccessful TensorSliceReader constructor: Failed to find any matching files for my_checkpoint", "\nDuring handling of the above exception, another exception occurred:\n", "\u001b[0;31mNotFoundError\u001b[0m Traceback (most recent call last)", "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/checkpoint/checkpoint.py:2732\u001b[0m, in \u001b[0;36mCheckpoint.restore\u001b[0;34m(self, save_path, options)\u001b[0m\n\u001b[1;32m 2731\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 2732\u001b[0m status \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mread(save_path, options\u001b[38;5;241m=\u001b[39moptions)\n\u001b[1;32m 2733\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m context\u001b[38;5;241m.\u001b[39mexecuting_eagerly():\n", "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/checkpoint/checkpoint.py:2595\u001b[0m, in \u001b[0;36mCheckpoint.read\u001b[0;34m(self, save_path, options)\u001b[0m\n\u001b[1;32m 2594\u001b[0m options \u001b[38;5;241m=\u001b[39m options \u001b[38;5;129;01mor\u001b[39;00m checkpoint_options\u001b[38;5;241m.\u001b[39mCheckpointOptions()\n\u001b[0;32m-> 2595\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_saver\u001b[38;5;241m.\u001b[39mrestore(save_path\u001b[38;5;241m=\u001b[39msave_path, options\u001b[38;5;241m=\u001b[39moptions)\n\u001b[1;32m 2596\u001b[0m metrics\u001b[38;5;241m.\u001b[39mAddCheckpointReadDuration(\n\u001b[1;32m 2597\u001b[0m api_label\u001b[38;5;241m=\u001b[39m_CHECKPOINT_V2,\n\u001b[1;32m 2598\u001b[0m microseconds\u001b[38;5;241m=\u001b[39m_get_duration_microseconds(start_time, time\u001b[38;5;241m.\u001b[39mtime()))\n", "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/checkpoint/checkpoint.py:1456\u001b[0m, in \u001b[0;36mTrackableSaver.restore\u001b[0;34m(self, save_path, options)\u001b[0m\n\u001b[1;32m 1455\u001b[0m _ASYNC_CHECKPOINT_THREAD\u001b[38;5;241m.\u001b[39mjoin()\n\u001b[0;32m-> 1456\u001b[0m reader \u001b[38;5;241m=\u001b[39m py_checkpoint_reader\u001b[38;5;241m.\u001b[39mNewCheckpointReader(save_path)\n\u001b[1;32m 1457\u001b[0m graph_building \u001b[38;5;241m=\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m context\u001b[38;5;241m.\u001b[39mexecuting_eagerly()\n", "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/training/py_checkpoint_reader.py:96\u001b[0m, in \u001b[0;36mNewCheckpointReader\u001b[0;34m(filepattern)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m---> 96\u001b[0m error_translator(e)\n", "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/training/py_checkpoint_reader.py:31\u001b[0m, in \u001b[0;36merror_translator\u001b[0;34m(e)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnot found in checkpoint\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m error_message \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[1;32m 29\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mFailed to find any \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmatching files for\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01min\u001b[39;00m error_message:\n\u001b[0;32m---> 31\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m errors_impl\u001b[38;5;241m.\u001b[39mNotFoundError(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m, error_message)\n\u001b[1;32m 32\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mSliced checkpoints are not supported\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m error_message \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[1;32m 33\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mData type \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnot \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msupported\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01min\u001b[39;00m error_message:\n", "\u001b[0;31mNotFoundError\u001b[0m: Unsuccessful TensorSliceReader constructor: Failed to find any matching files for my_checkpoint", "\nDuring handling of the above exception, another exception occurred:\n", "\u001b[0;31mNotFoundError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[14], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m new_model \u001b[38;5;241m=\u001b[39m MySequentialModule()\n\u001b[1;32m 2\u001b[0m new_checkpoint \u001b[38;5;241m=\u001b[39m tf\u001b[38;5;241m.\u001b[39mtrain\u001b[38;5;241m.\u001b[39mCheckpoint(model\u001b[38;5;241m=\u001b[39mnew_model)\n\u001b[0;32m----> 3\u001b[0m new_checkpoint\u001b[38;5;241m.\u001b[39mrestore(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmy_checkpoint\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m# Should be the same result as above\u001b[39;00m\n\u001b[1;32m 6\u001b[0m new_model(tf\u001b[38;5;241m.\u001b[39mconstant([[\u001b[38;5;241m2.0\u001b[39m, \u001b[38;5;241m2.0\u001b[39m, \u001b[38;5;241m2.0\u001b[39m]]))\n", "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/checkpoint/checkpoint.py:2736\u001b[0m, in \u001b[0;36mCheckpoint.restore\u001b[0;34m(self, save_path, options)\u001b[0m\n\u001b[1;32m 2734\u001b[0m context\u001b[38;5;241m.\u001b[39masync_wait() \u001b[38;5;66;03m# Ensure restore operations have completed.\u001b[39;00m\n\u001b[1;32m 2735\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m errors_impl\u001b[38;5;241m.\u001b[39mNotFoundError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m-> 2736\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m errors_impl\u001b[38;5;241m.\u001b[39mNotFoundError(\n\u001b[1;32m 2737\u001b[0m \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 2738\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mError when restoring from checkpoint or SavedModel at \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 2739\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00morig_save_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;241m.\u001b[39mmessage\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 2740\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mPlease double-check that the path is correct. You may be missing \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 2741\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthe checkpoint suffix (e.g. the \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m-1\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m in \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpath/to/ckpt-1\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m).\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 2742\u001b[0m \u001b[38;5;66;03m# Create the save counter now so it gets initialized with other variables\u001b[39;00m\n\u001b[1;32m 2743\u001b[0m \u001b[38;5;66;03m# when graph building. Creating it earlier would lead to errors when using,\u001b[39;00m\n\u001b[1;32m 2744\u001b[0m \u001b[38;5;66;03m# say, train.Saver() to save the model before initializing it.\u001b[39;00m\n\u001b[1;32m 2745\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_maybe_create_save_counter()\n", "\u001b[0;31mNotFoundError\u001b[0m: Error when restoring from checkpoint or SavedModel at my_checkpoint: Unsuccessful TensorSliceReader constructor: Failed to find any matching files for my_checkpoint\nPlease double-check that the path is correct. You may be missing the checkpoint suffix (e.g. the '-1' in 'path/to/ckpt-1')." ] } ], "source": [ "new_model = MySequentialModule()\n", "new_checkpoint = tf.train.Checkpoint(model=new_model)\n", "new_checkpoint.restore(\"my_checkpoint\")\n", "\n", "# Should be the same result as above\n", "new_model(tf.constant([[2.0, 2.0, 2.0]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "BnPwDRwamdfq" }, "source": [ "注:由于检查点处于长时间训练工作流的核心位置,因此 `tf.checkpoint.CheckpointManager` 是一个可使检查点管理变得更简单的辅助类。有关更多详细信息,请参阅[指南](./checkpoint.ipynb)。" ] }, { "cell_type": "markdown", "metadata": { "id": "pSZebVuWxDXu" }, "source": [ "### 保存函数\n", "\n", "TensorFlow 可以在不使用原始 Python 对象的情况下运行模型,如 [TensorFlow Serving](https://tensorflow.org/tfx) 和 [TensorFlow Lite](https://tensorflow.org/lite) 所示,甚至当您从 [TensorFlow Hub](https://tensorflow.org/hub) 下载经过训练的模型时也是如此。\n", "\n", "TensorFlow 需要了解如何执行 Python 中描述的计算,但**不需要原始代码**。为此,您可以创建一个**计算图**,如[计算图和函数简介指南](./intro_to_graphs.ipynb)中所述。\n", "\n", "此计算图中包含实现函数的*运算*。\n", "\n", "您可以通过添加 `@tf.function` 装饰器在上面的模型中定义计算图,以指示此代码应作为计算图运行。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "WQTvkapUh7lk" }, "outputs": [], "source": [ "class MySequentialModule(tf.Module):\n", " def __init__(self, name=None):\n", " super().__init__(name=name)\n", "\n", " self.dense_1 = Dense(in_features=3, out_features=3)\n", " self.dense_2 = Dense(in_features=3, out_features=2)\n", "\n", " @tf.function\n", " def __call__(self, x):\n", " x = self.dense_1(x)\n", " return self.dense_2(x)\n", "\n", "# You have made a model with a graph!\n", "my_model = MySequentialModule(name=\"the_model\")" ] }, { "cell_type": "markdown", "metadata": { "id": "hW66YXBziLo9" }, "source": [ "您构建的模块的工作原理与之前完全相同。传递给函数的每个唯一签名都会创建一个单独的计算图。请参阅[计算图和函数简介指南](./intro_to_graphs.ipynb)以了解详情。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "H5zUfti3iR52" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([[8.677328 0. ]], shape=(1, 2), dtype=float32)\n", "tf.Tensor(\n", "[[[8.677328 0. ]\n", " [8.677328 0. ]]], shape=(1, 2, 2), dtype=float32)\n" ] } ], "source": [ "print(my_model([[2.0, 2.0, 2.0]]))\n", "print(my_model([[[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "lbGlU1kgyDo7" }, "source": [ "您可以通过在 TensorBoard 摘要中跟踪计算图来将其可视化。" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "zmy-T67zhp-S" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([[0.6444056 0. ]], shape=(1, 2), dtype=float32)\n" ] } ], "source": [ "# Set up logging.\n", "stamp = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", "logdir = f\"{temp_dir}/logs/func/{stamp}\"\n", "writer = tf.summary.create_file_writer(logdir)\n", "\n", "# Create a new model to get a fresh trace\n", "# Otherwise the summary will not see the graph.\n", "new_model = MySequentialModule()\n", "\n", "# Bracket the function call with\n", "# tf.summary.trace_on() and tf.summary.trace_export().\n", "tf.summary.trace_on(graph=True)\n", "tf.profiler.experimental.start(logdir)\n", "# Call only one tf.function when tracing.\n", "z = print(new_model(tf.constant([[2.0, 2.0, 2.0]])))\n", "with writer.as_default():\n", " tf.summary.trace_export(\n", " name=\"my_func_trace\",\n", " step=0,\n", " profiler_outdir=logdir)" ] }, { "cell_type": "markdown", "metadata": { "id": "gz4lwNZ9hR79" }, "source": [ "启动 Tensorboard 以查看生成的跟踪:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "V4MXDbgBnkJu" }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#docs_infra: no_execute\n", "%tensorboard --logdir {temp_dir}/logs/func" ] }, { "cell_type": "markdown", "metadata": { "id": "Gjattu0AhYUl" }, "source": [ "![A screenshot of the graph, in tensorboard](https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/images/tensorboard_graph.png?raw=true)" ] }, { "cell_type": "markdown", "metadata": { "id": "SQu3TVZecmL7" }, "source": [ "### 创建 `SavedModel`\n", "\n", "共享经过完全训练的模型的推荐方式是使用 `SavedModel`。`SavedModel` 包含函数集合与权重集合。\n", "\n", "您可以按以下方式保存刚刚训练的模型:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "Awv_Tw__WK7a" }, "outputs": [], "source": [ "tf.saved_model.save(my_model, temp_dir/\"the_saved_model\")" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "SXv3mEKsefGj" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total 24\n", "drwxr-xr-x 2 ai ai 10 10月 25 10:03 assets\n", "-rw-rw-r-- 1 ai ai 56 10月 25 10:03 fingerprint.pb\n", "-rw-rw-r-- 1 ai ai 17643 10月 25 10:03 saved_model.pb\n", "drwxr-xr-x 2 ai ai 78 10月 25 10:03 variables\n" ] } ], "source": [ "# Inspect the SavedModel in the directory\n", "!ls -l {temp_dir}/the_saved_model" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "vQQ3hEvHYdoR" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total 8\n", "-rw-rw-r-- 1 ai ai 490 10月 25 10:03 variables.data-00000-of-00001\n", "-rw-rw-r-- 1 ai ai 356 10月 25 10:03 variables.index\n" ] } ], "source": [ "# The variables/ directory contains a checkpoint of the variables \n", "!ls -l {temp_dir}/the_saved_model/variables" ] }, { "cell_type": "markdown", "metadata": { "id": "xBqPop7ZesBU" }, "source": [ "`saved_model.pb` 文件是一个描述函数式 `tf.Graph` 的[协议缓冲区](https://developers.google.com/protocol-buffers)。\n", "\n", "可以从此表示加载模型和层,而无需实际构建创建该表示的类的实例。在您没有(或不需要)Python 解释器(例如大规模应用或在边缘设备上),或者在原始 Python 代码不可用或不实用的情况下,这样做十分理想。\n", "\n", "您可以将模型作为新对象加载:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "zRFcA5wIefv4" }, "outputs": [], "source": [ "new_model = tf.saved_model.load(temp_dir/\"the_saved_model\")" ] }, { "cell_type": "markdown", "metadata": { "id": "-9EF3mT7i3qN" }, "source": [ "通过加载已保存模型创建的 `new_model` 是 TensorFlow 内部的用户对象,无需任何类知识。它不是 `SequentialModule` 类型的对象。" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "EC_eQj7yi54G" }, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "isinstance(new_model, SequentialModule)" ] }, { "cell_type": "markdown", "metadata": { "id": "-OrOX1zxiyhR" }, "source": [ "此新模型​​适用于已定义的输入签名。您不能向以这种方式恢复的模型添加更多签名。" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "_23BYYBWfKnc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([[8.677328 0. ]], shape=(1, 2), dtype=float32)\n", "tf.Tensor(\n", "[[[8.677328 0. ]\n", " [8.677328 0. ]]], shape=(1, 2, 2), dtype=float32)\n" ] } ], "source": [ "print(my_model([[2.0, 2.0, 2.0]]))\n", "print(my_model([[[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "qSFhoMtTjSR6" }, "source": [ "因此,利用 `SavedModel`,您可以使用 `tf.Module` 保存 TensorFlow 权重和计算图,随后再次加载它们。" ] }, { "cell_type": "markdown", "metadata": { "id": "Rb9IdN7hlUZK" }, "source": [ "## Keras 模型和层\n", "\n", "请注意,到目前为止,还没有提到 Keras。您可以在 `tf.Module` 上构建自己的高级 API,而我们已经拥有这些 API。\n", "\n", "在本部分中,您将研究 Keras 如何使用 `tf.Module`。可在 [Keras 指南](https://tensorflow.google.cn/guide/keras/sequential_model)中找到有关 Keras 模型的完整用户指南。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ds08u3touwe4t" }, "source": [ "Keras 层和模型具有许多额外功能,包括:\n", "\n", "- 可选损失\n", "- 对指标的支持\n", "- 对可选 `training` 参数的内置支持,用于区分训练和推断用途\n", "- 保存和恢复 Python 对象而不仅仅是黑盒函数\n", "- `get_config` 和 `from_config` 方法,允许您准确存储配置以在 Python 中克隆模型\n", "\n", "这些功能通过子类化允许更复杂的模型,例如自定义 GAN 或变分自编码器 (VAE) 模型。在自定义层和模型的[完整指南](./keras/custom_layers_and_models.ipynb)中阅读相关内容。\n", "\n", "Keras 模型还附带额外的功能,使它们易于训练、评估、加载、保存,甚至在多台机器上进行训练。" ] }, { "cell_type": "markdown", "metadata": { "id": "uigsVGPreE-D" }, "source": [ "### Keras 层\n", "\n", "`tf.keras.layers.Layer` 是所有 Keras 层的基类,它继承自 `tf.Module`。\n", "\n", "您只需换出父项,然后将 `__call__` 更改为 `call` 即可将模块转换为 Keras 层:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "88YOGquhnQRd" }, "outputs": [], "source": [ "class MyDense(tf.keras.layers.Layer):\n", " # Adding **kwargs to support base Keras layer arguments\n", " def __init__(self, in_features, out_features, **kwargs):\n", " super().__init__(**kwargs)\n", "\n", " # This will soon move to the build step; see below\n", " self.w = tf.Variable(\n", " tf.random.normal([in_features, out_features]), name='w')\n", " self.b = tf.Variable(tf.zeros([out_features]), name='b')\n", " def call(self, x):\n", " y = tf.matmul(x, self.w) + self.b\n", " return tf.nn.relu(y)\n", "\n", "simple_layer = MyDense(name=\"simple\", in_features=3, out_features=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "nYGmAsPrws--" }, "source": [ "Keras 层有自己的 `__call__`,它会进行下一部分中所述的某些簿记,然后调用 `call()`。您应当不会看到功能上的任何变化。" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "nIqE8wOznYKG" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "simple_layer(tf.constant([[2.0, 2.0, 2.0]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "tmN5vb1K18U1" }, "source": [ "### `build` 步骤\n", "\n", "如上所述,在您确定输入形状之前,等待创建变量在许多情况下十分方便。\n", "\n", "Keras 层具有额外的生命周期步骤,可让您在定义层时获得更高的灵活性。这是在 `build()` 函数中定义的。\n", "\n", "`build` 仅被调用一次,而且是使用输入的形状调用的。它通常用于创建变量(权重)。\n", "\n", "您可以根据输入的大小灵活地重写上面的 `MyDense` 层:\n" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "4YTfrlgdsURp" }, "outputs": [], "source": [ "class FlexibleDense(tf.keras.layers.Layer):\n", " # Note the added `**kwargs`, as Keras supports many arguments\n", " def __init__(self, out_features, **kwargs):\n", " super().__init__(**kwargs)\n", " self.out_features = out_features\n", "\n", " def build(self, input_shape): # Create the state of the layer (weights)\n", " self.w = tf.Variable(\n", " tf.random.normal([input_shape[-1], self.out_features]), name='w')\n", " self.b = tf.Variable(tf.zeros([self.out_features]), name='b')\n", "\n", " def call(self, inputs): # Defines the computation from inputs to outputs\n", " return tf.matmul(inputs, self.w) + self.b\n", "\n", "# Create the instance of the layer\n", "flexible_dense = FlexibleDense(out_features=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "Koc_uSqt2PRh" }, "source": [ "此时,模型尚未构建,因此没有变量:" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "DgyTyUD32Ln4" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "flexible_dense.variables" ] }, { "cell_type": "markdown", "metadata": { "id": "-KdamIVl2W8Y" }, "source": [ "调用该函数会分配大小适当的变量。" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "IkLyEx7uAoTK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model results: tf.Tensor(\n", "[[ 7.2248173 2.6235514 3.2833693]\n", " [10.837226 3.9353273 4.9250536]], shape=(2, 3), dtype=float32)\n" ] } ], "source": [ "# Call it, with predictably random results\n", "print(\"Model results:\", flexible_dense(tf.constant([[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])))" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "id": "Swofpkrd2YDd" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "flexible_dense.variables" ] }, { "cell_type": "markdown", "metadata": { "id": "7PuNUnf0OIpF" }, "source": [ "由于仅调用一次 `build`,因此如果输入形状与层的变量不兼容,输入将被拒绝。" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "id": "caYWDrHSAy_j" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Failed: Exception encountered when calling FlexibleDense.call().\n", "\n", "\u001b[1m{{function_node __wrapped__MatMul_device_/job:localhost/replica:0/task:0/device:GPU:0}} Matrix size-incompatible: In[0]: [1,4], In[1]: [3,3] [Op:MatMul] name: \u001b[0m\n", "\n", "Arguments received by FlexibleDense.call():\n", " • inputs=tf.Tensor(shape=(1, 4), dtype=float32)\n" ] } ], "source": [ "try:\n", " print(\"Model results:\", flexible_dense(tf.constant([[2.0, 2.0, 2.0, 2.0]])))\n", "except tf.errors.InvalidArgumentError as e:\n", " print(\"Failed:\", e)" ] }, { "cell_type": "markdown", "metadata": { "id": "L2kds2IHw2KD" }, "source": [ "### Keras 模型\n", "\n", "您可以将模型定义为嵌套的 Keras 层。\n", "\n", "不过,Keras 还提供了称为 `tf.keras.Model` 的全功能模型类。它继承自 `tf.keras.layers.Layer`,因此 Keras 模型支持以与 Keras 层相同的方式使用和嵌套。Keras 模型还具有额外的功能,这使它们可以轻松训练、评估、加载、保存,甚至在多台机器上进行训练。\n", "\n", "您可以使用几乎相同的代码定义上面的 `SequentialModule`,再次将 `__call__` 转换为 `call()` 并更改父项。" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "id": "Hqjo1DiyrHrn" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model results: tf.Tensor([[ 2.20927 -1.5668545]], shape=(1, 2), dtype=float32)\n" ] } ], "source": [ "class MySequentialModel(tf.keras.Model):\n", " def __init__(self, name=None, **kwargs):\n", " super().__init__(**kwargs)\n", "\n", " self.dense_1 = FlexibleDense(out_features=3)\n", " self.dense_2 = FlexibleDense(out_features=2)\n", " def call(self, x):\n", " x = self.dense_1(x)\n", " return self.dense_2(x)\n", "\n", "# You have made a Keras model!\n", "my_sequential_model = MySequentialModel(name=\"the_model\")\n", "\n", "# Call it on a tensor, with random results\n", "print(\"Model results:\", my_sequential_model(tf.constant([[2.0, 2.0, 2.0]])))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "8i-CR_h2xw3z" }, "source": [ "所有相同的功能都可用,包括跟踪变量和子模块。\n", "\n", "注:嵌套在 Keras 层或模型中的原始 `tf.Module` 将不会收集其变量以用于训练或保存。相反,它会在 Keras 层内嵌套 Keras 层。" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "id": "hdLQFNdMsOz1" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "my_sequential_model.variables" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "JjVAMrAJsQ7G" }, "outputs": [ { "ename": "AttributeError", "evalue": "'MySequentialModel' object has no attribute 'submodules'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[36], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m my_sequential_model\u001b[38;5;241m.\u001b[39msubmodules\n", "\u001b[0;31mAttributeError\u001b[0m: 'MySequentialModel' object has no attribute 'submodules'" ] } ], "source": [ "my_sequential_model.submodules" ] }, { "cell_type": "markdown", "metadata": { "id": "FhP8EItC4oac" }, "source": [ "重写 `tf.keras.Model` 是一种构建 TensorFlow 模型的极 Python 化方式。如果要从其他框架迁移模型,这可能非常简单。\n", "\n", "如果要构造的模型是现有层和输入的简单组合,则可以使用[函数式 API](./keras/functional.ipynb) 节省时间和空间,此 API 附带有关模型重构和架构的附加功能。\n", "\n", "下面是使用函数式 API 构造的相同模型:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "id": "jJiZZiJ0fyqQ" }, "outputs": [ { "data": { "text/html": [ "
Model: \"functional\"\n",
              "
\n" ], "text/plain": [ "\u001b[1mModel: \"functional\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
              "┃ Layer (type)                     Output Shape                  Param # ┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
              "│ input_layer (InputLayer)        │ (None, 3)              │             0 │\n",
              "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
              "│ flexible_dense_3                │ (None, 3)              │             0 │\n",
              "│ (FlexibleDense)                 │                        │               │\n",
              "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
              "│ flexible_dense_4                │ (None, 2)              │             0 │\n",
              "│ (FlexibleDense)                 │                        │               │\n",
              "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
              "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ input_layer (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ flexible_dense_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "│ (\u001b[38;5;33mFlexibleDense\u001b[0m) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ flexible_dense_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "│ (\u001b[38;5;33mFlexibleDense\u001b[0m) │ │ │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 0 (0.00 B)\n",
              "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 0 (0.00 B)\n",
              "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
              "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "inputs = tf.keras.Input(shape=[3,])\n", "\n", "x = FlexibleDense(3)(inputs)\n", "x = FlexibleDense(2)(x)\n", "\n", "my_functional_model = tf.keras.Model(inputs=inputs, outputs=x)\n", "\n", "my_functional_model.summary()" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "id": "kg-xAZw5gaG6" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "my_functional_model(tf.constant([[2.0, 2.0, 2.0]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "s_BK9XH5q9cq" }, "source": [ "这里的主要区别在于,输入形状是作为函数构造过程的一部分预先指定的。在这种情况下,不必完全指定 `input_shape` 参数;您可以将某些维度保留为 `None`。\n", "\n", "注:您无需在子类化模型中指定 `input_shape` 或 `InputLayer`;这些参数和层将被忽略。" ] }, { "cell_type": "markdown", "metadata": { "id": "qI9aXLnaHEFF" }, "source": [ "### 保存 Keras 模型\n", "\n", "Keras 模型拥有自己专门的 zip 归档保存格式,以 `.keras` 扩展名标记。调用 `tf.keras.Model.save` 时,在文件名中添加 `.keras` 扩展名。例如:" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "id": "SAz-KVZlzAJu" }, "outputs": [], "source": [ "my_sequential_model.save(temp_dir/\"exname_of_file.keras\")" ] }, { "cell_type": "markdown", "metadata": { "id": "C2urAeR-omns" }, "source": [ "同样地,它们也可以轻松重新加载:" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "id": "Wj5DW-LCopry" }, "outputs": [ { "ename": "TypeError", "evalue": "Could not locate class 'MySequentialModel'. Make sure custom classes are decorated with `@keras.saving.register_keras_serializable()`. Full object config: {'module': None, 'class_name': 'MySequentialModel', 'config': {'name': 'my_sequential_model', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}}, 'registered_name': 'MySequentialModel'}", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[41], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m reconstructed_model \u001b[38;5;241m=\u001b[39m tf\u001b[38;5;241m.\u001b[39mkeras\u001b[38;5;241m.\u001b[39mmodels\u001b[38;5;241m.\u001b[39mload_model(temp_dir\u001b[38;5;241m/\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexname_of_file.keras\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/saving/saving_api.py:187\u001b[0m, in \u001b[0;36mload_model\u001b[0;34m(filepath, custom_objects, compile, safe_mode)\u001b[0m\n\u001b[1;32m 184\u001b[0m is_keras_zip \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_keras_zip \u001b[38;5;129;01mor\u001b[39;00m is_keras_dir:\n\u001b[0;32m--> 187\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m saving_lib\u001b[38;5;241m.\u001b[39mload_model(\n\u001b[1;32m 188\u001b[0m filepath,\n\u001b[1;32m 189\u001b[0m custom_objects\u001b[38;5;241m=\u001b[39mcustom_objects,\n\u001b[1;32m 190\u001b[0m \u001b[38;5;28mcompile\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mcompile\u001b[39m,\n\u001b[1;32m 191\u001b[0m safe_mode\u001b[38;5;241m=\u001b[39msafe_mode,\n\u001b[1;32m 192\u001b[0m )\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(filepath)\u001b[38;5;241m.\u001b[39mendswith((\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.h5\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.hdf5\u001b[39m\u001b[38;5;124m\"\u001b[39m)):\n\u001b[1;32m 194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m legacy_h5_format\u001b[38;5;241m.\u001b[39mload_model_from_hdf5(\n\u001b[1;32m 195\u001b[0m filepath, custom_objects\u001b[38;5;241m=\u001b[39mcustom_objects, \u001b[38;5;28mcompile\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mcompile\u001b[39m\n\u001b[1;32m 196\u001b[0m )\n", "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/saving/saving_lib.py:365\u001b[0m, in \u001b[0;36mload_model\u001b[0;34m(filepath, custom_objects, compile, safe_mode)\u001b[0m\n\u001b[1;32m 360\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 361\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInvalid filename: expected a `.keras` extension. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 362\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mReceived: filepath=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfilepath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 363\u001b[0m )\n\u001b[1;32m 364\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(filepath, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrb\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[0;32m--> 365\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _load_model_from_fileobj(\n\u001b[1;32m 366\u001b[0m f, custom_objects, \u001b[38;5;28mcompile\u001b[39m, safe_mode\n\u001b[1;32m 367\u001b[0m )\n", "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/saving/saving_lib.py:442\u001b[0m, in \u001b[0;36m_load_model_from_fileobj\u001b[0;34m(fileobj, custom_objects, compile, safe_mode)\u001b[0m\n\u001b[1;32m 439\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m zf\u001b[38;5;241m.\u001b[39mopen(_CONFIG_FILENAME, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 440\u001b[0m config_json \u001b[38;5;241m=\u001b[39m f\u001b[38;5;241m.\u001b[39mread()\n\u001b[0;32m--> 442\u001b[0m model \u001b[38;5;241m=\u001b[39m _model_from_config(\n\u001b[1;32m 443\u001b[0m config_json, custom_objects, \u001b[38;5;28mcompile\u001b[39m, safe_mode\n\u001b[1;32m 444\u001b[0m )\n\u001b[1;32m 446\u001b[0m all_filenames \u001b[38;5;241m=\u001b[39m zf\u001b[38;5;241m.\u001b[39mnamelist()\n\u001b[1;32m 447\u001b[0m extract_dir \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/saving/saving_lib.py:431\u001b[0m, in \u001b[0;36m_model_from_config\u001b[0;34m(config_json, custom_objects, compile, safe_mode)\u001b[0m\n\u001b[1;32m 429\u001b[0m \u001b[38;5;66;03m# Construct the model from the configuration file in the archive.\u001b[39;00m\n\u001b[1;32m 430\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ObjectSharingScope():\n\u001b[0;32m--> 431\u001b[0m model \u001b[38;5;241m=\u001b[39m deserialize_keras_object(\n\u001b[1;32m 432\u001b[0m config_dict, custom_objects, safe_mode\u001b[38;5;241m=\u001b[39msafe_mode\n\u001b[1;32m 433\u001b[0m )\n\u001b[1;32m 434\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m model\n", "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/saving/serialization_lib.py:694\u001b[0m, in \u001b[0;36mdeserialize_keras_object\u001b[0;34m(config, custom_objects, safe_mode, **kwargs)\u001b[0m\n\u001b[1;32m 691\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m obj \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 692\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m obj\n\u001b[0;32m--> 694\u001b[0m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;241m=\u001b[39m _retrieve_class_or_fn(\n\u001b[1;32m 695\u001b[0m class_name,\n\u001b[1;32m 696\u001b[0m registered_name,\n\u001b[1;32m 697\u001b[0m module,\n\u001b[1;32m 698\u001b[0m obj_type\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mclass\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 699\u001b[0m full_config\u001b[38;5;241m=\u001b[39mconfig,\n\u001b[1;32m 700\u001b[0m custom_objects\u001b[38;5;241m=\u001b[39mcustom_objects,\n\u001b[1;32m 701\u001b[0m )\n\u001b[1;32m 703\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mcls\u001b[39m, types\u001b[38;5;241m.\u001b[39mFunctionType):\n\u001b[1;32m 704\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\n", "File \u001b[0;32m/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/saving/serialization_lib.py:812\u001b[0m, in \u001b[0;36m_retrieve_class_or_fn\u001b[0;34m(name, registered_name, module, obj_type, full_config, custom_objects)\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m obj \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 810\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m obj\n\u001b[0;32m--> 812\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[1;32m 813\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCould not locate \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mobj_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 814\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMake sure custom classes are decorated with \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 815\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`@keras.saving.register_keras_serializable()`. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 816\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFull object config: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfull_config\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 817\u001b[0m )\n", "\u001b[0;31mTypeError\u001b[0m: Could not locate class 'MySequentialModel'. Make sure custom classes are decorated with `@keras.saving.register_keras_serializable()`. Full object config: {'module': None, 'class_name': 'MySequentialModel', 'config': {'name': 'my_sequential_model', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}}, 'registered_name': 'MySequentialModel'}" ] } ], "source": [ "reconstructed_model = tf.keras.models.load_model(temp_dir/\"exname_of_file.keras\")" ] }, { "cell_type": "markdown", "metadata": { "id": "EA7P_MNvpviZ" }, "source": [ "Keras zip 归档 `.keras` 文件还可以保存指标、损失和优化器状态。\n", "\n", "可以使用此重构模型,并且在相同数据上调用时会产生相同的结果:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P_wGfQo5pe6T" }, "outputs": [], "source": [ "reconstructed_model(tf.constant([[2.0, 2.0, 2.0]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "seLIUG2354s" }, "source": [ "### 设置 Keras 模型检查点\n", "\n", "也可以为 Keras 模型设置检查点,这看起来和 `tf.Module` 一样。" ] }, { "cell_type": "markdown", "metadata": { "id": "xKyjlkceqjwD" }, "source": [ "有关保存和序列化 Keras 模型,包括为自定义层提供配置方法来为功能提供支持的更多信息,请参阅[保存和序列化指南](https://tensorflow.google.cn/guide/keras/save_and_serialize)。" ] }, { "cell_type": "markdown", "metadata": { "id": "kcdMMPYv7Krz" }, "source": [ "# 后续步骤\n", "\n", "如果您想了解有关 Keras 的更多详细信息,可以在[此处](./keras/)查看现有的 Keras 指南。\n", "\n", "在 `tf.module` 上构建的高级 API 的另一个示例是 DeepMind 的 Sonnet,[其网站](https://github.com/deepmind/sonnet)上有详细介绍。" ] } ], "metadata": { "colab": { "collapsed_sections": [ "ISubpr_SSsiM" ], "name": "intro_to_modules.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "xxx", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 0 }