Checkpointer 和 PolicySaver#

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本

简介#

tf_agents.utils.common.Checkpointer 是一个用于将训练状态、策略状态和 replay_buffer 状态保存到本地存储空间或从本地存储空间加载的实用工具。

tf_agents.policies.policy_saver.PolicySaver 是一个仅供保存/加载策略的工具,比 Checkpointer 更为轻量化。您也可以使用 PolicySaver 来部署模型,而无需了解与创建策略所用代码有关的任何信息。

在本教程中,我们将使用 DQN 来训练模型,然后使用 CheckpointerPolicySaver 展示如何以交互方式存储和加载状态与模型。请注意,我们将为 PolicySaver 使用 TF2.0 的新版 saved_model 工具和格式。

设置#

如果尚未安装以下依赖项,请运行以下命令:

#@test {"skip": true}
!sudo apt-get update
!sudo apt-get install -y xvfb ffmpeg python-opengl
!pip install pyglet
!pip install 'imageio==2.4.0'
!pip install 'xvfbwrapper==0.2.9'
!pip install tf-agents[reverb]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import base64
import imageio
import io
import matplotlib
import matplotlib.pyplot as plt
import os
import shutil
import tempfile
import tensorflow as tf
import zipfile
import IPython

try:
  from google.colab import files
except ImportError:
  files = None
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

tempdir = os.getenv("TEST_TMPDIR", tempfile.gettempdir())
#@test {"skip": true}
# Set up a virtual display for rendering OpenAI gym environments.
import xvfbwrapper
xvfbwrapper.Xvfb(1400, 900, 24).start()

DQN 代理#

与此前的 Colab 一样,我们将建立 DQN 代理。详细信息在默认情况下处于隐藏状态,因为它们不是此 Colab 的核心部分,但是您可以点击“SHOW CODE”查看详细信息。

超参数#

env_name = "CartPole-v1"

collect_steps_per_iteration = 100
replay_buffer_capacity = 100000

fc_layer_params = (100,)

batch_size = 64
learning_rate = 1e-3
log_interval = 5

num_eval_episodes = 10
eval_interval = 1000

环境#

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

代理#

#@title
q_net = q_network.QNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

global_step = tf.compat.v1.train.get_or_create_global_step()

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=global_step)
agent.initialize()

数据收集#

#@title
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_capacity)

collect_driver = dynamic_step_driver.DynamicStepDriver(
    train_env,
    agent.collect_policy,
    observers=[replay_buffer.add_batch],
    num_steps=collect_steps_per_iteration)

# Initial data collection
collect_driver.run()

# Dataset generates trajectories with shape [BxTx...] where
# T = n_step_update + 1.
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, sample_batch_size=batch_size,
    num_steps=2).prefetch(3)

iterator = iter(dataset)

训练代理#

#@title
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

def train_one_iteration():

  # Collect a few steps using collect_policy and save to the replay buffer.
  collect_driver.run()

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  train_loss = agent.train(experience)

  iteration = agent.train_step_counter.numpy()
  print ('iteration: {0} loss: {1}'.format(iteration, train_loss.loss))

视频生成#

#@title
def embed_gif(gif_buffer):
  """Embeds a gif file in the notebook."""
  tag = '<img src="data:image/gif;base64,{0}"/>'.format(base64.b64encode(gif_buffer).decode())
  return IPython.display.HTML(tag)

def run_episodes_and_create_video(policy, eval_tf_env, eval_py_env):
  num_episodes = 3
  frames = []
  for _ in range(num_episodes):
    time_step = eval_tf_env.reset()
    frames.append(eval_py_env.render())
    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = eval_tf_env.step(action_step.action)
      frames.append(eval_py_env.render())
  gif_file = io.BytesIO()
  imageio.mimsave(gif_file, frames, format='gif', fps=60)
  IPython.display.display(embed_gif(gif_file.getvalue()))

生成视频#

通过生成视频来检查策略的性能。

print ('global_step:')
print (global_step)
run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)

设置 Checkpointer 和 PolicySaver#

现在,我们已做好了使用 Checkpointer 和 PolicySaver 的准备工作。

Checkpointer#

checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
    ckpt_dir=checkpoint_dir,
    max_to_keep=1,
    agent=agent,
    policy=agent.policy,
    replay_buffer=replay_buffer,
    global_step=global_step
)

PolicySaver#

policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver = policy_saver.PolicySaver(agent.policy)

训练一个迭代#

#@test {"skip": true}
print('Training one iteration....')
train_one_iteration()

保存到检查点#

train_checkpointer.save(global_step)

恢复检查点#

为此,应使用创建检查点时所用的相同方法重新创建整个对象集。

train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

还需保存策略并导出到特定位置

tf_policy_saver.save(policy_dir)

无需了解创建策略所用代理或网络,即可加载策略。这使策略的部署更加简单。

加载保存的策略并检查其执行情况

saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

导出和导入#

Colab 的其余部分将帮助您导出/导入 checkpointer 和策略目录,以便您可以在以后继续训练并部署模型,而无需重新训练。

现在,您可以返回“训练一个迭代”部分并进行更多次训练,以便发现其中的区别。一旦您发现开始出现更佳的结果,请继续下面的步骤。

#@title Create zip file and upload zip file (double-click to see the code)
def create_zip_file(dirname, base_filename):
  return shutil.make_archive(base_filename, 'zip', dirname)

def upload_and_unzip_file_to(dirname):
  if files is None:
    return
  uploaded = files.upload()
  for fn in uploaded.keys():
    print('User uploaded file "{name}" with length {length} bytes'.format(
        name=fn, length=len(uploaded[fn])))
    shutil.rmtree(dirname)
    zip_files = zipfile.ZipFile(io.BytesIO(uploaded[fn]), 'r')
    zip_files.extractall(dirname)
    zip_files.close()

创建检查点目录的压缩文件。

train_checkpointer.save(global_step)
checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))

下载 zip 文件。

#@test {"skip": true}
if files is not None:
  files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

训练一段时间(10-15 次)后,下载检查点 zip 文件,并转到“Runtime > Restart and run all”以重置训练,然后返回此单元。现在,您可以上传下载的 zip 文件,然后继续训练。

#@test {"skip": true}
upload_and_unzip_file_to(checkpoint_dir)
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

上传检查点目录后,返回“训练一个迭代”部分以继续训练,或者返回“生成视频”部分以检查加载的策略的性能。

或者,您也可以保存策略(模型)并进行恢复。与 checkpointer 不同,您无法继续训练,但仍可以部署模型。请注意,与 checkpointer 相比,此方式下载的文件要小得多。

tf_policy_saver.save(policy_dir)
policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))
#@test {"skip": true}
if files is not None:
  files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

上传下载的策略目录 (exported_policy.zip),并检查保存的策略的性能。

#@test {"skip": true}
upload_and_unzip_file_to(policy_dir)
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

SavedModelPyTFEagerPolicy#

如果您不想使用 TF 策略,还可以通过使用 py_tf_eager_policy.SavedModelPyTFEagerPolicy 直接在 Python 环境下使用 saved_model。

请注意,此方法仅在启用 Eager 模式时有效。

eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
    policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())

# Note that we're passing eval_py_env not eval_env.
run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)

将策略转换为 TFLite#

请参阅 TensorFlow Lite 转换器,了解详细信息。

converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=["action"])
tflite_policy = converter.convert()
with open(os.path.join(tempdir, 'policy.tflite'), 'wb') as f:
  f.write(tflite_policy)

在 TFLite 模型上运行推断#

如需了解详情,请参阅 TensorFlow Lite 推断

import numpy as np
interpreter = tf.lite.Interpreter(os.path.join(tempdir, 'policy.tflite'))

policy_runner = interpreter.get_signature_runner()
print(policy_runner._inputs)
policy_runner(**{
    '0/discount':tf.constant(0.0),
    '0/observation':tf.zeros([1,4]),
    '0/reward':tf.constant(0.0),
    '0/step_type':tf.constant(0)})