{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "klGNgWREsvQv" }, "source": [ "##### Copyright 2023 The TF-Agents Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "nQnmcm0oI1Q-" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "HNtBC6Bbb1YU" }, "source": [ "# REINFORCE 代理\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "ZOUOQOrFs3zn" }, "source": [ "## 简介" ] }, { "cell_type": "markdown", "metadata": { "id": "cKOCZlhUgXVK" }, "source": [ "本例介绍如何使用 TF-Agents 库在 Cartpole 环境中训练 [REINFORCE](http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf) 代理,与 [DQN 教程](1_dqn_tutorial.ipynb)比较相似。\n", "\n", "![Cartpole environment](images/cartpole.png)\n", "\n", "我们会引导您完成强化学习 (RL) 流水线中关于训练、评估和数据收集的所有部分。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "1u9QVVsShC9X" }, "source": [ "## 设置" ] }, { "cell_type": "markdown", "metadata": { "id": "I5PNmEzIb9t4" }, "source": [ "如果尚未安装以下依赖项,请运行以下命令:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KEHR2Ui-lo8O" }, "outputs": [], "source": [ "!sudo apt-get update\n", "!sudo apt-get install -y xvfb ffmpeg freeglut3-dev\n", "!pip install 'imageio==2.4.0'\n", "!pip install pyvirtualdisplay\n", "!pip install tf-agents[reverb]\n", "!pip install pyglet xvfbwrapper\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sMitx5qSgJk1" }, "outputs": [], "source": [ "from __future__ import absolute_import\n", "from __future__ import division\n", "from __future__ import print_function\n", "\n", "import base64\n", "import imageio\n", "import IPython\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import PIL.Image\n", "import pyvirtualdisplay\n", "import reverb\n", "\n", "import tensorflow as tf\n", "\n", "from tf_agents.agents.reinforce import reinforce_agent\n", "from tf_agents.drivers import py_driver\n", "from tf_agents.environments import suite_gym\n", "from tf_agents.environments import tf_py_environment\n", "from tf_agents.networks import actor_distribution_network\n", "from tf_agents.policies import py_tf_eager_policy\n", "from tf_agents.replay_buffers import reverb_replay_buffer\n", "from tf_agents.replay_buffers import reverb_utils\n", "from tf_agents.specs import tensor_spec\n", "from tf_agents.trajectories import trajectory\n", "from tf_agents.utils import common\n", "\n", "# Set up a virtual display for rendering OpenAI gym environments.\n", "display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()" ] }, { "cell_type": "markdown", "metadata": { "id": "LmC0NDhdLIKY" }, "source": [ "## 超参数" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HC1kNrOsLSIZ" }, "outputs": [], "source": [ "env_name = \"CartPole-v0\" # @param {type:\"string\"}\n", "num_iterations = 250 # @param {type:\"integer\"}\n", "collect_episodes_per_iteration = 2 # @param {type:\"integer\"}\n", "replay_buffer_capacity = 2000 # @param {type:\"integer\"}\n", "\n", "fc_layer_params = (100,)\n", "\n", "learning_rate = 1e-3 # @param {type:\"number\"}\n", "log_interval = 25 # @param {type:\"integer\"}\n", "num_eval_episodes = 10 # @param {type:\"integer\"}\n", "eval_interval = 50 # @param {type:\"integer\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "VMsJC3DEgI0x" }, "source": [ "## 环境\n", "\n", "RL 环境用于描述要解决的任务或问题。在 TF-Agents 中,使用 `suites` 可以轻松创建标准环境。我们提供了不同的 `suites`,只需提供一个字符串环境名称,即可帮助您从来源加载环境,如 OpenAI Gym、Atari、DM Control 等。\n", "\n", "现在,我们试试从 OpenAI Gym 套件加载 CartPole 环境。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pYEz-S9gEv2-" }, "outputs": [], "source": [ "env = suite_gym.load(env_name)" ] }, { "cell_type": "markdown", "metadata": { "id": "IIHYVBkuvPNw" }, "source": [ "我们可以渲染此环境以查看其形式:小车上连接一条自由摆动的长杆。目标是向右或向左移动小车,使长杆保持朝上。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RlO7WIQHu_7D" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "env.reset()\n", "PIL.Image.fromarray(env.render())" ] }, { "cell_type": "markdown", "metadata": { "id": "B9_lskPOey18" }, "source": [ "在该环境中,`time_step = environment.step(action)` 语句用于执行 `action`。返回的 `TimeStep` 元组包含该操作在环境中的下一个观测值和奖励。环境中的 `time_step_spec()` 和 `action_spec()` 方法分别返回 `time_step` 和 `action` 的规范(类型、形状、边界)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "exDv57iHfwQV" }, "outputs": [], "source": [ "print('Observation Spec:')\n", "print(env.time_step_spec().observation)\n", "print('Action Spec:')\n", "print(env.action_spec())" ] }, { "cell_type": "markdown", "metadata": { "id": "eJCgJnx3g0yY" }, "source": [ "我们可以看到,该观测值是一个包含 4 个浮点数的数组:小车的位置和速度,长杆的角度位置和速度。由于只有两个操作(向左或向右移动),因此,`action_spec` 是一个标量,其中 0 表示“向左移动”,1 表示“向右移动”。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V2UGR5t_iZX-" }, "outputs": [], "source": [ "time_step = env.reset()\n", "print('Time step:')\n", "print(time_step)\n", "\n", "action = np.array(1, dtype=np.int32)\n", "\n", "next_time_step = env.step(action)\n", "print('Next time step:')\n", "print(next_time_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "zuUqXAVmecTU" }, "source": [ "通常,我们会创建两个环境:一个用于训练,另一个用于评估。大部分环境都是使用纯 Python 语言编写的,但是使用 `TFPyEnvironment` 包装器可轻松将其转换至 TensorFlow 环境。原始环境的 API 使用 NumPy 数组,但凭借 `TFPyEnvironment`,这些数组可以与 `Tensors` 相互转换,从而更轻松地与 TensorFlow 策略和代理交互。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xp-Y4mD6eDhF" }, "outputs": [], "source": [ "train_py_env = suite_gym.load(env_name)\n", "eval_py_env = suite_gym.load(env_name)\n", "\n", "train_env = tf_py_environment.TFPyEnvironment(train_py_env)\n", "eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "E9lW_OZYFR8A" }, "source": [ "## 代理\n", "\n", "我们用于解决 RL 问题的算法以 `Agent` 形式表示。除了 REINFORCE 代理,TF-Agents 还为各种 `Agents` 提供了标准实现,如 [DQN](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)、[DDPG](https://arxiv.org/pdf/1509.02971.pdf)、[TD3](https://arxiv.org/pdf/1802.09477.pdf)、[PPO](https://arxiv.org/abs/1707.06347) 和 [SAC](https://arxiv.org/abs/1801.01290)。\n", "\n", "要创建 REINFORCE 代理,首先需要有一个通过环境提供的观测值,学会预测操作的 `Actor Network`。\n", "\n", "使用观测值和操作的规范,我们可以轻松创建 `Actor Network`。我们也可以在网络中指定层,本例中是设置为 `ints` 元祖(表示每个隐藏层的大小)的 `fc_layer_params` 参数(请参阅上面的“超参数”部分)。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TgkdEPg_muzV" }, "outputs": [], "source": [ "actor_net = actor_distribution_network.ActorDistributionNetwork(\n", " train_env.observation_spec(),\n", " train_env.action_spec(),\n", " fc_layer_params=fc_layer_params)" ] }, { "cell_type": "markdown", "metadata": { "id": "z62u55hSmviJ" }, "source": [ "我们还需要一个 `optimizer` 来训练刚才创建的网络,以及一个跟踪网络更新次数的 `train_step_counter` 变量。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jbY4yrjTEyc9" }, "outputs": [], "source": [ "optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n", "\n", "train_step_counter = tf.Variable(0)\n", "\n", "tf_agent = reinforce_agent.ReinforceAgent(\n", " train_env.time_step_spec(),\n", " train_env.action_spec(),\n", " actor_network=actor_net,\n", " optimizer=optimizer,\n", " normalize_returns=True,\n", " train_step_counter=train_step_counter)\n", "tf_agent.initialize()" ] }, { "cell_type": "markdown", "metadata": { "id": "I0KLrEPwkn5x" }, "source": [ "## 策略\n", "\n", "在 TF-Agents 中,策略是 RL 中的标准策略概念:给订 `time_step` 来产生操作或操作的分布。主要方法是 `policy_step = policy.step(time_step)`,其中 `policy_step` 是命名元祖 `PolicyStep(action, state, info)`。`policy_step.action` 是要应用到环境的 `action`,`state` 表示有状态 (RNN) 策略的状态,而 `info` 可能包含辅助信息(如操作的对数几率)。\n", "\n", "代理包含两个策略:一个是用于评估/部署的主要策略 (agent.policy),另一个是用于数据收集的策略 (agent.collect_policy)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BwY7StuMkuV4" }, "outputs": [], "source": [ "eval_policy = tf_agent.policy\n", "collect_policy = tf_agent.collect_policy" ] }, { "cell_type": "markdown", "metadata": { "id": "94rCXQtbUbXv" }, "source": [ "## 指标和评估\n", "\n", "用于评估策略的最常用指标是平均回报。回报就是在环境中运行策略时,某个片段获得的奖励总和,我们通常会计算几个片段的平均值。计算平均回报指标的代码如下。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bitzHo5_UbXy" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "def compute_avg_return(environment, policy, num_episodes=10):\n", "\n", " total_return = 0.0\n", " for _ in range(num_episodes):\n", "\n", " time_step = environment.reset()\n", " episode_return = 0.0\n", "\n", " while not time_step.is_last():\n", " action_step = policy.action(time_step)\n", " time_step = environment.step(action_step.action)\n", " episode_return += time_step.reward\n", " total_return += episode_return\n", "\n", " avg_return = total_return / num_episodes\n", " return avg_return.numpy()[0]\n", "\n", "\n", "# Please also see the metrics module for standard implementations of different\n", "# metrics." ] }, { "cell_type": "markdown", "metadata": { "id": "NLva6g2jdWgr" }, "source": [ "## 回放缓冲区\n", "\n", "为了跟踪从环境收集的数据,我们将使用 [Reverb](https://deepmind.com/research/open-source/Reverb),这是 Deepmind 打造的一款高效、可扩展且易于使用的回放系统。它会在我们收集轨迹时存储经验数据,并在训练期间使用。\n", "\n", "此回放缓冲区使用描述要存储的张量的规范进行构造,可以使用 `tf_agent.collect_data_spec` 从代理获取这些张量。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vX2zGUWJGWAl" }, "outputs": [], "source": [ "table_name = 'uniform_table'\n", "replay_buffer_signature = tensor_spec.from_spec(\n", " tf_agent.collect_data_spec)\n", "replay_buffer_signature = tensor_spec.add_outer_dim(\n", " replay_buffer_signature)\n", "table = reverb.Table(\n", " table_name,\n", " max_size=replay_buffer_capacity,\n", " sampler=reverb.selectors.Uniform(),\n", " remover=reverb.selectors.Fifo(),\n", " rate_limiter=reverb.rate_limiters.MinSize(1),\n", " signature=replay_buffer_signature)\n", "\n", "reverb_server = reverb.Server([table])\n", "\n", "replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(\n", " tf_agent.collect_data_spec,\n", " table_name=table_name,\n", " sequence_length=None,\n", " local_server=reverb_server)\n", "\n", "rb_observer = reverb_utils.ReverbAddEpisodeObserver(\n", " replay_buffer.py_client,\n", " table_name,\n", " replay_buffer_capacity\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "ZGNTDJpZs4NN" }, "source": [ "对于大多数代理,`collect_data_spec` 是一个 `Trajectory` 命名元组,其中包含观测值、操作和奖励等。" ] }, { "cell_type": "markdown", "metadata": { "id": "rVD5nQ9ZGo8_" }, "source": [ "## 数据收集\n", "\n", "当 REINFORCE 从全部片段中学习时,我们使用给定数据收集策略定义一个函数来收集片段,并在回放缓冲区中将数据(观测值、操作、奖励等)保存为轨迹。这里我们使用“PyDriver”运行经验收集循环。您可以在我们的 [driver 教程](https://tensorflow.google.cn/agents/tutorials/4_drivers_tutorial)中了解到有关 TF Agents driver 的更多信息。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wr1KSAEGG4h9" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "\n", "def collect_episode(environment, policy, num_episodes):\n", "\n", " driver = py_driver.PyDriver(\n", " environment,\n", " py_tf_eager_policy.PyTFEagerPolicy(\n", " policy, use_tf_function=True),\n", " [rb_observer],\n", " max_episodes=num_episodes)\n", " initial_time_step = environment.reset()\n", " driver.run(initial_time_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "hBc9lj9VWWtZ" }, "source": [ "## 训练代理\n", "\n", "训练循环包括从环境收集数据和优化代理的网络。在训练过程中,我们偶尔会评估代理的策略,看看效果如何。\n", "\n", "运行下面的代码大约需要 3 分钟。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0pTbJ3PeyF-u" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "try:\n", " %%time\n", "except:\n", " pass\n", "\n", "# (Optional) Optimize by wrapping some of the code in a graph using TF function.\n", "tf_agent.train = common.function(tf_agent.train)\n", "\n", "# Reset the train step\n", "tf_agent.train_step_counter.assign(0)\n", "\n", "# Evaluate the agent's policy once before training.\n", "avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)\n", "returns = [avg_return]\n", "\n", "for _ in range(num_iterations):\n", "\n", " # Collect a few episodes using collect_policy and save to the replay buffer.\n", " collect_episode(\n", " train_py_env, tf_agent.collect_policy, collect_episodes_per_iteration)\n", "\n", " # Use data from the buffer and update the agent's network.\n", " iterator = iter(replay_buffer.as_dataset(sample_batch_size=1))\n", " trajectories, _ = next(iterator)\n", " train_loss = tf_agent.train(experience=trajectories) \n", "\n", " replay_buffer.clear()\n", "\n", " step = tf_agent.train_step_counter.numpy()\n", "\n", " if step % log_interval == 0:\n", " print('step = {0}: loss = {1}'.format(step, train_loss.loss))\n", "\n", " if step % eval_interval == 0:\n", " avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)\n", " print('step = {0}: Average Return = {1}'.format(step, avg_return))\n", " returns.append(avg_return)" ] }, { "cell_type": "markdown", "metadata": { "id": "68jNcA_TiJDq" }, "source": [ "## 可视化\n" ] }, { "cell_type": "markdown", "metadata": { "id": "aO-LWCdbbOIC" }, "source": [ "### 绘图\n", "\n", "我们可以通过绘制回报与全局步骤的图形来了解代理的性能。在 `Cartpole-v0` 中,长杆每停留一个时间步骤,环境就会提供一个 +1 的奖励,由于最大步骤数量为 200,所以可以获得的最大回报也是 200。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NxtL1mbOYCVO" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "\n", "steps = range(0, num_iterations + 1, eval_interval)\n", "plt.plot(steps, returns)\n", "plt.ylabel('Average Return')\n", "plt.xlabel('Step')\n", "plt.ylim(top=250)" ] }, { "cell_type": "markdown", "metadata": { "id": "M7-XpPP99Cy7" }, "source": [ "### 视频" ] }, { "cell_type": "markdown", "metadata": { "id": "9pGfGxSH32gn" }, "source": [ "在每个步骤渲染环境有助于可视化代理的性能。在此之前,我们先创建一个函数,在该 Colab 中嵌入视频。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ULaGr8pvOKbl" }, "outputs": [], "source": [ "def embed_mp4(filename):\n", " \"\"\"Embeds an mp4 file in the notebook.\"\"\"\n", " video = open(filename,'rb').read()\n", " b64 = base64.b64encode(video)\n", " tag = '''\n", " '''.format(b64.decode())\n", "\n", " return IPython.display.HTML(tag)" ] }, { "cell_type": "markdown", "metadata": { "id": "9c_PH-pX4Pr5" }, "source": [ "以下代码用于为代理可视化几个片段的策略:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "owOVWB158NlF" }, "outputs": [], "source": [ "num_episodes = 3\n", "video_filename = 'imageio.mp4'\n", "with imageio.get_writer(video_filename, fps=60) as video:\n", " for _ in range(num_episodes):\n", " time_step = eval_env.reset()\n", " video.append_data(eval_py_env.render())\n", " while not time_step.is_last():\n", " action_step = tf_agent.policy.action(time_step)\n", " time_step = eval_env.step(action_step.action)\n", " video.append_data(eval_py_env.render())\n", "\n", "embed_mp4(video_filename)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "6_reinforce_tutorial.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }