驱动程序#

在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本

简介#

强化学习的常见模式是在环境中执行策略,持续指定的步数或片段数。在诸如数据收集、评估和生成代理视频期间会采用这种模式。

使用 Python 编程非常简单,但在 TensorFlow 中编程和调试则要复杂得多,因为它涉及 tf.while 循环、tf.condtf.control_dependencies。因此,我们将运行循环这一概念抽象成一个名为 driver 的类,并在 Python 和 TensorFlow 中提供经过充分测试的实现。

此外,驱动程序在每步遇到的数据都会保存在名为 Trajectory 的命名元组内,并广播给一组观察者(例如回放缓冲区和指标)。这些数据包括环境观测值、策略建议的操作、获得的奖励、当前和下一个步骤的类型等。

设置#

如果尚未安装 TF-Agents 或 Gym,请运行以下命令:

!pip install tf-agents
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf


from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.policies import random_py_policy
from tf_agents.policies import random_tf_policy
from tf_agents.metrics import py_metrics
from tf_agents.metrics import tf_metrics
from tf_agents.drivers import py_driver
from tf_agents.drivers import dynamic_episode_driver

Python 驱动程序#

PyDriver 类采用 Python 环境、Python 策略和观察者列表在每个时间步骤更新。主要方法为 run(),它会使用策略中的操作逐步执行环境,直到至少满足以下终止条件之一:步数达到 max_steps 或片段数达到 max_episodes

实现方式大致如下:

class PyDriver(object):

  def __init__(self, env, policy, observers, max_steps=1, max_episodes=1):
    self._env = env
    self._policy = policy
    self._observers = observers or []
    self._max_steps = max_steps or np.inf
    self._max_episodes = max_episodes or np.inf

  def run(self, time_step, policy_state=()):
    num_steps = 0
    num_episodes = 0
    while num_steps < self._max_steps and num_episodes < self._max_episodes:

      # Compute an action using the policy for the given time_step
      action_step = self._policy.action(time_step, policy_state)

      # Apply the action to the environment and get the next step
      next_time_step = self._env.step(action_step.action)

      # Package information into a trajectory
      traj = trajectory.Trajectory(
         time_step.step_type,
         time_step.observation,
         action_step.action,
         action_step.info,
         next_time_step.step_type,
         next_time_step.reward,
         next_time_step.discount)

      for observer in self._observers:
        observer(traj)

      # Update statistics to check termination
      num_episodes += np.sum(traj.is_last())
      num_steps += np.sum(~traj.is_boundary())

      time_step = next_time_step
      policy_state = action_step.state

    return time_step, policy_state

以下示例展示了在 CartPole 环境中运行随机策略,将结果保存到回放缓冲区并计算一些指标。

env = suite_gym.load('CartPole-v0')
policy = random_py_policy.RandomPyPolicy(time_step_spec=env.time_step_spec(), 
                                         action_spec=env.action_spec())
replay_buffer = []
metric = py_metrics.AverageReturnMetric()
observers = [replay_buffer.append, metric]
driver = py_driver.PyDriver(
    env, policy, observers, max_steps=20, max_episodes=1)

initial_time_step = env.reset()
final_time_step, _ = driver.run(initial_time_step)

print('Replay Buffer:')
for traj in replay_buffer:
  print(traj)

print('Average Return: ', metric.result())

TensorFlow 驱动程序#

TensorFlow 中也有驱动程序,其功能与 Python 驱动程序类似,区别是使用 TF 环境、TF 策略、TF 观察者等。我们目前有 2 种 TensorFlow 驱动程序:DynamicStepDriver(在给定的有效环境步数后终止),以及 DynamicEpisodeDriver(在给定的片段数后终止)。让我们看一下 DynamicEpisode 的实际应用示例。

env = suite_gym.load('CartPole-v0')
tf_env = tf_py_environment.TFPyEnvironment(env)

tf_policy = random_tf_policy.RandomTFPolicy(action_spec=tf_env.action_spec(),
                                            time_step_spec=tf_env.time_step_spec())


num_episodes = tf_metrics.NumberOfEpisodes()
env_steps = tf_metrics.EnvironmentSteps()
observers = [num_episodes, env_steps]
driver = dynamic_episode_driver.DynamicEpisodeDriver(
    tf_env, tf_policy, observers, num_episodes=2)

# Initial driver.run will reset the environment and initialize the policy.
final_time_step, policy_state = driver.run()

print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
# Continue running from previous state
final_time_step, _ = driver.run(final_time_step, policy_state)

print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())