{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "W7rEsKyWcxmu" }, "source": [ "##### Copyright 2023 The TF-Agents Authors.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "nQnmcm0oI1Q-" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "G6aOV15Wc4HP" }, "source": [ "# Checkpointer 和 PolicySaver\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "M3HE5S3wsMEh" }, "source": [ "## 简介\n", "\n", "`tf_agents.utils.common.Checkpointer` 是一个用于将训练状态、策略状态和 replay_buffer 状态保存到本地存储空间或从本地存储空间加载的实用工具。\n", "\n", "`tf_agents.policies.policy_saver.PolicySaver` 是一个仅供保存/加载策略的工具,比 `Checkpointer` 更为轻量化。您也可以使用 `PolicySaver` 来部署模型,而无需了解与创建策略所用代码有关的任何信息。\n", "\n", "在本教程中,我们将使用 DQN 来训练模型,然后使用 `Checkpointer` 和 `PolicySaver` 展示如何以交互方式存储和加载状态与模型。请注意,我们将为 `PolicySaver` 使用 TF2.0 的新版 saved_model 工具和格式。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vbTrDrX4dkP_" }, "source": [ "## 设置" ] }, { "cell_type": "markdown", "metadata": { "id": "Opk_cVDYdgct" }, "source": [ "如果尚未安装以下依赖项,请运行以下命令:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Jv668dKvZmka" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "!sudo apt-get update\n", "!sudo apt-get install -y xvfb ffmpeg python-opengl\n", "!pip install pyglet\n", "!pip install 'imageio==2.4.0'\n", "!pip install 'xvfbwrapper==0.2.9'\n", "!pip install tf-agents[reverb]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bQMULMo1dCEn" }, "outputs": [], "source": [ "from __future__ import absolute_import\n", "from __future__ import division\n", "from __future__ import print_function\n", "\n", "import base64\n", "import imageio\n", "import io\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import os\n", "import shutil\n", "import tempfile\n", "import tensorflow as tf\n", "import zipfile\n", "import IPython\n", "\n", "try:\n", " from google.colab import files\n", "except ImportError:\n", " files = None\n", "from tf_agents.agents.dqn import dqn_agent\n", "from tf_agents.drivers import dynamic_step_driver\n", "from tf_agents.environments import suite_gym\n", "from tf_agents.environments import tf_py_environment\n", "from tf_agents.eval import metric_utils\n", "from tf_agents.metrics import tf_metrics\n", "from tf_agents.networks import q_network\n", "from tf_agents.policies import policy_saver\n", "from tf_agents.policies import py_tf_eager_policy\n", "from tf_agents.policies import random_tf_policy\n", "from tf_agents.replay_buffers import tf_uniform_replay_buffer\n", "from tf_agents.trajectories import trajectory\n", "from tf_agents.utils import common\n", "\n", "tempdir = os.getenv(\"TEST_TMPDIR\", tempfile.gettempdir())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AwIqiLdDCX9Q" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "# Set up a virtual display for rendering OpenAI gym environments.\n", "import xvfbwrapper\n", "xvfbwrapper.Xvfb(1400, 900, 24).start()" ] }, { "cell_type": "markdown", "metadata": { "id": "AOv_kofIvWnW" }, "source": [ "## DQN 代理\n", "\n", "与此前的 Colab 一样,我们将建立 DQN 代理。详细信息在默认情况下处于隐藏状态,因为它们不是此 Colab 的核心部分,但是您可以点击“SHOW CODE”查看详细信息。" ] }, { "cell_type": "markdown", "metadata": { "id": "cStmaxredFSW" }, "source": [ "### 超参数" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "both", "id": "yxFs6QU0dGI_" }, "outputs": [], "source": [ "env_name = \"CartPole-v1\"\n", "\n", "collect_steps_per_iteration = 100\n", "replay_buffer_capacity = 100000\n", "\n", "fc_layer_params = (100,)\n", "\n", "batch_size = 64\n", "learning_rate = 1e-3\n", "log_interval = 5\n", "\n", "num_eval_episodes = 10\n", "eval_interval = 1000" ] }, { "cell_type": "markdown", "metadata": { "id": "w4GR7RDndIOR" }, "source": [ "### 环境" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fZwK4d-bdI7Z" }, "outputs": [], "source": [ "train_py_env = suite_gym.load(env_name)\n", "eval_py_env = suite_gym.load(env_name)\n", "\n", "train_env = tf_py_environment.TFPyEnvironment(train_py_env)\n", "eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "0AvYRwfkeMvo" }, "source": [ "### 代理" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cUrFl83ieOvV" }, "outputs": [], "source": [ "#@title\n", "q_net = q_network.QNetwork(\n", " train_env.observation_spec(),\n", " train_env.action_spec(),\n", " fc_layer_params=fc_layer_params)\n", "\n", "optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)\n", "\n", "global_step = tf.compat.v1.train.get_or_create_global_step()\n", "\n", "agent = dqn_agent.DqnAgent(\n", " train_env.time_step_spec(),\n", " train_env.action_spec(),\n", " q_network=q_net,\n", " optimizer=optimizer,\n", " td_errors_loss_fn=common.element_wise_squared_loss,\n", " train_step_counter=global_step)\n", "agent.initialize()" ] }, { "cell_type": "markdown", "metadata": { "id": "p8ganoJhdsbn" }, "source": [ "### 数据收集" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XiT1p78HdtSe" }, "outputs": [], "source": [ "#@title\n", "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n", " data_spec=agent.collect_data_spec,\n", " batch_size=train_env.batch_size,\n", " max_length=replay_buffer_capacity)\n", "\n", "collect_driver = dynamic_step_driver.DynamicStepDriver(\n", " train_env,\n", " agent.collect_policy,\n", " observers=[replay_buffer.add_batch],\n", " num_steps=collect_steps_per_iteration)\n", "\n", "# Initial data collection\n", "collect_driver.run()\n", "\n", "# Dataset generates trajectories with shape [BxTx...] where\n", "# T = n_step_update + 1.\n", "dataset = replay_buffer.as_dataset(\n", " num_parallel_calls=3, sample_batch_size=batch_size,\n", " num_steps=2).prefetch(3)\n", "\n", "iterator = iter(dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "8V8bojrKdupW" }, "source": [ "### 训练代理" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-rDC3leXdvm_" }, "outputs": [], "source": [ "#@title\n", "# (Optional) Optimize by wrapping some of the code in a graph using TF function.\n", "agent.train = common.function(agent.train)\n", "\n", "def train_one_iteration():\n", "\n", " # Collect a few steps using collect_policy and save to the replay buffer.\n", " collect_driver.run()\n", "\n", " # Sample a batch of data from the buffer and update the agent's network.\n", " experience, unused_info = next(iterator)\n", " train_loss = agent.train(experience)\n", "\n", " iteration = agent.train_step_counter.numpy()\n", " print ('iteration: {0} loss: {1}'.format(iteration, train_loss.loss))" ] }, { "cell_type": "markdown", "metadata": { "id": "vgqVaPnUeDAn" }, "source": [ "### 视频生成" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZY6w-fcieFDW" }, "outputs": [], "source": [ "#@title\n", "def embed_gif(gif_buffer):\n", " \"\"\"Embeds a gif file in the notebook.\"\"\"\n", " tag = ''.format(base64.b64encode(gif_buffer).decode())\n", " return IPython.display.HTML(tag)\n", "\n", "def run_episodes_and_create_video(policy, eval_tf_env, eval_py_env):\n", " num_episodes = 3\n", " frames = []\n", " for _ in range(num_episodes):\n", " time_step = eval_tf_env.reset()\n", " frames.append(eval_py_env.render())\n", " while not time_step.is_last():\n", " action_step = policy.action(time_step)\n", " time_step = eval_tf_env.step(action_step.action)\n", " frames.append(eval_py_env.render())\n", " gif_file = io.BytesIO()\n", " imageio.mimsave(gif_file, frames, format='gif', fps=60)\n", " IPython.display.display(embed_gif(gif_file.getvalue()))" ] }, { "cell_type": "markdown", "metadata": { "id": "y-oA8VYJdFdj" }, "source": [ "### 生成视频\n", "\n", "通过生成视频来检查策略的性能。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FpmPLXWbdG70" }, "outputs": [], "source": [ "print ('global_step:')\n", "print (global_step)\n", "run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "7RPLExsxwnOm" }, "source": [ "## 设置 Checkpointer 和 PolicySaver\n", "\n", "现在,我们已做好了使用 Checkpointer 和 PolicySaver 的准备工作。" ] }, { "cell_type": "markdown", "metadata": { "id": "g-iyQJacfQqO" }, "source": [ "### Checkpointer\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2DzCJZ-6YYbX" }, "outputs": [], "source": [ "checkpoint_dir = os.path.join(tempdir, 'checkpoint')\n", "train_checkpointer = common.Checkpointer(\n", " ckpt_dir=checkpoint_dir,\n", " max_to_keep=1,\n", " agent=agent,\n", " policy=agent.policy,\n", " replay_buffer=replay_buffer,\n", " global_step=global_step\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "MKpWNZM4WE8d" }, "source": [ "### PolicySaver" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8mDZ_YMUWEY9" }, "outputs": [], "source": [ "policy_dir = os.path.join(tempdir, 'policy')\n", "tf_policy_saver = policy_saver.PolicySaver(agent.policy)" ] }, { "cell_type": "markdown", "metadata": { "id": "1OnANb1Idx8-" }, "source": [ "### 训练一个迭代" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ql_D1iq8dl0X" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "print('Training one iteration....')\n", "train_one_iteration()" ] }, { "cell_type": "markdown", "metadata": { "id": "eSChNSQPlySb" }, "source": [ "### 保存到检查点" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "usDm_Wpsl0bu" }, "outputs": [], "source": [ "train_checkpointer.save(global_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "gTQUrKgihuic" }, "source": [ "### 恢复检查点\n", "\n", "为此,应使用创建检查点时所用的相同方法重新创建整个对象集。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l6l3EB-Yhwmz" }, "outputs": [], "source": [ "train_checkpointer.initialize_or_restore()\n", "global_step = tf.compat.v1.train.get_global_step()" ] }, { "cell_type": "markdown", "metadata": { "id": "Nb8_MSE2XjRp" }, "source": [ "还需保存策略并导出到特定位置" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3xHz09WCWjwA" }, "outputs": [], "source": [ "tf_policy_saver.save(policy_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "Mz-xScbuh4Vo" }, "source": [ "无需了解创建策略所用代理或网络,即可加载策略。这使策略的部署更加简单。\n", "\n", "加载保存的策略并检查其执行情况" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J6T5KLTMh9ZB" }, "outputs": [], "source": [ "saved_policy = tf.saved_model.load(policy_dir)\n", "run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "MpE0KKfqjc0c" }, "source": [ "## 导出和导入\n", "\n", "Colab 的其余部分将帮助您导出/导入 checkpointer 和策略目录,以便您可以在以后继续训练并部署模型,而无需重新训练。\n", "\n", "现在,您可以返回“训练一个迭代”部分并进行更多次训练,以便发现其中的区别。一旦您发现开始出现更佳的结果,请继续下面的步骤。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fd5Cj7DVjfH4" }, "outputs": [], "source": [ "#@title Create zip file and upload zip file (double-click to see the code)\n", "def create_zip_file(dirname, base_filename):\n", " return shutil.make_archive(base_filename, 'zip', dirname)\n", "\n", "def upload_and_unzip_file_to(dirname):\n", " if files is None:\n", " return\n", " uploaded = files.upload()\n", " for fn in uploaded.keys():\n", " print('User uploaded file \"{name}\" with length {length} bytes'.format(\n", " name=fn, length=len(uploaded[fn])))\n", " shutil.rmtree(dirname)\n", " zip_files = zipfile.ZipFile(io.BytesIO(uploaded[fn]), 'r')\n", " zip_files.extractall(dirname)\n", " zip_files.close()" ] }, { "cell_type": "markdown", "metadata": { "id": "hgyy29doHCmL" }, "source": [ "创建检查点目录的压缩文件。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nhR8NeWzF4fe" }, "outputs": [], "source": [ "train_checkpointer.save(global_step)\n", "checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))" ] }, { "cell_type": "markdown", "metadata": { "id": "VGEpntTocd2u" }, "source": [ "下载 zip 文件。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "upFxb5k8b4MC" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "if files is not None:\n", " files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469" ] }, { "cell_type": "markdown", "metadata": { "id": "VRaZMrn5jLmE" }, "source": [ "训练一段时间(10-15 次)后,下载检查点 zip 文件,并转到“Runtime > Restart and run all”以重置训练,然后返回此单元。现在,您可以上传下载的 zip 文件,然后继续训练。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kg-bKgMsF-H_" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "upload_and_unzip_file_to(checkpoint_dir)\n", "train_checkpointer.initialize_or_restore()\n", "global_step = tf.compat.v1.train.get_global_step()" ] }, { "cell_type": "markdown", "metadata": { "id": "uXrNax5Zk3vF" }, "source": [ "上传检查点目录后,返回“训练一个迭代”部分以继续训练,或者返回“生成视频”部分以检查加载的策略的性能。" ] }, { "cell_type": "markdown", "metadata": { "id": "OAkvVZ-NeN2j" }, "source": [ "或者,您也可以保存策略(模型)并进行恢复。与 checkpointer 不同,您无法继续训练,但仍可以部署模型。请注意,与 checkpointer 相比,此方式下载的文件要小得多。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "s7qMn6D8eiIA" }, "outputs": [], "source": [ "tf_policy_saver.save(policy_dir)\n", "policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rrGvCEXwerJj" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "if files is not None:\n", " files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469" ] }, { "cell_type": "markdown", "metadata": { "id": "DyC_O_gsgSi5" }, "source": [ "上传下载的策略目录 (exported_policy.zip),并检查保存的策略的性能。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bgWLimRlXy5z" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "upload_and_unzip_file_to(policy_dir)\n", "saved_policy = tf.saved_model.load(policy_dir)\n", "run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "HSehXThTm4af" }, "source": [ "## SavedModelPyTFEagerPolicy\n", "\n", "如果您不想使用 TF 策略,还可以通过使用 `py_tf_eager_policy.SavedModelPyTFEagerPolicy` 直接在 Python 环境下使用 saved_model。\n", "\n", "请注意,此方法仅在启用 Eager 模式时有效。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iUC5XuLf1jF7" }, "outputs": [], "source": [ "eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(\n", " policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())\n", "\n", "# Note that we're passing eval_py_env not eval_env.\n", "run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "7fvWqfJg00ww" }, "source": [ "## 将策略转换为 TFLite\n", "\n", "请参阅 [TensorFlow Lite 转换器](https://tensorflow.google.cn/lite/convert),了解详细信息。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z9zonVBJ0z46" }, "outputs": [], "source": [ "converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=[\"action\"])\n", "tflite_policy = converter.convert()\n", "with open(os.path.join(tempdir, 'policy.tflite'), 'wb') as f:\n", " f.write(tflite_policy)" ] }, { "cell_type": "markdown", "metadata": { "id": "rsi3V9QdxJUu" }, "source": [ "### 在 TFLite 模型上运行推断\n", "\n", "如需了解详情,请参阅 [TensorFlow Lite 推断](https://tensorflow.org/lite/guide/inference)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4GeUSWyZxMlN" }, "outputs": [], "source": [ "import numpy as np\n", "interpreter = tf.lite.Interpreter(os.path.join(tempdir, 'policy.tflite'))\n", "\n", "policy_runner = interpreter.get_signature_runner()\n", "print(policy_runner._inputs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eVVrdTbRxnOC" }, "outputs": [], "source": [ "policy_runner(**{\n", " '0/discount':tf.constant(0.0),\n", " '0/observation':tf.zeros([1,4]),\n", " '0/reward':tf.constant(0.0),\n", " '0/step_type':tf.constant(0)})" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "10_checkpointer_policysaver_tutorial.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }