{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "klGNgWREsvQv" }, "source": [ "##### Copyright 2023 The TF-Agents Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "nQnmcm0oI1Q-" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "HqslkUeyEJFg" }, "source": [ "# TF-Agents 中的多臂老虎机教程" ] }, { "cell_type": "markdown", "metadata": { "id": "MimUC9NrYFaS" }, "source": [ "### 开始\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本
\n" ] }, { "cell_type": "markdown", "metadata": { "id": "1u9QVVsShC9X" }, "source": [ "### 安装" ] }, { "cell_type": "markdown", "metadata": { "id": "kNrNXKI7bINP" }, "source": [ "如果尚未安装以下依赖项,请运行:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KEHR2Ui-lo8O" }, "outputs": [], "source": [ "!pip install tf-agents" ] }, { "cell_type": "markdown", "metadata": { "id": "O7gLdUS6b2EG" }, "source": [ "### 导入" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3oCS94Z83Jo2" }, "outputs": [], "source": [ "import abc\n", "import numpy as np\n", "import tensorflow as tf\n", "\n", "from tf_agents.agents import tf_agent\n", "from tf_agents.drivers import driver\n", "from tf_agents.environments import py_environment\n", "from tf_agents.environments import tf_environment\n", "from tf_agents.environments import tf_py_environment\n", "from tf_agents.policies import tf_policy\n", "from tf_agents.specs import array_spec\n", "from tf_agents.specs import tensor_spec\n", "from tf_agents.trajectories import time_step as ts\n", "from tf_agents.trajectories import trajectory\n", "from tf_agents.trajectories import policy_step\n", "\n", "nest = tf.nest" ] }, { "cell_type": "markdown", "metadata": { "id": "CcIob6rYqien" }, "source": [ "# 简介\n" ] }, { "cell_type": "markdown", "metadata": { "id": "JdnTJrzaeft3" }, "source": [ "多臂老虎机问题 (MAB) 是强化学习的一项特例:代理会通过在观察到环境的某些状态后采取一些动作来收集环境中的奖励。一般的强化学习与 MAB 的主要区别在于,在 MAB 中,我们假定代理采取的动作不会影响环境的下一个状态。因此,代理不会对状态转换进行建模,将奖励归因于过去的动作,或者以获得高奖励状态为目的进行“提前计划”。\n", "\n", "与其他强化学习领域一样,MAB *代理*的目标也是找出一种*策略*来收集尽可能多奖励。然而,总是试图利用预示最高奖励的动作是不对的,因为如果我们所做的探索不够充分,就有可能会错过更好的动作。这是 MAB 中要解决的主要问题,通常称为*探索-利用困境*。\n", "\n", "MAB 的老虎机环境、策略和代理可以在 [tf_agents/bandits](https://github.com/tensorflow/agents/blob/master/tf_agents/bandits) 的子目录中找到。" ] }, { "cell_type": "markdown", "metadata": { "id": "iPzsBCTperx3" }, "source": [ "# 环境" ] }, { "cell_type": "markdown", "metadata": { "id": "1LOXW8i320Cp" }, "source": [ "在 TF-Agents 中,环境类的作用是提供有关当前状态的信息(称为**观测值**或**上下文**)、接收动作作为输入、执行状态转换以及输出奖励。此类还负责在片段结束时进行重置,以便可以开始新的片段。这是通过在状态被标记为片段的“最后”状态时调用 `reset` 函数来实现的。\n", "\n", "有关详情,请参阅 [TF-Agents 环境教程](https://github.com/tensorflow/agents/blob/master/docs/tutorials/2_environments_tutorial.ipynb)。\n", "\n", "如上所述,MAB 与一般强化学习的不同之处在于动作不会影响下一次观测。另一个区别是,老虎机中没有“片段”:每个时间步都始于新的观测,与之前的时间步无关。\n", "\n", "为了确保观测独立且不涉及强化学习片段的概念,我们引入了 `PyEnvironment` 和 `TFEnvironment` 的子类:[BanditPyEnvironment](https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/environments/bandit_py_environment.py) 和 [BanditTFEnvironment](https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/environments/bandit_tf_environment.py)。这些类会公开两个私有成员函数,这些函数仍然由用户实现:\n", "\n", "```python\n", "@abc.abstractmethod\n", "def _observe(self):\n", "```\n", "\n", "和\n", "\n", "```python\n", "@abc.abstractmethod\n", "def _apply_action(self, action):\n", "```\n", "\n", "`_observe` 函数会返回一个观测值。然后,策略会根据此观测值来选择一个动作。`_apply_action` 会接收该动作作为输入,并返回相应的奖励。这些私有成员函数分别由 `reset` 和 `step` 函数调用。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TTaG2ZapQvHX" }, "outputs": [], "source": [ "class BanditPyEnvironment(py_environment.PyEnvironment):\n", "\n", " def __init__(self, observation_spec, action_spec):\n", " self._observation_spec = observation_spec\n", " self._action_spec = action_spec\n", " super(BanditPyEnvironment, self).__init__()\n", "\n", " # Helper functions.\n", " def action_spec(self):\n", " return self._action_spec\n", "\n", " def observation_spec(self):\n", " return self._observation_spec\n", "\n", " def _empty_observation(self):\n", " return tf.nest.map_structure(lambda x: np.zeros(x.shape, x.dtype),\n", " self.observation_spec())\n", "\n", " # These two functions below should not be overridden by subclasses.\n", " def _reset(self):\n", " \"\"\"Returns a time step containing an observation.\"\"\"\n", " return ts.restart(self._observe(), batch_size=self.batch_size)\n", "\n", " def _step(self, action):\n", " \"\"\"Returns a time step containing the reward for the action taken.\"\"\"\n", " reward = self._apply_action(action)\n", " return ts.termination(self._observe(), reward)\n", "\n", " # These two functions below are to be implemented in subclasses.\n", " @abc.abstractmethod\n", " def _observe(self):\n", " \"\"\"Returns an observation.\"\"\"\n", "\n", " @abc.abstractmethod\n", " def _apply_action(self, action):\n", " \"\"\"Applies `action` to the Environment and returns the corresponding reward.\n", " \"\"\"" ] }, { "cell_type": "markdown", "metadata": { "id": "ZVtLk28xVo0j" }, "source": [ "上述临时抽象类会实现 `PyEnvironment` 的 `_reset` 和 `_step` 函数,并公开 `_observe` 和 `_apply_action` 抽象函数以由子类实现。" ] }, { "cell_type": "markdown", "metadata": { "id": "xQbI-6PdtSJn" }, "source": [ "## 环境类的简单示例" ] }, { "cell_type": "markdown", "metadata": { "id": "8qspwAx0tS6l" }, "source": [ "以下类提供了一个非常简单的环境,观测值是一个介于 -2 和 2 之间的随机整数,有 3 种可能的动作(0、1、2),奖励为动作和观测值的乘积。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YV6DhsSi227-" }, "outputs": [], "source": [ "class SimplePyEnvironment(BanditPyEnvironment):\n", "\n", " def __init__(self):\n", " action_spec = array_spec.BoundedArraySpec(\n", " shape=(), dtype=np.int32, minimum=0, maximum=2, name='action')\n", " observation_spec = array_spec.BoundedArraySpec(\n", " shape=(1,), dtype=np.int32, minimum=-2, maximum=2, name='observation')\n", " super(SimplePyEnvironment, self).__init__(observation_spec, action_spec)\n", "\n", " def _observe(self):\n", " self._observation = np.random.randint(-2, 3, (1,), dtype='int32')\n", " return self._observation\n", "\n", " def _apply_action(self, action):\n", " return action * self._observation" ] }, { "cell_type": "markdown", "metadata": { "id": "ipEQgYDIf55t" }, "source": [ "现在,我们可以使用此环境来获得观测值,并为我们的动作获得奖励。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Eo_uwSz2gAKX" }, "outputs": [], "source": [ "environment = SimplePyEnvironment()\n", "observation = environment.reset().observation\n", "print(\"observation: %d\" % observation)\n", "\n", "action = 2 #@param\n", "\n", "print(\"action: %d\" % action)\n", "reward = environment.step(action).reward\n", "print(\"reward: %f\" % reward)" ] }, { "cell_type": "markdown", "metadata": { "id": "GuVYHI8aDgCx" }, "source": [ "## TF 环境" ] }, { "cell_type": "markdown", "metadata": { "id": "dP46VwLTDnOR" }, "source": [ "可以通过子类化 `BanditTFEnvironment` 来定义老虎机环境,或者也可以与强化学习环境类似,定义 `BanditPyEnvironment` 并使用 `TFPyEnvironment` 对其进行包装。为简单起见,我们在本教程中使用后一选项。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IPPpwSi3EtWz" }, "outputs": [], "source": [ "tf_environment = tf_py_environment.TFPyEnvironment(environment)" ] }, { "cell_type": "markdown", "metadata": { "id": "-S9fhxF9GUaT" }, "source": [ "# 策略" ] }, { "cell_type": "markdown", "metadata": { "id": "NbTt5jnuGlYj" }, "source": [ "老虎机问题中的*策略*与强化学习问题中的策略工作方式相同:给定一个观测值作为输入,它会提供一个动作(或动作分布)。\n", "\n", "有关详情,请参阅 [TF-Agents 策略教程](https://github.com/tensorflow/agents/blob/master/docs/tutorials/3_policies_tutorial.ipynb)。\n", "\n", "与环境一样,可以通过两种方式构建策略:可以创建 `PyPolicy` 并使用 `TFPyPolicy` 包装,或者也可以直接创建 `TFPolicy`。在此,我们选择使用直接方式。\n", "\n", "由于本例非常简单,我们可以手动定义最优策略。动作仅取决于观测值的正负,0 表示负数,2 表示正数。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VpMZlplNK5ND" }, "outputs": [], "source": [ "class SignPolicy(tf_policy.TFPolicy):\n", " def __init__(self):\n", " observation_spec = tensor_spec.BoundedTensorSpec(\n", " shape=(1,), dtype=tf.int32, minimum=-2, maximum=2)\n", " time_step_spec = ts.time_step_spec(observation_spec)\n", "\n", " action_spec = tensor_spec.BoundedTensorSpec(\n", " shape=(), dtype=tf.int32, minimum=0, maximum=2)\n", "\n", " super(SignPolicy, self).__init__(time_step_spec=time_step_spec,\n", " action_spec=action_spec)\n", " def _distribution(self, time_step):\n", " pass\n", "\n", " def _variables(self):\n", " return ()\n", "\n", " def _action(self, time_step, policy_state, seed):\n", " observation_sign = tf.cast(tf.sign(time_step.observation[0]), dtype=tf.int32)\n", " action = observation_sign + 1\n", " return policy_step.PolicyStep(action, policy_state)" ] }, { "cell_type": "markdown", "metadata": { "id": "GAM7hb4LVQ70" }, "source": [ "现在,我们可以从环境请求观测值,调用策略以选择动作,然后环境将输出奖励:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z0_5vMDCVZWT" }, "outputs": [], "source": [ "sign_policy = SignPolicy()\n", "\n", "current_time_step = tf_environment.reset()\n", "print('Observation:')\n", "print (current_time_step.observation)\n", "action = sign_policy.action(current_time_step).action\n", "print('Action:')\n", "print (action)\n", "reward = tf_environment.step(action).reward\n", "print('Reward:')\n", "print(reward)" ] }, { "cell_type": "markdown", "metadata": { "id": "AExuQ7u0-PF6" }, "source": [ "老虎机环境的实现方式可确保我们每完成一步,不仅会因所采取的动作而获得奖励,还会获得下一个观测值。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CiB935of-wVv" }, "outputs": [], "source": [ "step = tf_environment.reset()\n", "action = 1\n", "next_step = tf_environment.step(action)\n", "reward = next_step.reward\n", "next_observation = next_step.observation\n", "print(\"Reward: \")\n", "print(reward)\n", "print(\"Next observation:\")\n", "print(next_observation)" ] }, { "cell_type": "markdown", "metadata": { "id": "zFnqVHfeANZP" }, "source": [ "# 代理" ] }, { "cell_type": "markdown", "metadata": { "id": "1pDK_faXAPSA" }, "source": [ "现在,我们已经有了老虎机环境和老虎机策略,是时候定义老虎机代理了,它会负责基于训练样本来改变策略。\n", "\n", "老虎机代理的 API 与强化学习代理的 API 没有区别:代理只需实现 `_initialize` 和 `_train` 方法,并定义 `policy` 和 `collect_policy`。" ] }, { "cell_type": "markdown", "metadata": { "id": "TVCb-vPJOayG" }, "source": [ "## 更加复杂的环境" ] }, { "cell_type": "markdown", "metadata": { "id": "9Ksv7i7zPGSa" }, "source": [ "编写老虎机代理之前,我们需要准备一个稍加复杂的环境。为了增添一点趣味性,下一个环境要么总是给出 `reward = observation * action`,要么总是给出 `reward = -observation * action`。这将在环境初始化时决定。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fte7-Mr8O0QR" }, "outputs": [], "source": [ "class TwoWayPyEnvironment(BanditPyEnvironment):\n", "\n", " def __init__(self):\n", " action_spec = array_spec.BoundedArraySpec(\n", " shape=(), dtype=np.int32, minimum=0, maximum=2, name='action')\n", " observation_spec = array_spec.BoundedArraySpec(\n", " shape=(1,), dtype=np.int32, minimum=-2, maximum=2, name='observation')\n", "\n", " # Flipping the sign with probability 1/2.\n", " self._reward_sign = 2 * np.random.randint(2) - 1\n", " print(\"reward sign:\")\n", " print(self._reward_sign)\n", "\n", " super(TwoWayPyEnvironment, self).__init__(observation_spec, action_spec)\n", "\n", " def _observe(self):\n", " self._observation = np.random.randint(-2, 3, (1,), dtype='int32')\n", " return self._observation\n", "\n", " def _apply_action(self, action):\n", " return self._reward_sign * action * self._observation[0]\n", "\n", "two_way_tf_environment = tf_py_environment.TFPyEnvironment(TwoWayPyEnvironment())" ] }, { "cell_type": "markdown", "metadata": { "id": "7Zb4jWpQUA75" }, "source": [ "## 更加复杂的策略" ] }, { "cell_type": "markdown", "metadata": { "id": "Dz2rEEA1USJu" }, "source": [ "更加复杂的环境需要更加复杂的策略。我们需要一种能够检测底层环境行为的策略。该策略需要处理以下三种情况:\n", "\n", "1. 代理未检测,尚不知道哪个版本的环境正在运行。\n", "2. 代理检测到原始版本的环境正在运行。\n", "3. 代理检测到翻转版本的环境正在运行。\n", "\n", "我们定义了一个名为 `_situation` 的 `tf_variable` 来将此信息编码为 `[0, 2]` 区间内的值,然后使策略相应运行。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Srm2jsGHVM8N" }, "outputs": [], "source": [ "class TwoWaySignPolicy(tf_policy.TFPolicy):\n", " def __init__(self, situation):\n", " observation_spec = tensor_spec.BoundedTensorSpec(\n", " shape=(1,), dtype=tf.int32, minimum=-2, maximum=2)\n", " action_spec = tensor_spec.BoundedTensorSpec(\n", " shape=(), dtype=tf.int32, minimum=0, maximum=2)\n", " time_step_spec = ts.time_step_spec(observation_spec)\n", " self._situation = situation\n", " super(TwoWaySignPolicy, self).__init__(time_step_spec=time_step_spec,\n", " action_spec=action_spec)\n", " def _distribution(self, time_step):\n", " pass\n", "\n", " def _variables(self):\n", " return [self._situation]\n", "\n", " def _action(self, time_step, policy_state, seed):\n", " sign = tf.cast(tf.sign(time_step.observation[0, 0]), dtype=tf.int32)\n", " def case_unknown_fn():\n", " # Choose 1 so that we get information on the sign.\n", " return tf.constant(1, shape=(1,))\n", "\n", " # Choose 0 or 2, depending on the situation and the sign of the observation.\n", " def case_normal_fn():\n", " return tf.constant(sign + 1, shape=(1,))\n", " def case_flipped_fn():\n", " return tf.constant(1 - sign, shape=(1,))\n", "\n", " cases = [(tf.equal(self._situation, 0), case_unknown_fn),\n", " (tf.equal(self._situation, 1), case_normal_fn),\n", " (tf.equal(self._situation, 2), case_flipped_fn)]\n", " action = tf.case(cases, exclusive=True)\n", " return policy_step.PolicyStep(action, policy_state)" ] }, { "cell_type": "markdown", "metadata": { "id": "r6PPdRQQbE3Q" }, "source": [ "## 代理" ] }, { "cell_type": "markdown", "metadata": { "id": "pO8HpL0tUP32" }, "source": [ "现在,可以定义检测环境正负并适当设置策略的代理了。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7f-0W0cMbS_z" }, "outputs": [], "source": [ "class SignAgent(tf_agent.TFAgent):\n", " def __init__(self):\n", " self._situation = tf.Variable(0, dtype=tf.int32)\n", " policy = TwoWaySignPolicy(self._situation)\n", " time_step_spec = policy.time_step_spec\n", " action_spec = policy.action_spec\n", " super(SignAgent, self).__init__(time_step_spec=time_step_spec,\n", " action_spec=action_spec,\n", " policy=policy,\n", " collect_policy=policy,\n", " train_sequence_length=None)\n", "\n", " def _initialize(self):\n", " return tf.compat.v1.variables_initializer(self.variables)\n", "\n", " def _train(self, experience, weights=None):\n", " observation = experience.observation\n", " action = experience.action\n", " reward = experience.reward\n", "\n", " # We only need to change the value of the situation variable if it is\n", " # unknown (0) right now, and we can infer the situation only if the\n", " # observation is not 0.\n", " needs_action = tf.logical_and(tf.equal(self._situation, 0),\n", " tf.not_equal(reward, 0))\n", "\n", "\n", " def new_situation_fn():\n", " \"\"\"This returns either 1 or 2, depending on the signs.\"\"\"\n", " return (3 - tf.sign(tf.cast(observation[0, 0, 0], dtype=tf.int32) *\n", " tf.cast(action[0, 0], dtype=tf.int32) *\n", " tf.cast(reward[0, 0], dtype=tf.int32))) / 2\n", "\n", " new_situation = tf.cond(needs_action,\n", " new_situation_fn,\n", " lambda: self._situation)\n", " new_situation = tf.cast(new_situation, tf.int32)\n", " tf.compat.v1.assign(self._situation, new_situation)\n", " return tf_agent.LossInfo((), ())\n", "\n", "sign_agent = SignAgent()\n" ] }, { "cell_type": "markdown", "metadata": { "id": "oyclF0ZZpW-f" }, "source": [ "在上面的代码中,代理定义了策略,变量 `situation` 由代理和策略共享。\n", "\n", "另外,`_train` 函数的 `experience` 参数是一条轨迹:" ] }, { "cell_type": "markdown", "metadata": { "id": "3NlF228LGoiR" }, "source": [ "# 轨迹" ] }, { "cell_type": "markdown", "metadata": { "id": "2GbBDi1iGsnN" }, "source": [ "在 TF-Agents 中,`trajectories` 是包含来自先前步骤的样本的命名元组。然后代理会使用这些样本来训练和更新策略。在强化学习中,轨迹必须包含有关当前状态、下一个状态以及当前片段是否结束的信息。我们在老虎机问题中不需要这些信息,因此我们设置了一个辅助函数来创建轨迹:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gdSG1nv-HUJq" }, "outputs": [], "source": [ "# We need to add another dimension here because the agent expects the\n", "# trajectory of shape [batch_size, time, ...], but in this tutorial we assume\n", "# that both batch size and time are 1. Hence all the expand_dims.\n", "\n", "def trajectory_for_bandit(initial_step, action_step, final_step):\n", " return trajectory.Trajectory(observation=tf.expand_dims(initial_step.observation, 0),\n", " action=tf.expand_dims(action_step.action, 0),\n", " policy_info=action_step.info,\n", " reward=tf.expand_dims(final_step.reward, 0),\n", " discount=tf.expand_dims(final_step.discount, 0),\n", " step_type=tf.expand_dims(initial_step.step_type, 0),\n", " next_step_type=tf.expand_dims(final_step.step_type, 0))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "zFEJ8kbI_e6Q" }, "source": [ "# 训练代理" ] }, { "cell_type": "markdown", "metadata": { "id": "0Gh-41og_hDB" }, "source": [ "现在,各个部分均已准备就绪,可以训练我们的老虎机代理了。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LPx43dZgoyKg" }, "outputs": [], "source": [ "step = two_way_tf_environment.reset()\n", "for _ in range(10):\n", " action_step = sign_agent.collect_policy.action(step)\n", " next_step = two_way_tf_environment.step(action_step.action)\n", " experience = trajectory_for_bandit(step, action_step, next_step)\n", " print(experience)\n", " sign_agent.train(experience)\n", " step = next_step\n" ] }, { "cell_type": "markdown", "metadata": { "id": "4iVSNiYdy4U4" }, "source": [ "从输出可以看出,在第二步之后(除非在第一步中观测值为 0),策略将以正确的方式选择动作,因此收集的奖励始终为非负值。" ] }, { "cell_type": "markdown", "metadata": { "id": "RCKyKEjOlOPE" }, "source": [ "# 真实上下文老虎机示例" ] }, { "cell_type": "markdown", "metadata": { "id": "ecnQwUpmllar" }, "source": [ "在本教程的剩余部分中,我们将使用 TF-Agents Bandits 库的预实现[环境](https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/environments/)和[代理](https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/agents/)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oEnXUwd-nZKl" }, "outputs": [], "source": [ "# Imports for example.\n", "from tf_agents.bandits.agents import lin_ucb_agent\n", "from tf_agents.bandits.environments import stationary_stochastic_py_environment as sspe\n", "from tf_agents.bandits.metrics import tf_metrics\n", "from tf_agents.drivers import dynamic_step_driver\n", "from tf_agents.replay_buffers import tf_uniform_replay_buffer\n", "\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": { "id": "37oy70dUmmie" }, "source": [ "## 采用线性收益函数的平稳随机环境" ] }, { "cell_type": "markdown", "metadata": { "id": "euPPd8x1m7iG" }, "source": [ "此示例中使用的环境为 [StationaryStochasticPyEnvironment](https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/environments/stationary_stochastic_py_environment.py)。此环境会将(通常含噪声)函数作为参数来提供观测值(上下文),并且会针对每个老虎机臂采用(也含噪声)函数来基于给定的观测值计算奖励。在我们的示例中,我们从 d 维立方体中均匀地采样上下文,奖励函数为上下文的线性函数,加上一些高斯噪声。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gVa0hmQrpe6w" }, "outputs": [], "source": [ "batch_size = 2 # @param\n", "arm0_param = [-3, 0, 1, -2] # @param\n", "arm1_param = [1, -2, 3, 0] # @param\n", "arm2_param = [0, 0, 1, 1] # @param\n", "def context_sampling_fn(batch_size):\n", " \"\"\"Contexts from [-10, 10]^4.\"\"\"\n", " def _context_sampling_fn():\n", " return np.random.randint(-10, 10, [batch_size, 4]).astype(np.float32)\n", " return _context_sampling_fn\n", "\n", "class LinearNormalReward(object):\n", " \"\"\"A class that acts as linear reward function when called.\"\"\"\n", " def __init__(self, theta, sigma):\n", " self.theta = theta\n", " self.sigma = sigma\n", " def __call__(self, x):\n", " mu = np.dot(x, self.theta)\n", " return np.random.normal(mu, self.sigma)\n", "\n", "arm0_reward_fn = LinearNormalReward(arm0_param, 1)\n", "arm1_reward_fn = LinearNormalReward(arm1_param, 1)\n", "arm2_reward_fn = LinearNormalReward(arm2_param, 1)\n", "\n", "environment = tf_py_environment.TFPyEnvironment(\n", " sspe.StationaryStochasticPyEnvironment(\n", " context_sampling_fn(batch_size),\n", " [arm0_reward_fn, arm1_reward_fn, arm2_reward_fn],\n", " batch_size=batch_size))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "haID-SPgsLyY" }, "source": [ "## LinUCB 代理" ] }, { "cell_type": "markdown", "metadata": { "id": "298-1Q0bsQmR" }, "source": [ "下面的代理实现了 [LinUCB](http://rob.schapire.net/papers/www10.pdf) 算法。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "p4XmGgIusj-K" }, "outputs": [], "source": [ "observation_spec = tensor_spec.TensorSpec([4], tf.float32)\n", "time_step_spec = ts.time_step_spec(observation_spec)\n", "action_spec = tensor_spec.BoundedTensorSpec(\n", " dtype=tf.int32, shape=(), minimum=0, maximum=2)\n", "\n", "agent = lin_ucb_agent.LinearUCBAgent(time_step_spec=time_step_spec,\n", " action_spec=action_spec)" ] }, { "cell_type": "markdown", "metadata": { "id": "Eua_aC7Rt78G" }, "source": [ "## 后悔值指标" ] }, { "cell_type": "markdown", "metadata": { "id": "FBJDiJvEt-xC" }, "source": [ "老虎机最重要的指标就是*后悔值*,计算方式是求代理收集的奖励与可以访问环境奖励函数的先知策略的预期奖励之差。因此,[RegretMetric](https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/metrics/tf_metrics.py) 需要 *baseline_reward_fn* 函数来计算给定观测值的最佳可实现预期奖励。对于我们的示例,我们需要取我们已经为环境定义的奖励函数的无噪声等效函数的最大值。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cX7MiFhNu3_L" }, "outputs": [], "source": [ "def compute_optimal_reward(observation):\n", " expected_reward_for_arms = [\n", " tf.linalg.matvec(observation, tf.cast(arm0_param, dtype=tf.float32)),\n", " tf.linalg.matvec(observation, tf.cast(arm1_param, dtype=tf.float32)),\n", " tf.linalg.matvec(observation, tf.cast(arm2_param, dtype=tf.float32))]\n", " optimal_action_reward = tf.reduce_max(expected_reward_for_arms, axis=0)\n", " return optimal_action_reward\n", "\n", "regret_metric = tf_metrics.RegretMetric(compute_optimal_reward)" ] }, { "cell_type": "markdown", "metadata": { "id": "YRWz-Qeb13JC" }, "source": [ "## 训练" ] }, { "cell_type": "markdown", "metadata": { "id": "khdKjTs516Pg" }, "source": [ "现在,我们将上面介绍的所有组件组合到一起:环境、策略和代理。我们借助*驱动器*在环境上运行策略并输出训练数据,并基于这些数据训练代理。\n", "\n", "请注意,有两个参数共同指定所采取的步数。`num_iterations` 将指定我们运行训练器循环的次数,而驱动器将在每次迭代中执行 `steps_per_loop` 步。保留这两项参数的主要原因是,有些运算是在每次迭代中完成的,而有些运算是由驱动器在每一步中完成的。例如,代理的 `train` 函数仅在每次迭代中调用一次。这里需要权衡之处在于,如果我们以更高的频率进行训练,那么我们的策略会“更加新鲜”;另一方面,以更大批次进行训练可能会更具时间效率。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4Ggn45g62DWx" }, "outputs": [], "source": [ "num_iterations = 90 # @param\n", "steps_per_loop = 1 # @param\n", "\n", "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n", " data_spec=agent.policy.trajectory_spec,\n", " batch_size=batch_size,\n", " max_length=steps_per_loop)\n", "\n", "observers = [replay_buffer.add_batch, regret_metric]\n", "\n", "driver = dynamic_step_driver.DynamicStepDriver(\n", " env=environment,\n", " policy=agent.collect_policy,\n", " num_steps=steps_per_loop * batch_size,\n", " observers=observers)\n", "\n", "regret_values = []\n", "\n", "for _ in range(num_iterations):\n", " driver.run()\n", " loss_info = agent.train(replay_buffer.gather_all())\n", " replay_buffer.clear()\n", " regret_values.append(regret_metric.result())\n", "\n", "plt.plot(regret_values)\n", "plt.ylabel('Average Regret')\n", "plt.xlabel('Number of Iterations')" ] }, { "cell_type": "markdown", "metadata": { "id": "J2diHS5IzLuo" }, "source": [ "运行最后一个代码段后,生成的统计图(有望)显示,在给定的观测值下,平均后悔值会随着代理的训练而逐渐下降并且策略会逐渐更加善于确定正确的动作。" ] }, { "cell_type": "markdown", "metadata": { "id": "2qLMnOL00-2V" }, "source": [ "# 后续步骤" ] }, { "cell_type": "markdown", "metadata": { "id": "FOiRWZbf1Drs" }, "source": [ "要查看更多工作示例,请参阅 [bandits/agents/examples](https://github.com/tensorflow/agents/tree/master/tf_agents/bandits/agents/examples/v2) 目录,其中包含针对不同代理和环境的随时可运行的示例。\n", "\n", "TF-Agents 库还能够处理具有每臂特征的多臂老虎机。为此,我们建议读者阅读每臂老虎机[教程](https://github.com/tensorflow/agents/tree/master/docs/tutorials/per_arm_bandits_tutorial.ipynb)。" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "bandits_tutorial.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }