{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "klGNgWREsvQv" }, "source": [ "##### Copyright 2023 The TF-Agents Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "nQnmcm0oI1Q-" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "oMaGpi7TciQs" }, "source": [ "# DQN C51/Rainbow\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "ZOUOQOrFs3zn" }, "source": [ "## 简介" ] }, { "cell_type": "markdown", "metadata": { "id": "cKOCZlhUgXVK" }, "source": [ "本示例说明了如何使用 TF-Agents 库在 Cartpole 环境中训练[分类 DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf) 代理。\n", "\n", "![Cartpole environment](https://github.com/tensorflow/agents/blob/master/docs/tutorials/images/cartpole.png?raw=1)\n", "\n", "确保您已事先阅读 [DQN 教程](https://github.com/tensorflow/agents/blob/master/docs/tutorials/1_dqn_tutorial.ipynb)。本教程假定您熟悉 DQN 教程,并主要关注 DQN 与 C51 之间的差异。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "lsaQlK8fFQqH" }, "source": [ "## 设置\n" ] }, { "cell_type": "markdown", "metadata": { "id": "-NzBsZzPcyBm" }, "source": [ "如果尚未安装 TF-Agents,请运行以下命令:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KEHR2Ui-lo8O" }, "outputs": [], "source": [ "!sudo apt-get update\n", "!sudo apt-get install -y xvfb ffmpeg freeglut3-dev\n", "!pip install 'imageio==2.4.0'\n", "!pip install pyvirtualdisplay\n", "!pip install tf-agents\n", "!pip install pyglet" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sMitx5qSgJk1" }, "outputs": [], "source": [ "from __future__ import absolute_import\n", "from __future__ import division\n", "from __future__ import print_function\n", "\n", "import base64\n", "import imageio\n", "import IPython\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import PIL.Image\n", "import pyvirtualdisplay\n", "\n", "import tensorflow as tf\n", "\n", "from tf_agents.agents.categorical_dqn import categorical_dqn_agent\n", "from tf_agents.drivers import dynamic_step_driver\n", "from tf_agents.environments import suite_gym\n", "from tf_agents.environments import tf_py_environment\n", "from tf_agents.eval import metric_utils\n", "from tf_agents.metrics import tf_metrics\n", "from tf_agents.networks import categorical_q_network\n", "from tf_agents.policies import random_tf_policy\n", "from tf_agents.replay_buffers import tf_uniform_replay_buffer\n", "from tf_agents.trajectories import trajectory\n", "from tf_agents.utils import common\n", "\n", "# Set up a virtual display for rendering OpenAI gym environments.\n", "display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()" ] }, { "cell_type": "markdown", "metadata": { "id": "LmC0NDhdLIKY" }, "source": [ "## 超参数" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HC1kNrOsLSIZ" }, "outputs": [], "source": [ "env_name = \"CartPole-v1\" # @param {type:\"string\"}\n", "num_iterations = 15000 # @param {type:\"integer\"}\n", "\n", "initial_collect_steps = 1000 # @param {type:\"integer\"} \n", "collect_steps_per_iteration = 1 # @param {type:\"integer\"}\n", "replay_buffer_capacity = 100000 # @param {type:\"integer\"}\n", "\n", "fc_layer_params = (100,)\n", "\n", "batch_size = 64 # @param {type:\"integer\"}\n", "learning_rate = 1e-3 # @param {type:\"number\"}\n", "gamma = 0.99\n", "log_interval = 200 # @param {type:\"integer\"}\n", "\n", "num_atoms = 51 # @param {type:\"integer\"}\n", "min_q_value = -20 # @param {type:\"integer\"}\n", "max_q_value = 20 # @param {type:\"integer\"}\n", "n_step_update = 2 # @param {type:\"integer\"}\n", "\n", "num_eval_episodes = 10 # @param {type:\"integer\"}\n", "eval_interval = 1000 # @param {type:\"integer\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "VMsJC3DEgI0x" }, "source": [ "## 环境\n", "\n", "像以前一样加载环境,其中一个用于训练,另一个用于评估。在这里,我们使用 CartPole-v1(DQN 教程中则为 CartPole-v0),它的最大奖励是 500,而不是 200。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xp-Y4mD6eDhF" }, "outputs": [], "source": [ "train_py_env = suite_gym.load(env_name)\n", "eval_py_env = suite_gym.load(env_name)\n", "\n", "train_env = tf_py_environment.TFPyEnvironment(train_py_env)\n", "eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "E9lW_OZYFR8A" }, "source": [ "## 代理\n", "\n", "C51 是一种基于 DQN 的 Q-learning 算法。与 DQN 一样,它可以在具有离散操作空间的任何环境中使用。\n", "\n", "C51 与 DQN 之间的主要区别在于,C51 不仅可以简单地预测每个状态-操作对的 Q 值,还能预测表示 Q 值概率分布的直方图模型:\n", "\n", "![Example C51 Distribution](images/c51_distribution.png)\n", "\n", "通过学习分布而不是简单的期望值,此算法能够在训练过程中保持更稳定的状态,从而提高最终性能。这种算法尤其适用于具有双峰甚至多峰值分布的情况,此时单个平均值无法提供准确的概览。\n", "\n", "为了基于概率分布而不是值来训练,C51 必须执行一些复杂的分布计算才能计算其损失函数。但不用担心,我们已在 TF-Agents 中为您处理好一切!\n", "\n", "要创建 C51 代理,我们首先需要创建一个 `CategoricalQNetwork`。除了有一个附加参数 `num_atoms` 外,`CategoricalQNetwork` 的 API 与 `QNetwork` 的 API 相同。这表示我们的概率分布估算中的支撑点数。(上面的图像包括 10 个支撑点,每个支撑点都由垂直的蓝色条表示。)您可以从名称中看出,默认原子数为 51。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TgkdEPg_muzV" }, "outputs": [], "source": [ "categorical_q_net = categorical_q_network.CategoricalQNetwork(\n", " train_env.observation_spec(),\n", " train_env.action_spec(),\n", " num_atoms=num_atoms,\n", " fc_layer_params=fc_layer_params)" ] }, { "cell_type": "markdown", "metadata": { "id": "z62u55hSmviJ" }, "source": [ "我们还需要一个 `optimizer` 来训练刚刚创建的网络,以及一个 `train_step_counter` 变量来跟踪网络更新的次数。\n", "\n", "请注意,与普通 `DqnAgent` 的另一个重要区别在于,我们现在需要指定 `min_q_value` 和 `max_q_value` 作为参数。这两个参数指定了支撑点的最极端值(换句话说,任何一侧有全部 51 个原子)。确保为您的特定环境适当地选择这些值。在这里,我们使用 -20 和 20。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jbY4yrjTEyc9" }, "outputs": [], "source": [ "optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)\n", "\n", "train_step_counter = tf.Variable(0)\n", "\n", "agent = categorical_dqn_agent.CategoricalDqnAgent(\n", " train_env.time_step_spec(),\n", " train_env.action_spec(),\n", " categorical_q_network=categorical_q_net,\n", " optimizer=optimizer,\n", " min_q_value=min_q_value,\n", " max_q_value=max_q_value,\n", " n_step_update=n_step_update,\n", " td_errors_loss_fn=common.element_wise_squared_loss,\n", " gamma=gamma,\n", " train_step_counter=train_step_counter)\n", "agent.initialize()" ] }, { "cell_type": "markdown", "metadata": { "id": "L7O7F_HqiQ1G" }, "source": [ "最后要注意的一点是,我们还添加了一个参数来使用 $n$ = 2 的 n 步更新。在单步 Q-learning ($n$ = 1) 中,我们仅使用单步回报(基于贝尔曼最优性方程)计算当前时间步骤和下一时间步骤的 Q 值之间的误差。单步回报定义为:\n", "\n", "$G_t = R_{t + 1} + \\gamma V(s_{t + 1})$\n", "\n", "其中,我们定义 $V(s) = \\max_a{Q(s, a)}$。\n", "\n", "N 步更新涉及将标准单步回报函数扩展 $n$ 倍:\n", "\n", "$G_t^n = R_{t + 1} + \\gamma R_{t + 2} + \\gamma^2 R_{t + 3} + \\dots + \\gamma^n V(s_{t + n})$\n", "\n", "N 步更新使代理可以在将来进一步自助抽样,而在 $n$ 值正确的情况下,这通常可以加快学习速度。\n", "\n", "尽管 C51 和 n 步更新通常与优先回放相结合构成 [Rainbow 代理](https://arxiv.org/pdf/1710.02298.pdf)的核心,但我们发现,实现优先回放并未带来可衡量的改进。此外,我们还发现,仅将 C51 代理与 n 步更新结合使用时,在我们测试过的 Atari 环境样本中,我们的代理在性能上与其他 Rainbow 代理一样出色。" ] }, { "cell_type": "markdown", "metadata": { "id": "94rCXQtbUbXv" }, "source": [ "## 指标和评估\n", "\n", "用于评估策略的最常用指标是平均回报。回报是针对某个片段在环境中运行策略时获得的奖励总和,我们通常会评估多个片段的平均值。计算平均回报指标的代码如下。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bitzHo5_UbXy" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "def compute_avg_return(environment, policy, num_episodes=10):\n", "\n", " total_return = 0.0\n", " for _ in range(num_episodes):\n", "\n", " time_step = environment.reset()\n", " episode_return = 0.0\n", "\n", " while not time_step.is_last():\n", " action_step = policy.action(time_step)\n", " time_step = environment.step(action_step.action)\n", " episode_return += time_step.reward\n", " total_return += episode_return\n", "\n", " avg_return = total_return / num_episodes\n", " return avg_return.numpy()[0]\n", "\n", "\n", "random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),\n", " train_env.action_spec())\n", "\n", "compute_avg_return(eval_env, random_policy, num_eval_episodes)\n", "\n", "# Please also see the metrics module for standard implementations of different\n", "# metrics." ] }, { "cell_type": "markdown", "metadata": { "id": "NLva6g2jdWgr" }, "source": [ "## 数据收集\n", "\n", "与 DQN 教程中一样,使用随机策略设置回放缓冲区和初始数据收集。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wr1KSAEGG4h9" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n", " data_spec=agent.collect_data_spec,\n", " batch_size=train_env.batch_size,\n", " max_length=replay_buffer_capacity)\n", "\n", "def collect_step(environment, policy):\n", " time_step = environment.current_time_step()\n", " action_step = policy.action(time_step)\n", " next_time_step = environment.step(action_step.action)\n", " traj = trajectory.from_transition(time_step, action_step, next_time_step)\n", "\n", " # Add trajectory to the replay buffer\n", " replay_buffer.add_batch(traj)\n", "\n", "for _ in range(initial_collect_steps):\n", " collect_step(train_env, random_policy)\n", "\n", "# This loop is so common in RL, that we provide standard implementations of\n", "# these. For more details see the drivers module.\n", "\n", "# Dataset generates trajectories with shape [BxTx...] where\n", "# T = n_step_update + 1.\n", "dataset = replay_buffer.as_dataset(\n", " num_parallel_calls=3, sample_batch_size=batch_size,\n", " num_steps=n_step_update + 1).prefetch(3)\n", "\n", "iterator = iter(dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "hBc9lj9VWWtZ" }, "source": [ "## 训练代理\n", "\n", "训练循环包括从环境收集数据和优化代理的网络。在训练过程中,我们偶尔会评估代理的策略来了解效果。\n", "\n", "运行以下代码需要约 7 分钟。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0pTbJ3PeyF-u" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "try:\n", " %%time\n", "except:\n", " pass\n", "\n", "# (Optional) Optimize by wrapping some of the code in a graph using TF function.\n", "agent.train = common.function(agent.train)\n", "\n", "# Reset the train step\n", "agent.train_step_counter.assign(0)\n", "\n", "# Evaluate the agent's policy once before training.\n", "avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)\n", "returns = [avg_return]\n", "\n", "for _ in range(num_iterations):\n", "\n", " # Collect a few steps using collect_policy and save to the replay buffer.\n", " for _ in range(collect_steps_per_iteration):\n", " collect_step(train_env, agent.collect_policy)\n", "\n", " # Sample a batch of data from the buffer and update the agent's network.\n", " experience, unused_info = next(iterator)\n", " train_loss = agent.train(experience)\n", "\n", " step = agent.train_step_counter.numpy()\n", "\n", " if step % log_interval == 0:\n", " print('step = {0}: loss = {1}'.format(step, train_loss.loss))\n", "\n", " if step % eval_interval == 0:\n", " avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)\n", " print('step = {0}: Average Return = {1:.2f}'.format(step, avg_return))\n", " returns.append(avg_return)" ] }, { "cell_type": "markdown", "metadata": { "id": "68jNcA_TiJDq" }, "source": [ "## 可视化\n" ] }, { "cell_type": "markdown", "metadata": { "id": "aO-LWCdbbOIC" }, "source": [ "### 绘图\n", "\n", "我们可以通过绘制回报与全局步骤之间关系的图形来了解代理的性能。在 `Cartpole-v1` 中,长杆每直立一个时间步骤,环境就会提供 +1 的奖励,由于最大步骤数为 500,因此可以获得的最大回报也是 500。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NxtL1mbOYCVO" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "\n", "steps = range(0, num_iterations + 1, eval_interval)\n", "plt.plot(steps, returns)\n", "plt.ylabel('Average Return')\n", "plt.xlabel('Step')\n", "plt.ylim(top=550)" ] }, { "cell_type": "markdown", "metadata": { "id": "M7-XpPP99Cy7" }, "source": [ "### 视频" ] }, { "cell_type": "markdown", "metadata": { "id": "9pGfGxSH32gn" }, "source": [ "在每个步骤都渲染环境有助于可视化代理的性能。在此之前,我们先创建一个函数,以便在此 Colab 中嵌入视频。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ULaGr8pvOKbl" }, "outputs": [], "source": [ "def embed_mp4(filename):\n", " \"\"\"Embeds an mp4 file in the notebook.\"\"\"\n", " video = open(filename,'rb').read()\n", " b64 = base64.b64encode(video)\n", " tag = '''\n", " '''.format(b64.decode())\n", "\n", " return IPython.display.HTML(tag)" ] }, { "cell_type": "markdown", "metadata": { "id": "9c_PH-pX4Pr5" }, "source": [ "以下代码可将代理策略可视化多个片段:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "owOVWB158NlF" }, "outputs": [], "source": [ "num_episodes = 3\n", "video_filename = 'imageio.mp4'\n", "with imageio.get_writer(video_filename, fps=60) as video:\n", " for _ in range(num_episodes):\n", " time_step = eval_env.reset()\n", " video.append_data(eval_py_env.render())\n", " while not time_step.is_last():\n", " action_step = agent.policy.action(time_step)\n", " time_step = eval_env.step(action_step.action)\n", " video.append_data(eval_py_env.render())\n", "\n", "embed_mp4(video_filename)" ] }, { "cell_type": "markdown", "metadata": { "id": "exziB27hY8ia" }, "source": [ "C51 在性能上往往略微优于基于 CartPole-v1 的 DQN,但是,在越来越复杂的环境中,两种代理之间的差异变得越来越明显。例如,在完整的 Atari 2600 基准测试中,针对随机代理进行归一化之后,C51 的平均得分相比 DQN 提高 126%。通过包含 n 步更新,可以进一步提高性能。\n", "\n", "要深入了解 C51 算法,请参阅 [A Distributional Perspective on Reinforcement Learning (2017)](https://arxiv.org/pdf/1707.06887.pdf)。" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "9_c51_tutorial.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }