{ "cells": [ { "cell_type": "markdown", "metadata": { "cellView": "form", "id": "tuOe1ymfHZPu" }, "source": [ "````{admonition} Copyright 2019 The TensorFlow Authors.\n", "```\n", "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "```\n", "````" ] }, { "cell_type": "markdown", "metadata": { "id": "qFdPvlXBOdUN" }, "source": [ "# Estimator" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看在 Google Colab 中运行在 GitHub 上查看源代码下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "rILQuAiiRlI7" }, "source": [ "> 警告:不建议将 Estimator 用于新代码。Estimator 运行 `v1.Session` 风格的代码,此类代码更加难以正确编写,并且可能会出现意外行为,尤其是与 TF 2 代码结合使用时。Estimator 确实在我们的[兼容性保证](https://tensorflow.org/guide/versions)范围内,但除了安全漏洞之外不会得到任何修复。请参阅[迁移指南](https://tensorflow.org/guide/migrate)以了解详情。" ] }, { "cell_type": "markdown", "metadata": { "id": "oEinLJt2Uowq" }, "source": [ "本文档介绍了 `tf.estimator`,它是一种高级 TensorFlow API。Estimator 封装了以下操作:\n", "\n", "- 训练\n", "- 评估\n", "- 预测\n", "- 导出以供使用\n", "\n", "您可以使用我们提供的预制 Estimator 或编写您自己的自定义 Estimator。所有 Estimator(无论是预制还是自定义)都是基于 tf.estimator.Estimator 类的类。\n", "\n", "有关简单示例,请查看 [Estimator 教程](../tutorials/estimator/linear.ipynb)。有关 API 设计概述,请参阅[白皮书](https://arxiv.org/abs/1708.02637)。" ] }, { "cell_type": "markdown", "metadata": { "id": "KLdnqg4G2bmz" }, "source": [ "## 设置" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from set_env import temp_dir" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cXRQ6mRM5gk0" }, "outputs": [], "source": [ "!pip install -U tensorflow_datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J_-C9ty22dkD" }, "outputs": [], "source": [ "import tempfile\n", "import os\n", "\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "markdown", "metadata": { "id": "Wg5zbBliQvNL" }, "source": [ "## 优势\n", "\n", "与 `tf.keras.Model` 类似,`estimator` 是模型级别的抽象。`tf.estimator` 提供了一些目前仍在为 `tf.keras` 开发中的功能。包括:\n", "\n", "- 基于参数服务器的训练\n", "- 完整的 [TFX](http://tensorflow.org/tfx) 集成" ] }, { "cell_type": "markdown", "metadata": { "id": "yQ8fQYt_VD5E" }, "source": [ "## Estimator 功能\n", "\n", "Estimator 提供了以下优势:\n", "\n", "- 您可以在本地主机上或分布式多服务器环境中运行基于 Estimator 的模型,而无需更改模型。此外,您还可以在 CPU、GPU 或 TPU 上运行基于 Estimator 的模型,而无需重新编码模型。\n", "- Estimator 提供了安全的分布式训练循环,可控制如何以及何时进行以下操作:\n", " - 加载数据\n", " - 处理异常\n", " - 创建检查点文件并从故障中恢复\n", " - 保存 TensorBoard 摘要\n", "\n", "在用 Estimator 编写应用时,您必须将数据输入流水线与模型分离。这种分离简化了使用不同数据集进行的实验。" ] }, { "cell_type": "markdown", "metadata": { "id": "jQ2PsufpgIpM" }, "source": [ "## 预制 Estimator 程序结构\n", "\n", "使用预制 Estimator,您能够在比基础 TensorFlow API 高很多的概念层面上工作。您无需再担心创建计算图或会话,因为 Estimator 会替您完成所有“基础工作”。此外,使用预制 Estimator,您只需改动较少代码就能试验不同的模型架构。例如,`tf.estimator.DNNClassifier` 是一个预制 Estimator 类,可基于密集的前馈神经网络对分类模型进行训练。\n", "\n", "依赖于预制 Estimator 的 TensorFlow 程序通常包括以下四个步骤:" ] }, { "cell_type": "markdown", "metadata": { "id": "mIJPPe26gQpF" }, "source": [ "### 1. 编写一个或多个数据集导入函数。\n", "\n", "例如,您可以创建一个函数来导入训练集,创建另一个函数来导入测试集。每个数据集导入函数必须返回以下两个对象:\n", "\n", "- 字典,其中键是特征名称,值是包含相应特征数据的张量(或 SparseTensor)\n", "- 包含一个或多个标签的张量\n", "\n", "`input_fn` 应当返回一个 `tf.data.Dataset` 以产生该格式的对。\n", "\n", "例如,以下代码展示了输入函数的基本框架:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7fl_C5d6hEl3" }, "outputs": [], "source": [ "def train_input_fn():\n", " titanic_file = tf.keras.utils.get_file(\"train.csv\", \"https://storage.googleapis.com/tf-datasets/titanic/train.csv\")\n", " titanic = tf.data.experimental.make_csv_dataset(\n", " titanic_file, batch_size=32,\n", " label_name=\"survived\")\n", " titanic_batches = (\n", " titanic.cache().repeat().shuffle(500)\n", " .prefetch(tf.data.AUTOTUNE))\n", " return titanic_batches" ] }, { "cell_type": "markdown", "metadata": { "id": "CjyrQGb3mCcp" }, "source": [ "`input_fn` 在 `tf.Graph` 中执行,也可以直接返回包含计算图张量的 `(features_dics, labels)` 对,但这在返回常量等简单情况之外很容易出错。" ] }, { "cell_type": "markdown", "metadata": { "id": "yJYjWUMxgTnq" }, "source": [ "### 2. 定义特征列。\n", "\n", "每个 `tf.feature_column` 标识了特征名称、特征类型,以及任何输入预处理。例如,以下代码段创建了三个包含整数或浮点数据的特征列。前两个特征列仅标识了特征的名称和类型。第三个特征列还指定了一个会被程序调用以缩放原始数据的 lambda:\n", "\n", "例如,以下代码段会创建三个特征列。\n", "\n", "- 第一个直接使用 `age` 特征作为浮点输入。\n", "- 第二个使用 `class` 特征作为分类输入。\n", "- 第三个使用 `embark_town` 作为分类输入,但使用 `hashing trick` 来避免枚举选项并设置选项数量的需要。\n", "\n", "有关详细信息,请参阅[特征列教程](https://tensorflow.google.cn/tutorials/keras/feature_columns)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lFd8Dnrmhjhr" }, "outputs": [], "source": [ "# Define three numeric feature columns. population = tf.feature_column.numeric_column('population') crime_rate = tf.feature_column.numeric_column('crime_rate') median_education = tf.feature_column.numeric_column( 'median_education', normalizer_fn=lambda x: x - global_education_mean)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "UIjqAozjgXdr" }, "source": [ "### 3. 实例化相关预制 Estimator。\n", "\n", "例如,下面是对名为 `LinearClassifier` 的预制 Estimator 进行实例化的示例:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CDOx6lZVoVB8" }, "outputs": [], "source": [ "# Instantiate an estimator, passing the feature columns. estimator = tf.estimator.LinearClassifier( feature_columns=[population, crime_rate, median_education])\n" ] }, { "cell_type": "markdown", "metadata": { "id": "QGl9oYuFoYj6" }, "source": [ "有关详细信息,请参阅[线性分类器教程](https://tensorflow.google.cn/tutorials/estimator/linear)。" ] }, { "cell_type": "markdown", "metadata": { "id": "sXNBeY-oVxGQ" }, "source": [ "### 4. 调用训练、评估或推断方法。\n", "\n", "所有 Estimator 都提供 `train`、 `evaluate` 和 `predict` 方法。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iGaJKkmVBgo2" }, "outputs": [], "source": [ "# `input_fn` is the function created in Step 1 estimator.train(input_fn=my_training_set, steps=2000)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CXkivCNq0vfH" }, "outputs": [], "source": [ "result = model.evaluate(train_input_fn, steps=10)\n", "\n", "for key, value in result.items():\n", " print(key, \":\", value)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CPLD8n4CLVi_" }, "outputs": [], "source": [ "您可以在下面看到与此相关的示例。" ] }, { "cell_type": "markdown", "metadata": { "id": "cbmrm9pFg5vo" }, "source": [ "### 预制 Estimator 的优势\n", "\n", "预制 Estimator 对最佳做法进行了编码,具有以下优势:\n", "\n", "- 确定计算图不同部分的运行位置,以及在单台机器或集群上实施策略的最佳做法。\n", "- 事件(摘要)编写和通用摘要的最佳做法。\n", "\n", "如果不使用预制 Estimator,则您必须自己实现上述功能。" ] }, { "cell_type": "markdown", "metadata": { "id": "oIaPjYgnZdn6" }, "source": [ "## 自定义 Estimator\n", "\n", "每个 Estimator(无论预制还是自定义)的核心是其*模型函数*,这是一种为训练、评估和预测构建计算图的方法。当您使用预制 Estimator 时,已经有人为您实现了模型函数。当使用自定义 Estimator 时,您必须自己编写模型函数。\n", "\n", "> 注:自定义 `model_fn` 仍将在 1.x 样式的计算图模式下运行。这意味着没有 Eager Execution,也没有自动控制依赖项。您应当计划使用自定义 `model_fn` 从 `tf.estimator` 迁移。替代 API 是 `tf.keras` 和 `tf.distribute`。如果您的训练的某个部分仍需要 `Estimator`,则可以使用 `tf.keras.estimator.model_to_estimator` 转换器从 `keras.Model` 创建 `Estimator`。" ] }, { "cell_type": "markdown", "metadata": { "id": "P7aPNnXUbN4j" }, "source": [ "## 从 Keras 模型创建 Estimator\n", "\n", "您可以使用 `tf.keras.estimator.model_to_estimator` 将现有的 Keras 模型转换为 Estimator。这样一来,您的 Keras 模型就可以利用 Estimator 的优势,例如分布式训练。\n", "\n", "实例化 Keras MobileNet V2 模型并用训练中使用的优化器、损失和指标来编译模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XE6NMcuGeDOP" }, "outputs": [], "source": [ "keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(\n", " input_shape=(160, 160, 3), include_top=False)\n", "keras_mobilenet_v2.trainable = False\n", "\n", "estimator_model = tf.keras.Sequential([\n", " keras_mobilenet_v2,\n", " tf.keras.layers.GlobalAveragePooling2D(),\n", " tf.keras.layers.Dense(1)\n", "])\n", "\n", "# Compile the model\n", "estimator_model.compile(\n", " optimizer='adam',\n", " loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", " metrics=['accuracy'])" ] }, { "cell_type": "markdown", "metadata": { "id": "A3hcxzcEfYfX" }, "source": [ "从已编译的 Keras 模型创建 `Estimator`。Keras 模型的初始模型状态会保留在已创建的 `Estimator`中:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UCSSifirfyHk" }, "outputs": [], "source": [ "est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)" ] }, { "cell_type": "markdown", "metadata": { "id": "8jRNRVb_fzGT" }, "source": [ "您可以像对待任何其他 `Estimator` 一样对待派生的 `Estimator`。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Rv9xJk51e1fB" }, "outputs": [], "source": [ "IMG_SIZE = 160 # All images will be resized to 160x160\n", "\n", "def preprocess(image, label):\n", " image = tf.cast(image, tf.float32)\n", " image = (image/127.5) - 1\n", " image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))\n", " return image, label" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Fw8OjwujVBkc" }, "outputs": [], "source": [ "def train_input_fn(batch_size):\n", " data = tfds.load('cats_vs_dogs', as_supervised=True)\n", " train_data = data['train']\n", " train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)\n", " return train_data" ] }, { "cell_type": "markdown", "metadata": { "id": "JMb0cuy0gbTi" }, "source": [ "要进行训练,可调用 Estimator 的训练函数:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4JsvMp8Jge80" }, "outputs": [], "source": [ "est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)" ] }, { "cell_type": "markdown", "metadata": { "id": "jvr_rAzngY9v" }, "source": [ "同样,要进行评估,可调用 Estimator 的评估函数:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kVNPqysQgYR2" }, "outputs": [], "source": [ "est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "5HeTOvCYbjZb" }, "source": [ "有关详细信息,请参阅 `tf.keras.estimator.model_to_estimator` 文档。" ] }, { "cell_type": "markdown", "metadata": { "id": "zGG1tOM0L6iM" }, "source": [ "## 从 Keras 模型创建 Estimator\n", "\n", "默认情况下,Estimator 使用变量名而不是[检查点指南](checkpoint.ipynb)中介绍的对象计算图来保存检查点。tf.train.Checkpoint 将读取基于名称的检查点,但是在将模型的一部分移到 Estimator 的 `model_fn` 外部时,变量名称可能会更改。对于前向兼容性,保存基于对象的检查点可以更轻松地在 Estimator 内训练模型,然后在外部使用。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-8AMJeueNyoM" }, "outputs": [], "source": [ "import tensorflow as tf" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "W5JbCEUGY-Xo" }, "outputs": [], "source": [ "import tensorflow_datasets as tfds\n", "tfds.disable_progress_bar()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gTZbsIRCZnCU" }, "outputs": [], "source": [ "class Net(tf.keras.Model):\n", " \"\"\"A simple linear model.\"\"\"\n", "\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.l1 = tf.keras.layers.Dense(5)\n", "\n", " def call(self, x):\n", " return self.l1(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "T6fQsBzJQN2y" }, "outputs": [], "source": [ "def model_fn(features, labels, mode):\n", " net = Net()\n", " opt = tf.keras.optimizers.Adam(0.1)\n", " ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),\n", " optimizer=opt, net=net)\n", " with tf.GradientTape() as tape:\n", " output = net(features['x'])\n", " loss = tf.reduce_mean(tf.abs(output - features['y']))\n", " variables = net.trainable_variables\n", " gradients = tape.gradient(loss, variables)\n", " return tf.estimator.EstimatorSpec(\n", " mode,\n", " loss=loss,\n", " train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),\n", " ckpt.step.assign_add(1)),\n", " # Tell the Estimator to save \"ckpt\" in an object-based format.\n", " scaffold=tf_compat.train.Scaffold(saver=ckpt))\n", "\n", "tf.keras.backend.clear_session()\n", "est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')\n", "est.train(toy_dataset, steps=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "tObYHnrrb_mL" }, "source": [ "随后,`tf.train.Checkpoint` 可以从其 `model_dir` 加载 Estimator 的检查点。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Q6IP3Y_wb-fs" }, "outputs": [], "source": [ "opt = tf.keras.optimizers.Adam(0.1)\n", "net = Net()\n", "ckpt = tf.train.Checkpoint(\n", " step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)\n", "ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))\n", "ckpt.step.numpy() # From est.train(..., steps=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "Dk5wWyuMpuHx" }, "source": [ "## Estimator 中的 SavedModel\n", "\n", "Estimator 通过 `tf.Estimator.export_saved_model` 导出 SavedModel。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B9KQq5qzpzbK" }, "outputs": [], "source": [ "input_column = tf.feature_column.numeric_column(\"x\")\n", "\n", "estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])\n", "\n", "def input_fn():\n", " return tf.data.Dataset.from_tensor_slices(\n", " ({\"x\": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)\n", "estimator.train(input_fn)" ] }, { "cell_type": "markdown", "metadata": { "id": "y9qCa6J6FVS5" }, "source": [ "要保存 `Estimator`,您需要创建 `serving_input_receiver`。此函数构建 `tf.Graph` 的一部分,用于解析 SavedModel 接收到的原始数据。\n", "\n", "`tf.estimator.export` 模块包含帮助构建这些 `receivers` 的函数。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "XJ4PJ-Cl4060" }, "source": [ "下面的代码基于 `feature_columns` 构建一个接收器,它接受通常与 [tf-serving](https://tensorflow.org/serving) 一起使用的序列化 `tf.Example` 协议缓冲区。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lnmsmGOQFPED" }, "outputs": [], "source": [ "tmpdir = tempfile.mkdtemp()\n", "\n", "serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(\n", " tf.feature_column.make_parse_example_spec([input_column]))\n", "\n", "estimator_base_path = os.path.join(tmpdir, 'from_estimator')\n", "estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)" ] }, { "cell_type": "markdown", "metadata": { "id": "Q7XtbLMDaie2" }, "source": [ "您还可以从 Python 加载和运行该模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "c_BUBBNB1UH9" }, "outputs": [], "source": [ "imported = tf.saved_model.load(estimator_path)\n", "\n", "def predict(x):\n", " example = tf.train.Example()\n", " example.features.feature[\"x\"].float_list.value.extend([x])\n", " return imported.signatures[\"predict\"](\n", " examples=tf.constant([example.SerializeToString()]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C1ylWZCQ1ahG" }, "outputs": [], "source": [ "print(predict(1.5))\n", "print(predict(3.5))" ] }, { "cell_type": "markdown", "metadata": { "id": "_IrCCm0-isqA" }, "source": [ "通过 `tf.estimator.export.build_raw_serving_input_receiver_fn` 可以创建输入函数,这些函数使用原始张量,而不是 `tf.train.Example`。" ] }, { "cell_type": "markdown", "metadata": { "id": "nO0hmFCRoIll" }, "source": [ "## 在 Estimator 中使用 `tf.distribute.Strategy`(有限支持)\n", "\n", "`tf.estimator` 是分布式训练 TensorFlow API,最初支持异步参数服务器方法。`tf.estimator` 现在支持 `tf.distribute.Strategy`。如果您正在使用 `tf.estimator`,那么您只需改动少量代码即可轻松转换为分布式训练。借助此功能,Estimator 用户现在可以在多个 GPU 和多个工作进程以及 TPU 上进行同步分布式训练。但是,Estimator 的这种支持是有限的。有关详细信息,请参阅下文[目前支持的策略](#estimator_support)部分。\n", "\n", "在 Estimator 中使用 `tf.distribute.Strategy` 的方法与在 Keras 中略有不同。现在我们不使用 `strategy.scope`,而是将策略对象传递到 Estimator 的 RunConfig 中。\n", "\n", "要了解更多信息,请参阅[分布式训练指南](distributed_training.ipynb)。\n", "\n", "以下代码段使用预制 Estimator `LinearRegressor` 和 `MirroredStrategy` 展示了这种情况:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oGFY5nW_B3YU" }, "outputs": [], "source": [ "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "config = tf.estimator.RunConfig(\n", " train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)\n", "regressor = tf.estimator.LinearRegressor(\n", " feature_columns=[tf.feature_column.numeric_column('feats')],\n", " optimizer='SGD',\n", " config=config)" ] }, { "cell_type": "markdown", "metadata": { "id": "n6eSfLN5RGY8" }, "source": [ "我们在这里使用了预制 Estimator,但同样的代码也适用于自定义 Estimator。`train_distribute` 决定训练如何分布,`eval_distribute` 决定评估如何分布。这是与 Keras 的另一个区别,在 Keras 中,我们会对训练和评估使用相同的策略。\n", "\n", "现在,我们可以使用输入函数来训练和评估这个 Estimator:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2ky2ve2PB3YP" }, "outputs": [], "source": [ "def input_fn(dataset): ... # manipulate dataset, extracting the feature dict and the label return feature_dict, label\n" ] }, { "cell_type": "markdown", "metadata": { "id": "hgaU9xQSSk2x" }, "source": [ "需要在这里强调的 Estimator 和 Keras 的另一个区别是输入处理。在 Keras 中,数据集的每个批次都会在多个副本之间自动拆分。但在 Estimator 中,批次不会自动拆分,也不会在不同的工作进程之间自动对数据进行分片处理。您可以完全控制数据在工作进程和设备之间的分布方式,而且您必须提供 `input_fn` 来指定数据的分布方式。\n", "\n", "每个工作进程都会调用一次 `input_fn`,从而为每个工作进程提供一个数据集。然后数据集中的一个批次会被馈送到此工作进程上的一个副本,因此,1 个工作进程上的 N 个副本要使用 N 个批次。换句话说,`input_fn` 返回的数据集应提供大小为 `PER_REPLICA_BATCH_SIZE` 的批次。步骤的全局批次大小可通过 `PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync` 获得。\n", "\n", "在进行多工作进程训练时,您应该将数据拆分至各个工作进程,或者在每个工作进程上重排随机种子。您可以在[使用 Estimator 进行多工作进程训练](../tutorials/distribute/multi_worker_with_estimator.ipynb)教程中查看有关此操作的示例。" ] }, { "cell_type": "markdown", "metadata": { "id": "G3ieQKfWZhhL" }, "source": [ "同样,您也可以使用多工作进程和参数服务器策略。代码保持不变,但需要使用 `tf.estimator.train_and_evaluate`,并为集群中运行的每个二进制文件设置 `TF_CONFIG` 环境变量。" ] }, { "cell_type": "markdown", "metadata": { "id": "A_lvUsSLZzVg" }, "source": [ "\n", "\n", "### 目前支持的策略\n", "\n", "除 `TPUStrategy` 外,所有策略都对使用 Estimator 的训练提供有限支持。基本训练和评估应该可以正常运行,但如 `v1.train.Scaffold` 之类的许多高级功能尚不可用。此集成中可能还存在许多错误。目前,我们不打算主动改进此支持,而是专注于对 Keras 和自定义训练循环的支持。如果可能,您应该会更喜欢在这些 API 中使用 `tf.distribute`。\n", "\n", "训练 API | MirroredStrategy | TPUStrategy | MultiWorkerMirroredStrategy | CentralStorageStrategy | ParameterServerStrategy\n", ":-- | :-- | :-- | :-- | :-- | :--\n", "Estimator API | 有限支持 | 不支持 | 有限支持 | 有限支持 | 有限支持\n", "\n", "### 示例和教程\n", "\n", "如果可能,您可以通过构建自己的自定义 Estimator 进一步改进模型。\n", "\n", "1. [使用 Estimator 进行多工作进程训练教程](../tutorials/distribute/multi_worker_with_estimator.ipynb)展示了如何在 MNIST 数据集上使用 `MultiWorkerMirroredStrategy` 在多个工作进程上一起训练。\n", "2. 使用 Kubernetes 模板在 `tensorflow/ecosystem` 中[使用分布策略运行多工作进程训练](https://github.com/tensorflow/ecosystem/tree/master/distribution_strategy)的端到端示例。它从 Keras 模型开始,然后使用 `tf.keras.estimator.model_to_estimator` API 将其转换为 Estimator。\n", "3. 如果有其他合适的预制 Estimator,可通过运行实验确定哪个预制 Estimator 能够生成最佳结果。" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "Tce3stUlHN0L", "KLdnqg4G2bmz", "Wg5zbBliQvNL", "yQ8fQYt_VD5E", "jQ2PsufpgIpM", "mIJPPe26gQpF", "yJYjWUMxgTnq", "UIjqAozjgXdr", "sXNBeY-oVxGQ", "cbmrm9pFg5vo", "oIaPjYgnZdn6", "P7aPNnXUbN4j", "zGG1tOM0L6iM", "Dk5wWyuMpuHx", "nO0hmFCRoIll", "A_lvUsSLZzVg" ], "name": "estimator.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }