{ "cells": [ { "cell_type": "markdown", "metadata": { "cellView": "form", "id": "tuOe1ymfHZPu" }, "source": [ "````{admonition} Copyright 2019 The TensorFlow Authors.\n", "```\n", "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "```\n", "````" ] }, { "cell_type": "markdown", "metadata": { "id": "qFdPvlXBOdUN" }, "source": [ "# Estimator" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
tf.estimator.Estimator
类的类。\n",
"\n",
"有关简单示例,请查看 [Estimator 教程](../tutorials/estimator/linear.ipynb)。有关 API 设计概述,请参阅[白皮书](https://arxiv.org/abs/1708.02637)。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KLdnqg4G2bmz"
},
"source": [
"## 设置"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from set_env import temp_dir"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cXRQ6mRM5gk0"
},
"outputs": [],
"source": [
"!pip install -U tensorflow_datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "J_-C9ty22dkD"
},
"outputs": [],
"source": [
"import tempfile\n",
"import os\n",
"\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wg5zbBliQvNL"
},
"source": [
"## 优势\n",
"\n",
"与 `tf.keras.Model` 类似,`estimator` 是模型级别的抽象。`tf.estimator` 提供了一些目前仍在为 `tf.keras` 开发中的功能。包括:\n",
"\n",
"- 基于参数服务器的训练\n",
"- 完整的 [TFX](http://tensorflow.org/tfx) 集成"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yQ8fQYt_VD5E"
},
"source": [
"## Estimator 功能\n",
"\n",
"Estimator 提供了以下优势:\n",
"\n",
"- 您可以在本地主机上或分布式多服务器环境中运行基于 Estimator 的模型,而无需更改模型。此外,您还可以在 CPU、GPU 或 TPU 上运行基于 Estimator 的模型,而无需重新编码模型。\n",
"- Estimator 提供了安全的分布式训练循环,可控制如何以及何时进行以下操作:\n",
" - 加载数据\n",
" - 处理异常\n",
" - 创建检查点文件并从故障中恢复\n",
" - 保存 TensorBoard 摘要\n",
"\n",
"在用 Estimator 编写应用时,您必须将数据输入流水线与模型分离。这种分离简化了使用不同数据集进行的实验。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jQ2PsufpgIpM"
},
"source": [
"## 预制 Estimator 程序结构\n",
"\n",
"使用预制 Estimator,您能够在比基础 TensorFlow API 高很多的概念层面上工作。您无需再担心创建计算图或会话,因为 Estimator 会替您完成所有“基础工作”。此外,使用预制 Estimator,您只需改动较少代码就能试验不同的模型架构。例如,`tf.estimator.DNNClassifier` 是一个预制 Estimator 类,可基于密集的前馈神经网络对分类模型进行训练。\n",
"\n",
"依赖于预制 Estimator 的 TensorFlow 程序通常包括以下四个步骤:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mIJPPe26gQpF"
},
"source": [
"### 1. 编写一个或多个数据集导入函数。\n",
"\n",
"例如,您可以创建一个函数来导入训练集,创建另一个函数来导入测试集。每个数据集导入函数必须返回以下两个对象:\n",
"\n",
"- 字典,其中键是特征名称,值是包含相应特征数据的张量(或 SparseTensor)\n",
"- 包含一个或多个标签的张量\n",
"\n",
"`input_fn` 应当返回一个 `tf.data.Dataset` 以产生该格式的对。\n",
"\n",
"例如,以下代码展示了输入函数的基本框架:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7fl_C5d6hEl3"
},
"outputs": [],
"source": [
"def train_input_fn():\n",
" titanic_file = tf.keras.utils.get_file(\"train.csv\", \"https://storage.googleapis.com/tf-datasets/titanic/train.csv\")\n",
" titanic = tf.data.experimental.make_csv_dataset(\n",
" titanic_file, batch_size=32,\n",
" label_name=\"survived\")\n",
" titanic_batches = (\n",
" titanic.cache().repeat().shuffle(500)\n",
" .prefetch(tf.data.AUTOTUNE))\n",
" return titanic_batches"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CjyrQGb3mCcp"
},
"source": [
"`input_fn` 在 `tf.Graph` 中执行,也可以直接返回包含计算图张量的 `(features_dics, labels)` 对,但这在返回常量等简单情况之外很容易出错。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yJYjWUMxgTnq"
},
"source": [
"### 2. 定义特征列。\n",
"\n",
"每个 `tf.feature_column` 标识了特征名称、特征类型,以及任何输入预处理。例如,以下代码段创建了三个包含整数或浮点数据的特征列。前两个特征列仅标识了特征的名称和类型。第三个特征列还指定了一个会被程序调用以缩放原始数据的 lambda:\n",
"\n",
"例如,以下代码段会创建三个特征列。\n",
"\n",
"- 第一个直接使用 `age` 特征作为浮点输入。\n",
"- 第二个使用 `class` 特征作为分类输入。\n",
"- 第三个使用 `embark_town` 作为分类输入,但使用 `hashing trick` 来避免枚举选项并设置选项数量的需要。\n",
"\n",
"有关详细信息,请参阅[特征列教程](https://tensorflow.google.cn/tutorials/keras/feature_columns)。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lFd8Dnrmhjhr"
},
"outputs": [],
"source": [
"# Define three numeric feature columns. population = tf.feature_column.numeric_column('population') crime_rate = tf.feature_column.numeric_column('crime_rate') median_education = tf.feature_column.numeric_column( 'median_education', normalizer_fn=lambda x: x - global_education_mean)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UIjqAozjgXdr"
},
"source": [
"### 3. 实例化相关预制 Estimator。\n",
"\n",
"例如,下面是对名为 `LinearClassifier` 的预制 Estimator 进行实例化的示例:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CDOx6lZVoVB8"
},
"outputs": [],
"source": [
"# Instantiate an estimator, passing the feature columns. estimator = tf.estimator.LinearClassifier( feature_columns=[population, crime_rate, median_education])\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QGl9oYuFoYj6"
},
"source": [
"有关详细信息,请参阅[线性分类器教程](https://tensorflow.google.cn/tutorials/estimator/linear)。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sXNBeY-oVxGQ"
},
"source": [
"### 4. 调用训练、评估或推断方法。\n",
"\n",
"所有 Estimator 都提供 `train`、 `evaluate` 和 `predict` 方法。\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iGaJKkmVBgo2"
},
"outputs": [],
"source": [
"# `input_fn` is the function created in Step 1 estimator.train(input_fn=my_training_set, steps=2000)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CXkivCNq0vfH"
},
"outputs": [],
"source": [
"result = model.evaluate(train_input_fn, steps=10)\n",
"\n",
"for key, value in result.items():\n",
" print(key, \":\", value)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CPLD8n4CLVi_"
},
"outputs": [],
"source": [
"您可以在下面看到与此相关的示例。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cbmrm9pFg5vo"
},
"source": [
"### 预制 Estimator 的优势\n",
"\n",
"预制 Estimator 对最佳做法进行了编码,具有以下优势:\n",
"\n",
"- 确定计算图不同部分的运行位置,以及在单台机器或集群上实施策略的最佳做法。\n",
"- 事件(摘要)编写和通用摘要的最佳做法。\n",
"\n",
"如果不使用预制 Estimator,则您必须自己实现上述功能。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oIaPjYgnZdn6"
},
"source": [
"## 自定义 Estimator\n",
"\n",
"每个 Estimator(无论预制还是自定义)的核心是其*模型函数*,这是一种为训练、评估和预测构建计算图的方法。当您使用预制 Estimator 时,已经有人为您实现了模型函数。当使用自定义 Estimator 时,您必须自己编写模型函数。\n",
"\n",
"> 注:自定义 `model_fn` 仍将在 1.x 样式的计算图模式下运行。这意味着没有 Eager Execution,也没有自动控制依赖项。您应当计划使用自定义 `model_fn` 从 `tf.estimator` 迁移。替代 API 是 `tf.keras` 和 `tf.distribute`。如果您的训练的某个部分仍需要 `Estimator`,则可以使用 `tf.keras.estimator.model_to_estimator` 转换器从 `keras.Model` 创建 `Estimator`。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P7aPNnXUbN4j"
},
"source": [
"## 从 Keras 模型创建 Estimator\n",
"\n",
"您可以使用 `tf.keras.estimator.model_to_estimator` 将现有的 Keras 模型转换为 Estimator。这样一来,您的 Keras 模型就可以利用 Estimator 的优势,例如分布式训练。\n",
"\n",
"实例化 Keras MobileNet V2 模型并用训练中使用的优化器、损失和指标来编译模型:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XE6NMcuGeDOP"
},
"outputs": [],
"source": [
"keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(\n",
" input_shape=(160, 160, 3), include_top=False)\n",
"keras_mobilenet_v2.trainable = False\n",
"\n",
"estimator_model = tf.keras.Sequential([\n",
" keras_mobilenet_v2,\n",
" tf.keras.layers.GlobalAveragePooling2D(),\n",
" tf.keras.layers.Dense(1)\n",
"])\n",
"\n",
"# Compile the model\n",
"estimator_model.compile(\n",
" optimizer='adam',\n",
" loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A3hcxzcEfYfX"
},
"source": [
"从已编译的 Keras 模型创建 `Estimator`。Keras 模型的初始模型状态会保留在已创建的 `Estimator`中:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UCSSifirfyHk"
},
"outputs": [],
"source": [
"est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8jRNRVb_fzGT"
},
"source": [
"您可以像对待任何其他 `Estimator` 一样对待派生的 `Estimator`。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Rv9xJk51e1fB"
},
"outputs": [],
"source": [
"IMG_SIZE = 160 # All images will be resized to 160x160\n",
"\n",
"def preprocess(image, label):\n",
" image = tf.cast(image, tf.float32)\n",
" image = (image/127.5) - 1\n",
" image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))\n",
" return image, label"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Fw8OjwujVBkc"
},
"outputs": [],
"source": [
"def train_input_fn(batch_size):\n",
" data = tfds.load('cats_vs_dogs', as_supervised=True)\n",
" train_data = data['train']\n",
" train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)\n",
" return train_data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JMb0cuy0gbTi"
},
"source": [
"要进行训练,可调用 Estimator 的训练函数:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4JsvMp8Jge80"
},
"outputs": [],
"source": [
"est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jvr_rAzngY9v"
},
"source": [
"同样,要进行评估,可调用 Estimator 的评估函数:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kVNPqysQgYR2"
},
"outputs": [],
"source": [
"est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5HeTOvCYbjZb"
},
"source": [
"有关详细信息,请参阅 `tf.keras.estimator.model_to_estimator` 文档。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zGG1tOM0L6iM"
},
"source": [
"## 从 Keras 模型创建 Estimator\n",
"\n",
"默认情况下,Estimator 使用变量名而不是[检查点指南](checkpoint.ipynb)中介绍的对象计算图来保存检查点。tf.train.Checkpoint
将读取基于名称的检查点,但是在将模型的一部分移到 Estimator 的 `model_fn` 外部时,变量名称可能会更改。对于前向兼容性,保存基于对象的检查点可以更轻松地在 Estimator 内训练模型,然后在外部使用。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-8AMJeueNyoM"
},
"outputs": [],
"source": [
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "W5JbCEUGY-Xo"
},
"outputs": [],
"source": [
"import tensorflow_datasets as tfds\n",
"tfds.disable_progress_bar()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gTZbsIRCZnCU"
},
"outputs": [],
"source": [
"class Net(tf.keras.Model):\n",
" \"\"\"A simple linear model.\"\"\"\n",
"\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.l1 = tf.keras.layers.Dense(5)\n",
"\n",
" def call(self, x):\n",
" return self.l1(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T6fQsBzJQN2y"
},
"outputs": [],
"source": [
"def model_fn(features, labels, mode):\n",
" net = Net()\n",
" opt = tf.keras.optimizers.Adam(0.1)\n",
" ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),\n",
" optimizer=opt, net=net)\n",
" with tf.GradientTape() as tape:\n",
" output = net(features['x'])\n",
" loss = tf.reduce_mean(tf.abs(output - features['y']))\n",
" variables = net.trainable_variables\n",
" gradients = tape.gradient(loss, variables)\n",
" return tf.estimator.EstimatorSpec(\n",
" mode,\n",
" loss=loss,\n",
" train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),\n",
" ckpt.step.assign_add(1)),\n",
" # Tell the Estimator to save \"ckpt\" in an object-based format.\n",
" scaffold=tf_compat.train.Scaffold(saver=ckpt))\n",
"\n",
"tf.keras.backend.clear_session()\n",
"est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')\n",
"est.train(toy_dataset, steps=10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tObYHnrrb_mL"
},
"source": [
"随后,`tf.train.Checkpoint` 可以从其 `model_dir` 加载 Estimator 的检查点。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Q6IP3Y_wb-fs"
},
"outputs": [],
"source": [
"opt = tf.keras.optimizers.Adam(0.1)\n",
"net = Net()\n",
"ckpt = tf.train.Checkpoint(\n",
" step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)\n",
"ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))\n",
"ckpt.step.numpy() # From est.train(..., steps=10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Dk5wWyuMpuHx"
},
"source": [
"## Estimator 中的 SavedModel\n",
"\n",
"Estimator 通过 `tf.Estimator.export_saved_model` 导出 SavedModel。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B9KQq5qzpzbK"
},
"outputs": [],
"source": [
"input_column = tf.feature_column.numeric_column(\"x\")\n",
"\n",
"estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])\n",
"\n",
"def input_fn():\n",
" return tf.data.Dataset.from_tensor_slices(\n",
" ({\"x\": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)\n",
"estimator.train(input_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "y9qCa6J6FVS5"
},
"source": [
"要保存 `Estimator`,您需要创建 `serving_input_receiver`。此函数构建 `tf.Graph` 的一部分,用于解析 SavedModel 接收到的原始数据。\n",
"\n",
"`tf.estimator.export` 模块包含帮助构建这些 `receivers` 的函数。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XJ4PJ-Cl4060"
},
"source": [
"下面的代码基于 `feature_columns` 构建一个接收器,它接受通常与 [tf-serving](https://tensorflow.org/serving) 一起使用的序列化 `tf.Example` 协议缓冲区。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lnmsmGOQFPED"
},
"outputs": [],
"source": [
"tmpdir = tempfile.mkdtemp()\n",
"\n",
"serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(\n",
" tf.feature_column.make_parse_example_spec([input_column]))\n",
"\n",
"estimator_base_path = os.path.join(tmpdir, 'from_estimator')\n",
"estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q7XtbLMDaie2"
},
"source": [
"您还可以从 Python 加载和运行该模型:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "c_BUBBNB1UH9"
},
"outputs": [],
"source": [
"imported = tf.saved_model.load(estimator_path)\n",
"\n",
"def predict(x):\n",
" example = tf.train.Example()\n",
" example.features.feature[\"x\"].float_list.value.extend([x])\n",
" return imported.signatures[\"predict\"](\n",
" examples=tf.constant([example.SerializeToString()]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "C1ylWZCQ1ahG"
},
"outputs": [],
"source": [
"print(predict(1.5))\n",
"print(predict(3.5))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_IrCCm0-isqA"
},
"source": [
"通过 `tf.estimator.export.build_raw_serving_input_receiver_fn` 可以创建输入函数,这些函数使用原始张量,而不是 `tf.train.Example`。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nO0hmFCRoIll"
},
"source": [
"## 在 Estimator 中使用 `tf.distribute.Strategy`(有限支持)\n",
"\n",
"`tf.estimator` 是分布式训练 TensorFlow API,最初支持异步参数服务器方法。`tf.estimator` 现在支持 `tf.distribute.Strategy`。如果您正在使用 `tf.estimator`,那么您只需改动少量代码即可轻松转换为分布式训练。借助此功能,Estimator 用户现在可以在多个 GPU 和多个工作进程以及 TPU 上进行同步分布式训练。但是,Estimator 的这种支持是有限的。有关详细信息,请参阅下文[目前支持的策略](#estimator_support)部分。\n",
"\n",
"在 Estimator 中使用 `tf.distribute.Strategy` 的方法与在 Keras 中略有不同。现在我们不使用 `strategy.scope`,而是将策略对象传递到 Estimator 的 RunConfig 中。\n",
"\n",
"要了解更多信息,请参阅[分布式训练指南](distributed_training.ipynb)。\n",
"\n",
"以下代码段使用预制 Estimator `LinearRegressor` 和 `MirroredStrategy` 展示了这种情况:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oGFY5nW_B3YU"
},
"outputs": [],
"source": [
"mirrored_strategy = tf.distribute.MirroredStrategy()\n",
"config = tf.estimator.RunConfig(\n",
" train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)\n",
"regressor = tf.estimator.LinearRegressor(\n",
" feature_columns=[tf.feature_column.numeric_column('feats')],\n",
" optimizer='SGD',\n",
" config=config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "n6eSfLN5RGY8"
},
"source": [
"我们在这里使用了预制 Estimator,但同样的代码也适用于自定义 Estimator。`train_distribute` 决定训练如何分布,`eval_distribute` 决定评估如何分布。这是与 Keras 的另一个区别,在 Keras 中,我们会对训练和评估使用相同的策略。\n",
"\n",
"现在,我们可以使用输入函数来训练和评估这个 Estimator:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2ky2ve2PB3YP"
},
"outputs": [],
"source": [
"def input_fn(dataset): ... # manipulate dataset, extracting the feature dict and the label return feature_dict, label\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hgaU9xQSSk2x"
},
"source": [
"需要在这里强调的 Estimator 和 Keras 的另一个区别是输入处理。在 Keras 中,数据集的每个批次都会在多个副本之间自动拆分。但在 Estimator 中,批次不会自动拆分,也不会在不同的工作进程之间自动对数据进行分片处理。您可以完全控制数据在工作进程和设备之间的分布方式,而且您必须提供 `input_fn` 来指定数据的分布方式。\n",
"\n",
"每个工作进程都会调用一次 `input_fn`,从而为每个工作进程提供一个数据集。然后数据集中的一个批次会被馈送到此工作进程上的一个副本,因此,1 个工作进程上的 N 个副本要使用 N 个批次。换句话说,`input_fn` 返回的数据集应提供大小为 `PER_REPLICA_BATCH_SIZE` 的批次。步骤的全局批次大小可通过 `PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync` 获得。\n",
"\n",
"在进行多工作进程训练时,您应该将数据拆分至各个工作进程,或者在每个工作进程上重排随机种子。您可以在[使用 Estimator 进行多工作进程训练](../tutorials/distribute/multi_worker_with_estimator.ipynb)教程中查看有关此操作的示例。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G3ieQKfWZhhL"
},
"source": [
"同样,您也可以使用多工作进程和参数服务器策略。代码保持不变,但需要使用 `tf.estimator.train_and_evaluate`,并为集群中运行的每个二进制文件设置 `TF_CONFIG` 环境变量。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A_lvUsSLZzVg"
},
"source": [
"\n",
"\n",
"### 目前支持的策略\n",
"\n",
"除 `TPUStrategy` 外,所有策略都对使用 Estimator 的训练提供有限支持。基本训练和评估应该可以正常运行,但如 `v1.train.Scaffold` 之类的许多高级功能尚不可用。此集成中可能还存在许多错误。目前,我们不打算主动改进此支持,而是专注于对 Keras 和自定义训练循环的支持。如果可能,您应该会更喜欢在这些 API 中使用 `tf.distribute`。\n",
"\n",
"训练 API | MirroredStrategy | TPUStrategy | MultiWorkerMirroredStrategy | CentralStorageStrategy | ParameterServerStrategy\n",
":-- | :-- | :-- | :-- | :-- | :--\n",
"Estimator API | 有限支持 | 不支持 | 有限支持 | 有限支持 | 有限支持\n",
"\n",
"### 示例和教程\n",
"\n",
"如果可能,您可以通过构建自己的自定义 Estimator 进一步改进模型。\n",
"\n",
"1. [使用 Estimator 进行多工作进程训练教程](../tutorials/distribute/multi_worker_with_estimator.ipynb)展示了如何在 MNIST 数据集上使用 `MultiWorkerMirroredStrategy` 在多个工作进程上一起训练。\n",
"2. 使用 Kubernetes 模板在 `tensorflow/ecosystem` 中[使用分布策略运行多工作进程训练](https://github.com/tensorflow/ecosystem/tree/master/distribution_strategy)的端到端示例。它从 Keras 模型开始,然后使用 `tf.keras.estimator.model_to_estimator` API 将其转换为 Estimator。\n",
"3. 如果有其他合适的预制 Estimator,可通过运行实验确定哪个预制 Estimator 能够生成最佳结果。"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [
"Tce3stUlHN0L",
"KLdnqg4G2bmz",
"Wg5zbBliQvNL",
"yQ8fQYt_VD5E",
"jQ2PsufpgIpM",
"mIJPPe26gQpF",
"yJYjWUMxgTnq",
"UIjqAozjgXdr",
"sXNBeY-oVxGQ",
"cbmrm9pFg5vo",
"oIaPjYgnZdn6",
"P7aPNnXUbN4j",
"zGG1tOM0L6iM",
"Dk5wWyuMpuHx",
"nO0hmFCRoIll",
"A_lvUsSLZzVg"
],
"name": "estimator.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}