{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "klGNgWREsvQv" }, "outputs": [], "source": [ "##### Copyright 2023 The TF-Agents Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "nQnmcm0oI1Q-" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "pmDI-h7cI0tI" }, "source": [ "# 使用 TF-Agents 训练深度 Q 网络\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 设置日志级别为ERROR,以减少警告信息\n", "# 禁用 Gemini 的底层库(gRPC 和 Abseil)在初始化日志警告\n", "os.environ[\"GRPC_VERBOSITY\"] = \"ERROR\"\n", "os.environ[\"GLOG_minloglevel\"] = \"3\" # 0: INFO, 1: WARNING, 2: ERROR, 3: FATAL\n", "os.environ[\"GLOG_minloglevel\"] = \"true\"\n", "import logging\n", "import tensorflow as tf\n", "tf.get_logger().setLevel(logging.ERROR)\n", "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n", "!export TF_FORCE_GPU_ALLOW_GROWTH=true\n", "from pathlib import Path\n", "\n", "temp_dir = Path(\".temp\")\n", "temp_dir.mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "lsaQlK8fFQqH" }, "source": [ "## 简介\n" ] }, { "cell_type": "markdown", "metadata": { "id": "cKOCZlhUgXVK" }, "source": [ "本示例展示了如何使用 TF-Agents 库在 Cartpole 环境下训练 [DQN(深度 Q 网络)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)代理。\n", "\n", "![Cartpole environment](https://raw.githubusercontent.com/tensorflow/agents/master/docs/tutorials/images/cartpole.png)\n", "\n", "示例将引导您逐步了解强化学习 (RL) 的训练、评估和数据收集流水线的所有组成部分。\n", "\n", "要实时运行此代码,请点击上方的“在 Google Colab 中运行”链接。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "1u9QVVsShC9X" }, "source": [ "## 设置" ] }, { "cell_type": "markdown", "metadata": { "id": "kNrNXKI7bINP" }, "source": [ "如果尚未安装以下依赖项,请运行以下命令:" ] }, { "cell_type": "markdown", "metadata": { "id": "KEHR2Ui-lo8O" }, "source": [ "```bash\n", "sudo apt-get update\n", "sudo apt-get install -y xvfb ffmpeg freeglut3-dev\n", "pip install imageio\n", "pip install pyvirtualdisplay\n", "pip install tf-agents[reverb]\n", "pip install pyglet\n", "```" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "sMitx5qSgJk1" }, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'reverb'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[1], line 9\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mPIL\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mImage\u001b[39;00m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpyvirtualdisplay\u001b[39;00m\n\u001b[0;32m----> 9\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mreverb\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtensorflow\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mtf\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtf_agents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01magents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdqn\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m dqn_agent\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'reverb'" ] } ], "source": [ "import base64\n", "import imageio\n", "import IPython\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import PIL.Image\n", "import pyvirtualdisplay\n", "import reverb\n", "\n", "import tensorflow as tf\n", "\n", "from tf_agents.agents.dqn import dqn_agent\n", "from tf_agents.drivers import py_driver\n", "from tf_agents.environments import suite_gym\n", "from tf_agents.environments import tf_py_environment\n", "from tf_agents.eval import metric_utils\n", "from tf_agents.metrics import tf_metrics\n", "from tf_agents.networks import sequential\n", "from tf_agents.policies import py_tf_eager_policy\n", "from tf_agents.policies import random_tf_policy\n", "from tf_agents.replay_buffers import reverb_replay_buffer\n", "from tf_agents.replay_buffers import reverb_utils\n", "from tf_agents.trajectories import trajectory\n", "from tf_agents.specs import tensor_spec\n", "from tf_agents.utils import common" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J6HsdS5GbSjd" }, "outputs": [], "source": [ "# Set up a virtual display for rendering OpenAI gym environments.\n", "display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NspmzG4nP3b9" }, "outputs": [], "source": [ "tf.version.VERSION" ] }, { "cell_type": "markdown", "metadata": { "id": "LmC0NDhdLIKY" }, "source": [ "## 超参数" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HC1kNrOsLSIZ" }, "outputs": [], "source": [ "num_iterations = 20000 # @param {type:\"integer\"}\n", "\n", "initial_collect_steps = 100 # @param {type:\"integer\"}\n", "collect_steps_per_iteration = 1# @param {type:\"integer\"}\n", "replay_buffer_max_length = 100000 # @param {type:\"integer\"}\n", "\n", "batch_size = 64 # @param {type:\"integer\"}\n", "learning_rate = 1e-3 # @param {type:\"number\"}\n", "log_interval = 200 # @param {type:\"integer\"}\n", "\n", "num_eval_episodes = 10 # @param {type:\"integer\"}\n", "eval_interval = 1000 # @param {type:\"integer\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "VMsJC3DEgI0x" }, "source": [ "## 环境\n", "\n", "在强化学习 (RL) 中,环境代表要解决的任务或问题。可以使用 `tf_agents.environments` 套件在 TF-Agents 中创建标准环境。TF-Agents 具有用于从 OpenAI Gym、Atari 和 DM Control 等来源加载环境的套件。\n", "\n", "从 OpenAI Gym 套件加载 CartPole 环境。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pYEz-S9gEv2-" }, "outputs": [], "source": [ "env_name = 'CartPole-v0'\n", "env = suite_gym.load(env_name)" ] }, { "cell_type": "markdown", "metadata": { "id": "IIHYVBkuvPNw" }, "source": [ "您可以渲染此环境以查看其形式。小车上连接一条自由摆动的长杆。目标是向右或向左移动小车,使长杆保持朝上。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RlO7WIQHu_7D" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "env.reset()\n", "PIL.Image.fromarray(env.render())" ] }, { "cell_type": "markdown", "metadata": { "id": "B9_lskPOey18" }, "source": [ "`environment.step` 方法会在环境中执行 `action` 并返回 `TimeStep` 元组,其中包含环境的下一观测值以及该操作的奖励。\n", "\n", "`time_step_spec()` 方法会返回 `TimeStep` 元组的规范。其 `observation` 特性显示了观测值的形状、数据类型和允许值的范围。`reward` 特性显示了奖励的相同详细信息。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "exDv57iHfwQV" }, "outputs": [], "source": [ "print('Observation Spec:')\n", "print(env.time_step_spec().observation)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UxiSyCbBUQPi" }, "outputs": [], "source": [ "print('Reward Spec:')\n", "print(env.time_step_spec().reward)" ] }, { "cell_type": "markdown", "metadata": { "id": "b_lHcIcqUaqB" }, "source": [ "`action_spec()` 方法会返回有效操作的形状、数据类型和允许的值。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bttJ4uxZUQBr" }, "outputs": [], "source": [ "print('Action Spec:')\n", "print(env.action_spec())" ] }, { "cell_type": "markdown", "metadata": { "id": "eJCgJnx3g0yY" }, "source": [ "在 Cartpole 环境中:\n", "\n", "- `observation` 是由 4 个浮点数组成的数组:\n", " - 小车的位置和速度\n", " - 长杆的角位置和角速度\n", "- `reward` 是一个浮点标量\n", "- `action` 是一个整数标量,只有两个可能的值:\n", " - `0` -“向左移动”\n", " - `1` -“向右移动”\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V2UGR5t_iZX-" }, "outputs": [], "source": [ "time_step = env.reset()\n", "print('Time step:')\n", "print(time_step)\n", "\n", "action = np.array(1, dtype=np.int32)\n", "\n", "next_time_step = env.step(action)\n", "print('Next time step:')\n", "print(next_time_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "4JSc9GviWUBK" }, "source": [ "通常需要实例化两个环境:一个用于训练,一个用于评估。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N7brXNIGWXjC" }, "outputs": [], "source": [ "train_py_env = suite_gym.load(env_name)\n", "eval_py_env = suite_gym.load(env_name)" ] }, { "cell_type": "markdown", "metadata": { "id": "zuUqXAVmecTU" }, "source": [ "与大多数环境一样,Cartpole 环境采用纯 Python 编写。需要使用 `TFPyEnvironment` 封装容器将其转换为 TensorFlow。\n", "\n", "原始环境的 API 使用 Numpy 数组。`TFPyEnvironment` 会将它们转换为 `Tensors` 以使其与 Tensorflow 代理和策略兼容。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xp-Y4mD6eDhF" }, "outputs": [], "source": [ "train_env = tf_py_environment.TFPyEnvironment(train_py_env)\n", "eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "E9lW_OZYFR8A" }, "source": [ "## 代理\n", "\n", "用于解决 RL 问题的算法由 `Agent` 表示。TF-Agents 提供了各种 `Agents` 的标准实现,包括:\n", "\n", "- [DQN](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)(本教程使用)\n", "- [REINFORCE](https://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)\n", "- [DDPG](https://arxiv.org/pdf/1509.02971.pdf)\n", "- [TD3](https://arxiv.org/pdf/1802.09477.pdf)\n", "- [PPO](https://arxiv.org/abs/1707.06347)\n", "- [SAC](https://arxiv.org/abs/1801.01290)\n", "\n", "DQN 代理可以在具有离散操作空间的任何环境中使用。\n", "\n", "DQN 代理的核心是 `QNetwork`,它是一种神经网络模型,可以基于环境的观测值来学习以预测所有操作的 `QValues`(预期回报)。\n", "\n", "我们将使用 `tf_agents.networks.` 创建一个 `QNetwork`。该网络将由一系列 `tf.keras.layers.Dense` 层组成,其中最后一层将为每个可能的动作提供 1 个输出。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TgkdEPg_muzV" }, "outputs": [], "source": [ "fc_layer_params = (100, 50)\n", "action_tensor_spec = tensor_spec.from_spec(env.action_spec())\n", "num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1\n", "\n", "# Define a helper function to create Dense layers configured with the right\n", "# activation and kernel initializer.\n", "def dense_layer(num_units):\n", " return tf.keras.layers.Dense(\n", " num_units,\n", " activation=tf.keras.activations.relu,\n", " kernel_initializer=tf.keras.initializers.VarianceScaling(\n", " scale=2.0, mode='fan_in', distribution='truncated_normal'))\n", "\n", "# QNetwork consists of a sequence of Dense layers followed by a dense layer\n", "# with `num_actions` units to generate one q_value per available action as\n", "# its output.\n", "dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]\n", "q_values_layer = tf.keras.layers.Dense(\n", " num_actions,\n", " activation=None,\n", " kernel_initializer=tf.keras.initializers.RandomUniform(\n", " minval=-0.03, maxval=0.03),\n", " bias_initializer=tf.keras.initializers.Constant(-0.2))\n", "q_net = sequential.Sequential(dense_layers + [q_values_layer])" ] }, { "cell_type": "markdown", "metadata": { "id": "z62u55hSmviJ" }, "source": [ "现在,使用 `tf_agents.agents.dqn.dqn_agent` 实例化 `DqnAgent`。除了 `time_step_spec`、`action_spec` 和 QNetwork,代理构造函数还需要优化器 (本例为 `AdamOptimizer`)、损失函数和整数计步器。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jbY4yrjTEyc9" }, "outputs": [], "source": [ "optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n", "\n", "train_step_counter = tf.Variable(0)\n", "\n", "agent = dqn_agent.DqnAgent(\n", " train_env.time_step_spec(),\n", " train_env.action_spec(),\n", " q_network=q_net,\n", " optimizer=optimizer,\n", " td_errors_loss_fn=common.element_wise_squared_loss,\n", " train_step_counter=train_step_counter)\n", "\n", "agent.initialize()" ] }, { "cell_type": "markdown", "metadata": { "id": "I0KLrEPwkn5x" }, "source": [ "## 策略\n", "\n", "策略定义了代理在环境中的行为方式。通常,强化学习的目标是训练基础模型,直到策略产生期望的结果为止。\n", "\n", "在本教程中:\n", "\n", "- 期望的结果是使长杆在小车上保持平衡直立状态。\n", "- 策略会针对每个 `time_step` 观测值返回一个操作(向左或向右)。\n", "\n", "代理包含两个策略:\n", "\n", "- `agent.policy` - 用于评估和部署的主策略。\n", "- `agent.collect_policy` - 用于数据收集的第二策略。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BwY7StuMkuV4" }, "outputs": [], "source": [ "eval_policy = agent.policy\n", "collect_policy = agent.collect_policy" ] }, { "cell_type": "markdown", "metadata": { "id": "2Qs1Fl3dV0ae" }, "source": [ "可以独立于代理创建策略。例如,使用 `tf_agents.policies.random_tf_policy` 创建策略,将为每个 `time_step` 随机选择一个操作。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HE37-UCIrE69" }, "outputs": [], "source": [ "random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),\n", " train_env.action_spec())" ] }, { "cell_type": "markdown", "metadata": { "id": "dOlnlRRsUbxP" }, "source": [ "要从策略中获取操作,请调用 `policy.action(time_step)` 方法。`time_step` 包含来自环境的观测值。此方法会返回 `PolicyStep`,此为命名元组,包含以下三个组成部分:\n", "\n", "- `action` - 要执行的操作(本例为 `0` 或 `1`)\n", "- `state` - 用于有状态(即基于 RNN 的)策略\n", "- `info` - 辅助数据,例如操作的日志概率" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5gCcpXswVAxk" }, "outputs": [], "source": [ "example_environment = tf_py_environment.TFPyEnvironment(\n", " suite_gym.load('CartPole-v0'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "D4DHZtq3Ndis" }, "outputs": [], "source": [ "time_step = example_environment.reset()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PRFqAUzpNaAW" }, "outputs": [], "source": [ "random_policy.action(time_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "94rCXQtbUbXv" }, "source": [ "## 指标和评估\n", "\n", "用于评估策略的最常用指标是平均回报。回报是针对某个片段在环境中运行策略时获得的奖励总和。运行多个片段后,即可创建平均回报。\n", "\n", "以下函数可基于特定策略、环境和多个片段来计算策略的平均回报。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bitzHo5_UbXy" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "def compute_avg_return(environment, policy, num_episodes=10):\n", "\n", " total_return = 0.0\n", " for _ in range(num_episodes):\n", "\n", " time_step = environment.reset()\n", " episode_return = 0.0\n", "\n", " while not time_step.is_last():\n", " action_step = policy.action(time_step)\n", " time_step = environment.step(action_step.action)\n", " episode_return += time_step.reward\n", " total_return += episode_return\n", "\n", " avg_return = total_return / num_episodes\n", " return avg_return.numpy()[0]\n", "\n", "\n", "# See also the metrics module for standard implementations of different metrics.\n", "# https://github.com/tensorflow/agents/tree/master/tf_agents/metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "_snCVvq5Z8lJ" }, "source": [ "在 `random_policy` 上运行此计算将显示环境中的基线性能。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9bgU6Q6BZ8Bp" }, "outputs": [], "source": [ "compute_avg_return(eval_env, random_policy, num_eval_episodes)" ] }, { "cell_type": "markdown", "metadata": { "id": "NLva6g2jdWgr" }, "source": [ "## 回放缓冲区\n", "\n", "为了跟踪从环境收集的数据,我们将使用 [Reverb](https://deepmind.com/research/open-source/Reverb),这是 Deepmind 打造的一款高效、可扩展且易于使用的回放系统。它会在我们收集轨迹时存储经验数据,并在训练期间使用。\n", "\n", "回放缓冲区使用描述要存储的张量的规范构造,可以使用 agent.collect_data_spec 从代理获取这些张量。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vX2zGUWJGWAl" }, "outputs": [], "source": [ "table_name = 'uniform_table'\n", "replay_buffer_signature = tensor_spec.from_spec(\n", " agent.collect_data_spec)\n", "replay_buffer_signature = tensor_spec.add_outer_dim(\n", " replay_buffer_signature)\n", "\n", "table = reverb.Table(\n", " table_name,\n", " max_size=replay_buffer_max_length,\n", " sampler=reverb.selectors.Uniform(),\n", " remover=reverb.selectors.Fifo(),\n", " rate_limiter=reverb.rate_limiters.MinSize(1),\n", " signature=replay_buffer_signature)\n", "\n", "reverb_server = reverb.Server([table])\n", "\n", "replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(\n", " agent.collect_data_spec,\n", " table_name=table_name,\n", " sequence_length=2,\n", " local_server=reverb_server)\n", "\n", "rb_observer = reverb_utils.ReverbAddTrajectoryObserver(\n", " replay_buffer.py_client,\n", " table_name,\n", " sequence_length=2)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZGNTDJpZs4NN" }, "source": [ "对于大多数代理来说,`collect_data_spec` 是一个名为 `Trajectory` 的命名元组,其中包含观测值、操作、奖励和其他项目的规范。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_IZ-3HcqgE1z" }, "outputs": [], "source": [ "agent.collect_data_spec" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sy6g1tGcfRlw" }, "outputs": [], "source": [ "agent.collect_data_spec._fields" ] }, { "cell_type": "markdown", "metadata": { "id": "rVD5nQ9ZGo8_" }, "source": [ "## 数据收集\n", "\n", "现在,在环境中将随机策略执行几个步骤,这会将数据记录在回放缓冲区中。\n", "\n", "在这里,我们使用“PyDriver”来运行经验收集循环。您可以在我们的[驱动程序教程](https://tensorflow.google.cn/agents/tutorials/4_drivers_tutorial)中详细了解 TF Agents 驱动程序。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wr1KSAEGG4h9" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "py_driver.PyDriver(\n", " env,\n", " py_tf_eager_policy.PyTFEagerPolicy(\n", " random_policy, use_tf_function=True),\n", " [rb_observer],\n", " max_steps=initial_collect_steps).run(train_py_env.reset())" ] }, { "cell_type": "markdown", "metadata": { "id": "84z5pQJdoKxo" }, "source": [ "回放缓冲区现在是一个轨迹的集合。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4wZnLu2ViO4E" }, "outputs": [], "source": [ "# For the curious:\n", "# Uncomment to peel one of these off and inspect it.\n", "# iter(replay_buffer.as_dataset()).next()" ] }, { "cell_type": "markdown", "metadata": { "id": "TujU-PMUsKjS" }, "source": [ "代理需要访问回放缓冲区。通过创建可迭代的 `tf.data.Dataset` 流水线即可实现访问,此流水线可将数据馈送给代理。\n", "\n", "回放缓冲区的每一行仅存储一个观测步骤。但是,由于 DQN 代理需要当前和下一个观测值来计算损失,因此数据集流水线将为批次中的每个项目采样两个相邻的行 (`num_steps=2`)。\n", "\n", "此数据集还通过运行并行调用和预提取数据进行了优化。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ba7bilizt_qW" }, "outputs": [], "source": [ "# Dataset generates trajectories with shape [Bx2x...]\n", "dataset = replay_buffer.as_dataset(\n", " num_parallel_calls=3,\n", " sample_batch_size=batch_size,\n", " num_steps=2).prefetch(3)\n", "\n", "dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "K13AST-2ppOq" }, "outputs": [], "source": [ "iterator = iter(dataset)\n", "print(iterator)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Th5w5Sff0b16" }, "outputs": [], "source": [ "# For the curious:\n", "# Uncomment to see what the dataset iterator is feeding to the agent.\n", "# Compare this representation of replay data \n", "# to the collection of individual trajectories shown earlier.\n", "\n", "# iterator.next()" ] }, { "cell_type": "markdown", "metadata": { "id": "hBc9lj9VWWtZ" }, "source": [ "## 训练代理\n", "\n", "训练循环中必须包含两个步骤:\n", "\n", "- 从环境中收集数据\n", "- 使用该数据训练代理的神经网络\n", "\n", "在此示例中,还会定期评估策略并打印当前分数。\n", "\n", "运行以下示例大约需要 5 分钟。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0pTbJ3PeyF-u" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "try:\n", " %%time\n", "except:\n", " pass\n", "\n", "# (Optional) Optimize by wrapping some of the code in a graph using TF function.\n", "agent.train = common.function(agent.train)\n", "\n", "# Reset the train step.\n", "agent.train_step_counter.assign(0)\n", "\n", "# Evaluate the agent's policy once before training.\n", "avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)\n", "returns = [avg_return]\n", "\n", "# Reset the environment.\n", "time_step = train_py_env.reset()\n", "\n", "# Create a driver to collect experience.\n", "collect_driver = py_driver.PyDriver(\n", " env,\n", " py_tf_eager_policy.PyTFEagerPolicy(\n", " agent.collect_policy, use_tf_function=True),\n", " [rb_observer],\n", " max_steps=collect_steps_per_iteration)\n", "\n", "for _ in range(num_iterations):\n", "\n", " # Collect a few steps and save to the replay buffer.\n", " time_step, _ = collect_driver.run(time_step)\n", "\n", " # Sample a batch of data from the buffer and update the agent's network.\n", " experience, unused_info = next(iterator)\n", " train_loss = agent.train(experience).loss\n", "\n", " step = agent.train_step_counter.numpy()\n", "\n", " if step % log_interval == 0:\n", " print('step = {0}: loss = {1}'.format(step, train_loss))\n", "\n", " if step % eval_interval == 0:\n", " avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)\n", " print('step = {0}: Average Return = {1}'.format(step, avg_return))\n", " returns.append(avg_return)" ] }, { "cell_type": "markdown", "metadata": { "id": "68jNcA_TiJDq" }, "source": [ "## 可视化\n" ] }, { "cell_type": "markdown", "metadata": { "id": "aO-LWCdbbOIC" }, "source": [ "### 绘图\n", "\n", "使用 `matplotlib.pyplot` 绘制图表,展示策略在训练过程中的改进方式。\n", "\n", "`Cartpole-v0` 的一个迭代包含 200 个时间步骤。长杆保持直立的每一步,环境都会分配 `+1` 奖励,因此一个片段的最大回报为 200。图表显示,在训练期间每次评估的回报都朝着该最大值递增(递增可能稍有不稳定情况,并且并非每次均为单调递增)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NxtL1mbOYCVO" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "\n", "iterations = range(0, num_iterations + 1, eval_interval)\n", "plt.plot(iterations, returns)\n", "plt.ylabel('Average Return')\n", "plt.xlabel('Iterations')\n", "plt.ylim(top=250)" ] }, { "cell_type": "markdown", "metadata": { "id": "M7-XpPP99Cy7" }, "source": [ "### 视频" ] }, { "cell_type": "markdown", "metadata": { "id": "9pGfGxSH32gn" }, "source": [ "图表非常实用,但能够看到代理在环境中真实地执行任务将更为生动。\n", "\n", "首先,创建一个函数以在笔记本内嵌入视频。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ULaGr8pvOKbl" }, "outputs": [], "source": [ "def embed_mp4(filename):\n", " \"\"\"Embeds an mp4 file in the notebook.\"\"\"\n", " video = open(filename,'rb').read()\n", " b64 = base64.b64encode(video)\n", " tag = '''\n", " '''.format(b64.decode())\n", "\n", " return IPython.display.HTML(tag)" ] }, { "cell_type": "markdown", "metadata": { "id": "9c_PH-pX4Pr5" }, "source": [ "现在,使用代理训练 Cartpole 模型,完成几个片段的迭代。底层 Python 环境(在 TensorFlow 环境封装容器的“内部”)提供了一个 `render()` 方法,可输出环境状态的图像。这些图像可被收集并处理成视频。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "owOVWB158NlF" }, "outputs": [], "source": [ "def create_policy_eval_video(policy, filename, num_episodes=5, fps=30):\n", " filename = filename + \".mp4\"\n", " with imageio.get_writer(filename, fps=fps) as video:\n", " for _ in range(num_episodes):\n", " time_step = eval_env.reset()\n", " video.append_data(eval_py_env.render())\n", " while not time_step.is_last():\n", " action_step = policy.action(time_step)\n", " time_step = eval_env.step(action_step.action)\n", " video.append_data(eval_py_env.render())\n", " return embed_mp4(filename)\n", "\n", "create_policy_eval_video(agent.policy, \"trained-agent\")" ] }, { "cell_type": "markdown", "metadata": { "id": "povaAOcZygLw" }, "source": [ "您可以将经过训练的代理(上方)与随机移动的代理(效果不佳)进行比较。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pJZIdC37yNH4" }, "outputs": [], "source": [ "create_policy_eval_video(random_policy, \"random-agent\")" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "1_dqn_tutorial.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "xxx", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 0 }