{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "Pnn4rDWGqDZL" }, "outputs": [], "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "l534d35Gp68G" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/media/pc/data/lxw/ai/d2py/doc/libs/tf-chaos/guide\n" ] } ], "source": [ "%cd .." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from set_env import temp_dir" ] }, { "cell_type": "markdown", "metadata": { "id": "3TI3Q3XBesaS" }, "source": [ "# 训练检查点" ] }, { "cell_type": "markdown", "metadata": { "id": "yw_a0iGucY8z" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看在 Google Colab 中运行 在 GitHub 上查看源代码下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "LeDp7dovcbus" }, "source": [ "“保存 TensorFlow 模型”这一短语通常表示保存以下两种元素之一:\n", "\n", "1. 检查点,或\n", "2. SavedModel。\n", "\n", "检查点可以捕获模型使用的所有参数(`tf.Variable` 对象)的确切值。检查点不包含对模型所定义计算的任何描述,因此通常仅在将使用保存参数值的源代码可用时才有用。\n", "\n", "另一方面,除了参数值(检查点)之外,SavedModel 格式还包括对模型所定义计算的序列化描述。这种格式的模型独立于创建模型的源代码。因此,它们适合通过 TensorFlow Serving、TensorFlow Lite、TensorFlow.js 或者使用其他编程语言(C、C++、Java、Go、Rust、C# 等 TensorFlow API)编写的程序进行部署。\n", "\n", "本文介绍用于编写和读取检查点的 API。" ] }, { "cell_type": "markdown", "metadata": { "id": "U0nm8k-6xfh2" }, "source": [ "## 设置" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "VEvpMYAKsC4z" }, "outputs": [], "source": [ "import tensorflow as tf" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "OEQCseyeC4Ev" }, "outputs": [], "source": [ "class Net(tf.keras.Model):\n", " \"\"\"A simple linear model.\"\"\"\n", "\n", " def __init__(self):\n", " super().__init__()\n", " self.l1 = tf.keras.layers.Dense(5)\n", "\n", " def call(self, x):\n", " return self.l1(x)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "utqeoDADC5ZR" }, "outputs": [], "source": [ "net = Net()" ] }, { "cell_type": "markdown", "metadata": { "id": "5vsq3-pffo1I" }, "source": [ "## 从 `tf.keras` 训练 API 保存\n", "\n", "请参阅 [`tf.keras` 保存和恢复指南](https://tensorflow.google.cn/guide/keras/save_and_serialize)。\n", "\n", "`tf.keras.Model.save_weights` 可以保存一个 TensorFlow 检查点。 " ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "SuhmrYPEl4D_" }, "outputs": [], "source": [ "net.save_weights(temp_dir/'easy_checkpoint.weights.h5')" ] }, { "cell_type": "markdown", "metadata": { "id": "XseWX5jDg4lQ" }, "source": [ "## 编写检查点\n" ] }, { "cell_type": "markdown", "metadata": { "id": "1jpZPz76ZP3K" }, "source": [ "TensorFlow 模型的持久状态存储在 `tf.Variable` 对象中。这些对象可以直接构造,但通常会通过像 `tf.keras.layers` 或 `tf.keras.Model` 这样的高级 API 创建。\n", "\n", "管理变量的最简单方法是将它们附加到 Python 对象,然后引用这些对象。\n", "\n", "`tf.train.Checkpoint`、`tf.keras.layers.Layer` 和 `tf.keras.Model` 的子类会自动跟踪分配给其特性的变量。下面的示例构造了一个简单的线性模型,然后编写检查点,其中包含该模型所有变量的值。" ] }, { "cell_type": "markdown", "metadata": { "id": "x0vFBr_Im73_" }, "source": [ "您可以使用 `Model.save_weights` 轻松保存模型检查点。" ] }, { "cell_type": "markdown", "metadata": { "id": "FHTJ1JzxCi8a" }, "source": [ "### 手动创建检查点" ] }, { "cell_type": "markdown", "metadata": { "id": "6cF9fqYOCrEO" }, "source": [ "#### 设置" ] }, { "cell_type": "markdown", "metadata": { "id": "fNjf9KaLdIRP" }, "source": [ "为了帮助演示 `tf.train.Checkpoint` 的所有功能, 下面定义了一个玩具 (toy) 数据集和优化步骤:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "tSNyP4IJ9nkU" }, "outputs": [], "source": [ "def toy_dataset():\n", " inputs = tf.range(10.)[:, None]\n", " labels = inputs * 5. + tf.range(5.)[None, :]\n", " return tf.data.Dataset.from_tensor_slices(\n", " dict(x=inputs, y=labels)).repeat().batch(2)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "ICm1cufh_JH8" }, "outputs": [], "source": [ "def train_step(net, example, optimizer):\n", " \"\"\"Trains `net` on `example` using `optimizer`.\"\"\"\n", " with tf.GradientTape() as tape:\n", " output = net(example['x'])\n", " loss = tf.reduce_mean(tf.abs(output - example['y']))\n", " variables = net.trainable_variables\n", " gradients = tape.gradient(loss, variables)\n", " optimizer.apply_gradients(zip(gradients, variables))\n", " return loss" ] }, { "cell_type": "markdown", "metadata": { "id": "vxzGpHRbOVO6" }, "source": [ "#### 创建检查点对象\n", "\n", "使用 `tf.train.Checkpoint` 对象手动创建一个检查点,其中要检查的对象设置为对象的特性。\n", "\n", "`tf.train.CheckpointManager` 也有助于管理多个检查点。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "ou5qarOQOWYl" }, "outputs": [], "source": [ "opt = tf.keras.optimizers.Adam(0.1)\n", "dataset = toy_dataset()\n", "iterator = iter(dataset)\n", "ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)\n", "manager = tf.train.CheckpointManager(ckpt, temp_dir/'./tf_ckpts', max_to_keep=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "8ZbYSD4uCy96" }, "source": [ "#### 训练模型并为模型设置检查点" ] }, { "cell_type": "markdown", "metadata": { "id": "NP9IySmCeCkn" }, "source": [ "以下训练循环可创建模型和优化器的实例,然后将它们收集到 `tf.train.Checkpoint` 对象中。它在每批数据上循环调用训练步骤,并定期将检查点写入磁盘。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "BbCS5A6K1VSH" }, "outputs": [], "source": [ "def train_and_checkpoint(net, manager):\n", " ckpt.restore(manager.latest_checkpoint)\n", " if manager.latest_checkpoint:\n", " print(\"Restored from {}\".format(manager.latest_checkpoint))\n", " else:\n", " print(\"Initializing from scratch.\")\n", "\n", " for _ in range(50):\n", " example = next(iterator)\n", " loss = train_step(net, example, opt)\n", " ckpt.step.assign_add(1)\n", " if int(ckpt.step) % 10 == 0:\n", " save_path = manager.save()\n", " print(\"Saved checkpoint for step {}: {}\".format(int(ckpt.step), save_path))\n", " print(\"loss {:1.2f}\".format(loss.numpy()))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "Ik3IBMTdPW41" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initializing from scratch.\n", "Saved checkpoint for step 10: .temp/tf_ckpts/ckpt-1\n", "loss 30.26\n", "Saved checkpoint for step 20: .temp/tf_ckpts/ckpt-2\n", "loss 23.68\n", "Saved checkpoint for step 30: .temp/tf_ckpts/ckpt-3\n", "loss 17.13\n", "Saved checkpoint for step 40: .temp/tf_ckpts/ckpt-4\n", "loss 10.75\n", "Saved checkpoint for step 50: .temp/tf_ckpts/ckpt-5\n", "loss 5.55\n" ] } ], "source": [ "train_and_checkpoint(net, manager)" ] }, { "cell_type": "markdown", "metadata": { "id": "2wzcc1xYN-sH" }, "source": [ "#### 恢复和继续训练" ] }, { "cell_type": "markdown", "metadata": { "id": "lw1QeyRBgsLE" }, "source": [ "在第一个训练周期结束后,您可以传递一个新的模型和管理器,但在您中断的地方继续训练:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "UjilkTOV2PBK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Restored from .temp/tf_ckpts/ckpt-5\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/optimizers/base_optimizer.py:678: UserWarning: Gradients do not exist for variables ['kernel'] when minimizing the loss. If using `model.compile()`, did you forget to provide a `loss` argument?\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Saved checkpoint for step 60: .temp/tf_ckpts/ckpt-6\n", "loss 22.74\n", "Saved checkpoint for step 70: .temp/tf_ckpts/ckpt-7\n", "loss 12.86\n", "Saved checkpoint for step 80: .temp/tf_ckpts/ckpt-8\n", "loss 4.05\n", "Saved checkpoint for step 90: .temp/tf_ckpts/ckpt-9\n", "loss 1.63\n", "Saved checkpoint for step 100: .temp/tf_ckpts/ckpt-10\n", "loss 1.04\n" ] } ], "source": [ "opt = tf.keras.optimizers.Adam(0.1)\n", "net = Net()\n", "dataset = toy_dataset()\n", "iterator = iter(dataset)\n", "ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)\n", "manager = tf.train.CheckpointManager(ckpt, temp_dir/'./tf_ckpts', max_to_keep=3)\n", "\n", "train_and_checkpoint(net, manager)" ] }, { "cell_type": "markdown", "metadata": { "id": "dxJT9vV-2PnZ" }, "source": [ "`tf.train.CheckpointManager` 对象会删除旧的检查点。上面配置为仅保留最近的三个检查点。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "3zmM0a-F5XqC" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['.temp/tf_ckpts/ckpt-8', '.temp/tf_ckpts/ckpt-9', '.temp/tf_ckpts/ckpt-10']\n" ] } ], "source": [ "print(manager.checkpoints) # List the three remaining checkpoints" ] }, { "cell_type": "markdown", "metadata": { "id": "qwlYDyjemY4P" }, "source": [ "这些路径(如 `'./tf_ckpts/ckpt-10'`)不是磁盘上的文件,而是一个 `index` 文件和一个或多个包含变量值的数据文件的前缀。这些前缀被分组到一个单独的 `checkpoint` 文件 (`'./tf_ckpts/checkpoint'`) 中,其中 `CheckpointManager` 保存其状态。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "t1feej9JntV_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ls: cannot access '.temp./tf_ckpts': No such file or directory\n" ] } ], "source": [ "!ls {temp_dir}./tf_ckpts" ] }, { "cell_type": "markdown", "metadata": { "id": "DR2wQc9x6b3X" }, "source": [ "\n", "\n", "## 加载机制\n", "\n", "TensorFlow 通过从加载的对象开始遍历带命名边的有向计算图来将变量与检查点值匹配。边名称通常来自对象中的特性名称,例如 `self.l1 = tf.keras.layers.Dense(5)` 中的 `\"l1\"`。`tf.train.Checkpoint` 使用其关键字参数名称,如 `tf.train.Checkpoint(step=...)` 中的 `\"step\"`。\n", "\n", "上面示例中的依赖图如下所示:\n", "\n", "![Visualization of the dependency graph for the example training loop](https://tensorflow.google.cn/images/guide/whole_checkpoint.svg)\n", "\n", "优化器为红色,常规变量为蓝色,优化器插槽变量为橙色。其他节点(例如,代表 `tf.train.Checkpoint` 的节点)为黑色。\n", "\n", "插槽变量是优化器状态的一部分,不过是为特定变量而创建。例如,上面的 `'m'` 边缘对应于动量,Adam 优化器会针对每个变量跟踪该动量。只有在同时保存变量和优化器时,才会将插槽变量保存到检查点中,并因此保存虚线边缘。" ] }, { "cell_type": "markdown", "metadata": { "id": "VpY5IuanUEQ0" }, "source": [ "在 `tf.train.Checkpoint` 对象上调用 `restore` 会排队处理请求的恢复,一旦有来自 `Checkpoint` 对象的匹配路径,就会恢复变量值。例如,您可以通过重建一个穿过网络和层到达它的路径来仅从上面定义的模型加载偏差。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "wmX2AuyH7TVt" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0. 0. 0. 0. 0.]\n", "[5.5001473 5.835153 7.544729 7.4490175 9.386313 ]\n" ] } ], "source": [ "to_restore = tf.Variable(tf.zeros([5]))\n", "print(to_restore.numpy()) # All zeros\n", "fake_layer = tf.train.Checkpoint(bias=to_restore)\n", "fake_net = tf.train.Checkpoint(l1=fake_layer)\n", "new_root = tf.train.Checkpoint(net=fake_net)\n", "status = new_root.restore(tf.train.latest_checkpoint(temp_dir/'./tf_ckpts/'))\n", "print(to_restore.numpy()) # This gets the restored value." ] }, { "cell_type": "markdown", "metadata": { "id": "GqEW-_pJDAnE" }, "source": [ "这些新对象的依赖关系计算图是您上面所编写较大检查点的一个小得多的子计算图。它仅包括偏差和 `tf.train.Checkpoint` 用于对检查点进行编号的保存计数器。\n", "\n", "![偏差变量的子计算图的可视化](https://tensorflow.google.cn/images/guide/partial_checkpoint.svg)\n", "\n", "`restore` 返回一个具有可选断言的状态对象。在新的 `Checkpoint` 中创建的所有对象都已恢复,因此 `status.assert_existing_objects_matched` 通过。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "P9TQXl81Dq5r" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "status.assert_existing_objects_matched()" ] }, { "cell_type": "markdown", "metadata": { "id": "GoMwf8CFDu9r" }, "source": [ "检查点中有许多不匹配的对象,包括层的内核和优化器的变量。`status.assert_consumed()` 仅在检查点和程序完全匹配时通过,并在此处抛出异常。" ] }, { "cell_type": "markdown", "metadata": { "id": "KCcmJ-2j9RUP" }, "source": [ "### 延迟恢复\n", "\n", "当输入形状可用时,TensorFlow 中的 `Layer` 对象可能会将变量创建延迟到变量的首次调用。例如,`Dense` 层内核的形状取决于该层的输入和输出形状,因此,作为构造函数参数所需的输出形状没有足够的信息来单独创建变量。由于调用 `Layer` 还会读取变量的值,必须在变量的创建与其首次使用之间进行恢复。\n", "\n", "为支持这种习惯用法,`tf.train.Checkpoint` 会推迟尚不具有匹配变量的恢复。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "TXYUCO3v-I72" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0. 0. 0. 0. 0.]]\n", "[[0. 0. 0. 0. 0.]]\n" ] } ], "source": [ "deferred_restore = tf.Variable(tf.zeros([1, 5]))\n", "print(deferred_restore.numpy()) # Not restored; still zeros\n", "fake_layer.kernel = deferred_restore\n", "print(deferred_restore.numpy()) # Restored" ] }, { "cell_type": "markdown", "metadata": { "id": "-DWhJ3glyobN" }, "source": [ "### 手动检查检查点\n", "\n", "`tf.train.load_checkpoint` 返回一个提供对检查点内容进行较低级别访问权限的 `CheckpointReader`。它包含从每个变量的键到检查点中每个变量的形状和 dtype 的映射。如上面显示的计算图中所示,变量的键是它的对象路径。\n", "\n", "注:检查点没有更高级别的结构。它只知道变量的路径和值,而没有 `models`、`layers` 或它们如何连接的概念。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "RlRsADTezoBD" }, "outputs": [ { "data": { "text/plain": [ "['_CHECKPOINTABLE_OBJECT_GRAPH',\n", " 'iterator/.ATTRIBUTES/ITERATOR_STATE',\n", " 'net/l1/_kernel/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'optimizer/_iterations/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'optimizer/_learning_rate/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'optimizer/_trainable_variables/0/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'optimizer/_variables/2/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'optimizer/_variables/3/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'optimizer/_variables/4/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'optimizer/_variables/5/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'optimizer/_variables/6/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'optimizer/_variables/7/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'step/.ATTRIBUTES/VARIABLE_VALUE']" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reader = tf.train.load_checkpoint(temp_dir/'./tf_ckpts/')\n", "shape_from_key = reader.get_variable_to_shape_map()\n", "dtype_from_key = reader.get_variable_to_dtype_map()\n", "\n", "sorted(shape_from_key.keys())" ] }, { "cell_type": "markdown", "metadata": { "id": "AVrdvbNvgq5V" }, "source": [ "因此,如果您对 `net.l1.kernel` 的值感兴趣,可以使用以下代码获取该值:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "lYhX_XWCgl92" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape: [1, 5]\n", "Dtype: float32\n" ] } ], "source": [ "key = 'net/l1/_kernel/.ATTRIBUTES/VARIABLE_VALUE'\n", "\n", "print(\"Shape:\", shape_from_key[key])\n", "print(\"Dtype:\", dtype_from_key[key].name)" ] }, { "cell_type": "markdown", "metadata": { "id": "2Zk92jM5gRDW" }, "source": [ "此外,它还提供了一个 `get_tensor` 方法,允许您检查变量的值:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "cDJO3cgmecvi" }, "outputs": [ { "data": { "text/plain": [ "array([[3.934003 , 4.0978975, 4.122424 , 4.1820807, 4.423298 ]],\n", " dtype=float32)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reader.get_tensor(key)" ] }, { "cell_type": "markdown", "metadata": { "id": "5fxk_BnZ4W1b" }, "source": [ "### 对象跟踪\n", "\n", "检查点通过“跟踪”一个特性中的任何变量或可跟踪对象集来保存和回复 `tf.Variable` 对象的值。执行保存时,将从所有可访问的跟踪对象递归收集变量。\n", "\n", "对于像 `self.l1 = tf.keras.layers.Dense(5)` 一样的直接特性赋值,将列表和字典分配给特性会跟踪其内容。" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "rfaIbDtDHAr_" }, "outputs": [], "source": [ "save = tf.train.Checkpoint()\n", "save.listed = [tf.Variable(1.)]\n", "save.listed.append(tf.Variable(2.))\n", "save.mapped = {'one': save.listed[0]}\n", "save.mapped['two'] = save.listed[1]\n", "save_path = save.save(temp_dir/'./tf_list_example')\n", "\n", "restore = tf.train.Checkpoint()\n", "v2 = tf.Variable(0.)\n", "assert 0. == v2.numpy() # Not restored yet\n", "restore.mapped = {'two': v2}\n", "restore.restore(save_path)\n", "assert 2. == v2.numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "UTKvbxHcI3T2" }, "source": [ "您可能会注意到列表和字典的包装器对象。这些包装器是可设置检查点版本的基础数据结构。就像基于特性的加载一样,这些包装器会在将变量添加到容器后立即恢复它的值。" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "s0Uq1Hv5JCmm" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ListWrapper([])\n" ] } ], "source": [ "restore.listed = []\n", "print(restore.listed) # ListWrapper([])\n", "v1 = tf.Variable(0.)\n", "restore.listed.append(v1) # Restores v1, from restore() in the previous cell\n", "assert 1. == v1.numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "OxCIf2J6JyQ8" }, "source": [ "可跟踪对象包括 `tf.train.Checkpoint`、`tf.Module` 及其子类(例如 `keras.layers.Layer` 和 `keras.Model`),并识别 Python 容器:\n", "\n", "- `dict`(和 `collections.OrderedDict`)\n", "- `list`\n", "- `tuple`(和 `collections.namedtuple`、`typing.NamedTuple`)\n", "\n", "其他容器类型**不受支持**,包括:\n", "\n", "- `collections.defaultdict`\n", "- `set`\n", "\n", "所有其他 Python 对象都会被**忽略**,包括:\n", "\n", "- `int`\n", "- `string`\n", "- `float`\n" ] }, { "cell_type": "markdown", "metadata": { "id": "knyUFMrJg8y4" }, "source": [ "## 总结\n", "\n", "TensorFlow 对象提供了一种简单的自动机制来保存和恢复它们所使用变量的值。\n" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "checkpoint.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "xxx", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 0 }