##### Copyright 2019 The TensorFlow Authors.
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
利用 Keras 来训练多工作器(worker)#
在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 Github 上查看源代码 | 下载笔记本 |
概述#
本教程演示了如何使用 tf.distribute.MultiWorkerMirroredStrategy
API 通过 Keras 模型和 Model.fit
API 执行多工作进程分布式训练。借助此策略,设计用于在单个工作进程上运行的 Keras 模型只需最少量的代码变更即可无缝地在多个工作进程上运行。
To learn how to use the MultiWorkerMirroredStrategy
with Keras and a custom training loop, refer to Custom training loop with Keras and MultiWorkerMirroredStrategy.
本教程包含一个最小多工作进程示例,出于演示目的,其中有两个工作进程。
选择正确的策略#
在开始之前,请确保 tf.distribute.MultiWorkerMirroredStrategy
是您的加速器和训练的正确选择。以下是使用数据并行分布训练的两种常见方式:
同步训练,训练步骤在工作进程和副本之间同步,例如
tf.distribute.MirroredStrategy
、tf.distribute.TPUStrategy
和tf.distribute.MultiWorkerMirroredStrategy
。所有工作进程同步训练不同的输入数据切片,并在每一步聚合梯度。异步训练,训练步骤不会严格同步,例如
tf.distribute.experimental.ParameterServerStrategy
。所有工作进程都对输入数据进行独立训练并异步更新变量。
如果您正在寻找没有 TPU 的多工作进程同步训练,那么应该选择 tf.distribute.MultiWorkerMirroredStrategy
。它会在所有工作进程的每个设备上在模型的层中创建所有变量的副本。它使用 CollectiveOps
(一种用于集合通信的 TensorFlow 运算)来聚合梯度并保持变量同步。如果您有兴趣,请查看 tf.distribute.experimental.CommunicationOptions
参数以了解集合实施选项。
有关 tf.distribute.Strategy
API 的概述,请参阅 TensorFlow 中的分布式训练。
设置#
先进行一些必要的导入:
import json
import os
import sys
在导入 TensorFlow 之前,需要对环境进行一些变更:
在现实世界的应用中,每个工作进程将在不同的机器上运行。出于本教程的目的,所有工作进程都将在这台机器上运行。因此,请停用所有 GPU 以防止因所有工作进程尝试使用同一 GPU 而导致的错误。
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
重置
TF_CONFIG
环境变量(稍后您将了解更多相关信息):
os.environ.pop('TF_CONFIG', None)
确保当前目录位于 Python 的路径上。这样,笔记本可以导入稍后由
%%writefile
写入的文件:
if '.' not in sys.path:
sys.path.insert(0, '.')
安装 tf-nightly
,因为使用 tf.keras.callbacks.BackupAndRestore
中的 save_freq
参数设置特定步骤保存检查点的频率是从 TensorFlow 2.10 引入的:
!pip install tf-nightly
最后,导入 TensorFlow:
import tensorflow as tf
数据集和模型定义#
接下来,使用简单的模型和数据集设置创建 mnist_setup.py
文件。本教程中的工作进程将使用此 Python 文件:
%%writefile mnist_setup.py
import os
import tensorflow as tf
import numpy as np
def mnist_dataset(batch_size):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
# The `x` arrays are in uint8 and have values in the [0, 255] range.
# You need to convert them to float32 with values in the [0, 1] range.
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
return train_dataset
def build_and_compile_cnn_model():
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=['accuracy'])
return model
在单个工作进程上进行模型训练#
让我们首先尝试用少量的 epoch 来训练模型,并在单个工作器(worker)中观察结果,以确保一切正常。 随着训练的迭代,您应该会看到损失(loss)下降和准确度(accuracy)接近1.0。
import mnist_setup
batch_size = 64
single_worker_dataset = mnist_setup.mnist_dataset(batch_size)
single_worker_model = mnist_setup.build_and_compile_cnn_model()
single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)
多工作进程配置#
现在,让我们进入多工作进程训练的世界。
具有作业和任务的集群#
在 TensorFlow 中,分布式训练涉及:一个包含多个作业的 'cluster'
,每个作业可能有一个或多个 'task'
。
您将需要使用 TF_CONFIG
配置环境变量在多台计算机上进行训练,每台计算机都可能具有不同的角色。TF_CONFIG
是一个 JSON 字符串,用于在每个作为集群一部分的工作进程上指定集群配置。
TF_CONFIG
有两个可用组件:'cluster'
和 'task'
。
'cluster'
对所有工作进程都相同,并提供有关训练集群的信息,这是一个由不同类型的作业(例如'worker'
或'chief'
)组成的字典。在使用
tf.distribute.MultiWorkerMirroredStrategy
进行多工作进程训练时,除了普通的'worker'
之外,通常还有一个'worker'
承担更多责任,例如保存检查点和为 TensorBoard 编写摘要文件。此类'worker'
被称为首席工作进程(作业名称为'chief'
)。'chief'
通常是'index'
为0
的工作进程。
'task'
提供当前任务的信息,在每个工作进程上各不相同。它指定相应工作进程的'type'
和'index'
。
下面是一个示例配置:
tf_config = {
'cluster': {
'worker': ['localhost:12345', 'localhost:23456']
},
'task': {'type': 'worker', 'index': 0}
}
请注意,tf_config
只是 Python 中的局部变量。要将其用于训练配置,请将其序列化为 JSON 并将其放置在 TF_CONFIG
环境变量中。
json.dumps(tf_config)
在上面的示例配置中,您将任务 'type'
设置为 'worker'
,并将任务 'index'
设置为 0
。因此,这台计算机是第一个工作进程。它将被任命为 'chief'
工作进程。
注:其他计算机也需要设置 TF_CONFIG
环境变量,并且它应该具有相同的 'cluster'
字典,但具有不同的任务 'type'
或任务 'index'
,具体取决于这些计算机的角色。
在实践中,您将在外部 IP 地址/端口上创建多个工作进程,并相应地在每个工作进程上设置一个 TF_CONFIG
变量。出于说明目的,本教程展示了如何在 localhost
上设置带有两个工作进程的 TF_CONFIG
变量:
上面显示了第一个 (
'chief'
) 工作进程的TF_CONFIG
。对于第二个工作进程,您将设置
tf_config['task']['index']=1
笔记本中的环境变量和子进程#
子进程会从其父进程继承环境变量。因此,如果您在此 Jupyter Notebook 进程中设置环境变量:
os.environ['GREETINGS'] = 'Hello TensorFlow!'
…然后,您可以从子进程访问环境变量:
%%bash
echo ${GREETINGS}
在下一部分中,您将使用此方法将 TF_CONFIG
传递给子工作进程。在现实世界的场景中,您永远不会以这种方式启动你的作业。本教程只是为了展示如何通过一个最小的多工作进程示例来做到这一点。
训练模型#
要训练模型,首先创建一个 tf.distribute.MultiWorkerMirroredStrategy
的实例:
strategy = tf.distribute.MultiWorkerMirroredStrategy()
注:调用 MultiWorkerMirroredStrategy
时,将解析 TF_CONFIG
并启动 TensorFlow 的 GRPC 服务器,因此必须在创建tf.distribute.Strategy
实例之前设置 TF_CONFIG
环境变量。
得益于将 tf.distribute.Strategy
API 集成到 tf.keras
中,您在将训练分布到多个工作进程中需要执行的唯一变更就是将模型构建和 model.compile()
调用添加到 strategy.scope()
内。分布策略的范围决定了变量的创建方式和位置,采用 MultiWorkerMirroredStrategy
时,创建的变量将为 MirroredVariable
,它们会复制到每个工作进程上。
with strategy.scope():
# Model building/compiling need to be within `strategy.scope()`.
multi_worker_model = mnist_setup.build_and_compile_cnn_model()
注:目前在 MultiWorkerMirroredStrategy
中存在一个限制,即需要在创建策略实例后再创建 TensorFlow 运算。如果您遇到 RuntimeError: Collective ops must be configured at program startup
,请尝试在程序的开头创建 MultiWorkerMirroredStrategy
的实例,并在策略实例化后加入可以创建运算的代码。
要实际使用 MultiWorkerMirroredStrategy
运行,您需要运行工作进程并向其传递 TF_CONFIG
。
与之前编写的 mnist_setup.py
文件一样,以下是每个工作进程都将运行的 main.py
:
%%writefile main.py
import os
import json
import tensorflow as tf
import mnist_setup
per_worker_batch_size = 64
tf_config = json.loads(os.environ['TF_CONFIG'])
num_workers = len(tf_config['cluster']['worker'])
strategy = tf.distribute.MultiWorkerMirroredStrategy()
global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_setup.mnist_dataset(global_batch_size)
with strategy.scope():
# Model building/compiling need to be within `strategy.scope()`.
multi_worker_model = mnist_setup.build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
在上面的代码段中,请注意传递给 Dataset.batch
的 global_batch_size
设置为 per_worker_batch_size * num_workers
。这可以确保每个工作进程均处理若干批次的 per_worker_batch_size
样本,而不受工作进程数量影响。
当前目录现包含两个 Python 文件:
%%bash
ls *.py
将 TF_CONFIG
序列化为 JSON 并将其添加到环境变量中:
os.environ['TF_CONFIG'] = json.dumps(tf_config)
现在,您可以启动一个将运行 main.py
并使用 TF_CONFIG
的工作进程:
# first kill any previous runs
%killbgscripts
%%bash --bg
python main.py &> job_0.log
以上命令有几点需要注意:
它使用
%%bash
,这是一项用于运行一些 bash 命令的笔记本“魔术命令”。它使用
--bg
标志在后台运行bash
进程,因为此工作进程不会终止。它在开始之前会等待所有工作进程。
后台工作进程不会将输出打印到此笔记本,因此 &>
会将其输出重定向到一个文件,以便您稍后在日志文件中查看所发生的情况。
那么,请等待几秒钟以启动该进程:
import time
time.sleep(10)
现在,查看一下目前为止输出到工作进程日志文件的内容:
%%bash
cat job_0.log
日志文件的最后一行内容应为:Started server with target: grpc://localhost:12345
。第一个工作进程现已准备就绪,并等待所有其他工作进程继续。
随后,更新 tf_config
以供第二个工作进程取用:
tf_config['task']['index'] = 1
os.environ['TF_CONFIG'] = json.dumps(tf_config)
启动第二个工作进程。这将开始训练,因为所有工作进程都已处于活动状态(因此无需在后台执行此进程):
%%bash
python main.py
如果您重新检查第一个工作进程编写的日志,您会看到它参与了该模型的训练:
%%bash
cat job_0.log
注:这可能要慢于本教程开头的测试运行,因为在单台计算机上运行多个工作进程只会增加开销。这里的目标并非提高训练速度,而是为了提供一个多工作进程训练的示例。
# Delete the `TF_CONFIG`, and kill any background tasks so they don't affect the next section.
os.environ.pop('TF_CONFIG', None)
%killbgscripts
深入了解多工作进程训练#
到目前为止,您已经学习了如何执行基本的多工作进程设置。本教程的其余部分详细介绍了对实际用例可能有用或重要的其他因素。
数据集分片和批(batch)大小#
在多工作器训练中,需要将数据分片为多个部分,以确保融合和性能。 但是,请注意,在上面的代码片段中,数据集直接发送到model.fit()
,而无需分片; 这是因为tf.distribute.Strategy
API在多工作器训练中会自动处理数据集分片。
前一部分中的示例依赖于 tf.distribute.Strategy
API 提供的默认自动分片功能。您可以通过设置 tf.data.experimental.DistributeOptions
的 tf.data.experimental.AutoShardPolicy
控制分片。
要详细了解自动分片,请参阅分布式输入指南。
以下是如何关闭自动分片的快速示例,以便使每个副本都会处理每个样本(不推荐):
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
global_batch_size = 64
multi_worker_dataset = mnist_setup.mnist_dataset(batch_size=64)
dataset_no_auto_shard = multi_worker_dataset.with_options(options)
评估#
如果您还将 validation_data
传递给 Model.fit
,它将在每个周期的训练和评估之间交替。评估作业分布在同一组工作进程中,其结果会被聚合并对所有工作进程可用。
与训练类似,验证数据集在文件级别自动分片。您需要在验证数据集中设置全局批次大小并设置 validation_steps
。
建议使用重复数据集(通过调用 tf.data.Dataset.repeat
)进行评估。
或者,您也可以创建另一个任务来定期读取检查点并运行评估。这就是 Estimator 的工作。但这并不是执行评估的推荐方式,因此不做赘述。
性能#
要调整多工作进程训练的性能,您可以尝试以下操作:
tf.distribute.MultiWorkerMirroredStrategy
提供了多种集合通信实现:RING
使用 gRPC 作为跨主机通信层实现基于环的集合。NCCL
使用 NVIDIA Collective Communication Library 实现集合。AUTO
将选择推迟到运行时。
集合实现的最佳选择取决于 GPU 的数量、类型和集群中的网络互连。要重写自动选择,请指定
MultiWorkerMirroredStrategy
的构造函数的communication_options
参数。例如:communication_options=tf.distribute.experimental.CommunicationOptions(implementation=tf.distribute.experimental.CommunicationImplementation.NCCL)
如果可能的话,将变量强制转换为
tf.float
。官方 ResNet 模型包括如何完成此操作的示例。
容错#
在同步训练中,如果其中一个工作进程出现故障并且不存在故障恢复机制,则集群将失败。
在工作进程退出或不稳定的情况下,将 Keras 与 tf.distribute.Strategy
结合使用会具有容错的优势。您可以通过在您选择的分布式文件系统中保留训练状态来做到这一点,以便在重新启动先前失败或被抢占的实例后,将恢复训练状态。
当一个工作进程不可用时,其他工作进程将失败(可能先发生超时)。在这种情况下,需要重新启动不可用的工作进程以及其他失败的工作进程。
注:之前,ModelCheckpoint
回调提供了一种在从多工作进程训练的作业失败重启时恢复训练状态的机制。TensorFlow 团队引入了一个新的 BackupAndRestore
回调,这也添加了对单个工作进程训练的支持以获得一致的体验,并从现有 ModelCheckpoint
回调中移除了容错功能。从现在开始,依赖于此行为的应用应迁移到新 BackupAndRestore
回调。
ModelCheckpoint
回调#
ModelCheckpoint
回调不再提供容错功能,请改用 BackupAndRestore
回调。
ModelCheckpoint
回调仍可用于保存检查点。但使用此回调时,当训练中断或成功完成后,如果要继续从检查点进行训练,用户必须手动加载模型。
另外,用户也可以选择在 ModelCheckpoint
回调之外保存和恢复模型/权重。
模型保存和加载#
要使用 model.save
或 tf.saved_model.save
保存模型,每个工作进程需具有不同的保存目标。
对于非首席工作进程,您需要将模型保存到临时目录。
对于首席工作进程,您需要保存到提供的模型目录。
工作进程上的临时目录必须唯一,以防止多个工作进程尝试写入同一位置而导致错误。
所有目录中保存的模型都是相同的,通常只需要引用首席工作进程保存的模型进行恢复或应用。
您应该有一些清理逻辑,可以在训练完成后删除工作进程创建的临时目录。
需要同时在首席工作进程和其他工作进程上保存的原因是您可能会在读取检查点时聚合变量,这需要首席工作进程和其他工作进程都参与 AllReduce 通信协议。另一方面,让首席工作进程和其他工作进程保存到同一个模型目录中会因为争用路径而导致错误。
通过使用 MultiWorkerMirroredStrategy
,程序会在每个工作进程上运行,它利用了具有 task_type
和 task_id
特性的集群解析器对象来确定当前的工作进程是否为首席工作进程:
task_type
会告诉您当前的作业是什么(例如'worker'
)。task_id
会告诉您工作进程的标识符。task_id == 0
的工作进程会被指定为首席工作进程。
在下面的代码段中,write_filepath
函数提供了要写入的文件路径,这取决于工作进程的 task_id
:
对于首席工作进程 (
task_id == 0
),它会写入原始文件路径。对于其他工作进程,它会创建一个临时目录
temp_dir
,并在目录路径中使用task_id
来写入:
model_path = '/tmp/keras-model'
def _is_chief(task_type, task_id):
# Note: there are two possible `TF_CONFIG` configurations.
# 1) In addition to `worker` tasks, a `chief` task type is use;
# in this case, this function should be modified to
# `return task_type == 'chief'`.
# 2) Only `worker` task type is used; in this case, worker 0 is
# regarded as the chief. The implementation demonstrated here
# is for this case.
# For the purpose of this Colab section, the `task_type` is `None` case
# is added because it is effectively run with only a single worker.
return (task_type == 'worker' and task_id == 0) or task_type is None
def _get_temp_dir(dirpath, task_id):
base_dirpath = 'workertemp_' + str(task_id)
temp_dir = os.path.join(dirpath, base_dirpath)
tf.io.gfile.makedirs(temp_dir)
return temp_dir
def write_filepath(filepath, task_type, task_id):
dirpath = os.path.dirname(filepath)
base = os.path.basename(filepath)
if not _is_chief(task_type, task_id):
dirpath = _get_temp_dir(dirpath, task_id)
return os.path.join(dirpath, base)
task_type, task_id = (strategy.cluster_resolver.task_type,
strategy.cluster_resolver.task_id)
write_model_path = write_filepath(model_path, task_type, task_id)
随后,您就可以保存了:
已弃用:对于 Keras 对象,建议使用新的高级 .keras
格式和 tf.keras.Model.export
,如此处的指南所示。对于现有代码,继续支持低级 SavedModel 格式。
multi_worker_model.save(write_model_path)
如上所述,稍后应仅从保存首席工作进程的文件路径中加载模型。因此,请移除保存非首席工作进程的临时路径:
if not _is_chief(task_type, task_id):
tf.io.gfile.rmtree(os.path.dirname(write_model_path))
接下来,当需要加载时,使用方便的 tf.keras.models.load_model
API,并继续进一步的工作。
在这里,假设仅使用单个工作进程加载并继续训练,在这种情况下,您不会在另一个 strategy.scope()
中调用 tf.keras.models.load_model
(请注意,如前面所定义,strategy = tf.distribute.MultiWorkerMirroredStrategy()
):
loaded_model = tf.keras.models.load_model(model_path)
# Now that the model is restored, and can continue with the training.
loaded_model.fit(single_worker_dataset, epochs=2, steps_per_epoch=20)
检查点保存和恢复#
另一方面,您可以使用检查点保存并恢复模型的权重,而无需保存整个模型。
在这里,您将创建一个由 tf.train.CheckpointManager
管理的跟踪模型的 tf.train.Checkpoint
,以便仅保留最新的检查点:
checkpoint_dir = '/tmp/ckpt'
checkpoint = tf.train.Checkpoint(model=multi_worker_model)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
CheckpointManager
设置完成后,您就可以保存并移除非首席工作进程保存的检查点了:
checkpoint_manager.save()
if not _is_chief(task_type, task_id):
tf.io.gfile.rmtree(write_checkpoint_dir)
现在,当需要恢复模型时,您可以使用方便的 tf.train.latest_checkpoint
函数找到保存的最新检查点。恢复检查点后,您可以继续进行训练。
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint.restore(latest_checkpoint)
multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20)
BackupAndRestore
回调#
tf.keras.callbacks.BackupAndRestore
回调可通过在 BackupAndRestore
的 backup_dir
参数下的临时检查点文件中备份模型和当前训练状态来提供容错功能。
注:在 TensorFlow 2.9 中,当前模型和训练状态在周期边界处备份。在 tf-nightly
版和 TensorFlow 2.10 中,BackupAndRestore
回调可以在周期或步骤边界备份模型和训练状态。BackupAndRestore
接受可选的 save_freq
参数。save_freq
接受 'epoch'
或 int
值。如果 save_freq
设置为 'epoch'
,则模型会在每个周期后备份。如果 save_freq
设置为大于 0
的整数值,则在每 save_freq
个批次后备份模型。
作业中断并重新启动后,BackupAndRestore
回调将恢复上一个检查点,您可以从周期的开始和上次保存训练状态的步骤继续训练。
要使用该回调,请在 Model.fit
调用中提供 tf.keras.callbacks.BackupAndRestore
的实例。
使用 MultiWorkerMirroredStrategy
时,如果一个工作进程被中断,则整个集群都会暂停,直到被中断的工作进程重新启动为止。其他工作进程也会重新启动,且中断的工作进程将重新加入集群。然后,每个工作进程都会读取先前保存的检查点文件并获取其以前的状态,从而使集群恢复同步。然后即可继续训练。分布式数据集迭代器状态将重新初始化,而不会恢复。
The BackupAndRestore
callback uses the CheckpointManager
to save and restore the training state, which generates a file called checkpoint that tracks existing checkpoints together with the latest one. For this reason, backup_dir
should not be re-used to store other checkpoints in order to avoid name collision.
目前,BackupAndRestore
回调支持无策略 (MirroredStrategy
) 的单工作进程训练和采用 MultiWorkerMirroredStrategy
的多工作进程训练。
下面是两个多工作进程训练和单工作进程训练的示例:
# Multi-worker training with `MultiWorkerMirroredStrategy`
# and the `BackupAndRestore` callback. The training state
# is backed up at epoch boundaries by default.
callbacks = [tf.keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup')]
with strategy.scope():
multi_worker_model = mnist_setup.build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset,
epochs=3,
steps_per_epoch=70,
callbacks=callbacks)
如果 BackupAndRestore
回调中的 save_freq
参数设置为 'epoch'
,则模型会在每个周期后备份。
# The training state is backed up at epoch boundaries because `save_freq` is
# set to `epoch`.
callbacks = [tf.keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup')]
with strategy.scope():
multi_worker_model = mnist_setup.build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset,
epochs=3,
steps_per_epoch=70,
callbacks=callbacks)
注:下一个代码块使用了仅在 Tensorflow 2.10 发布后才能在 tf-nightly
中可用的功能。
如果 BackupAndRestore
回调中的 save_freq
参数设置为大于 0
的整数值,则在每 save_freq
个批次后备份模型。
# The training state is backed up at every 30 steps because `save_freq` is set
# to an integer value of `30`.
callbacks = [tf.keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup', save_freq=30)]
with strategy.scope():
multi_worker_model = mnist_setup.build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset,
epochs=3,
steps_per_epoch=70,
callbacks=callbacks)
检查您在 BackupAndRestore
中指定的 backup_dir
目录时,您可能会注意到一些临时生成的检查点文件。在恢复之前丢失的实例时需要用到这些文件,而在成功退出训练后,它们将在 Model.fit
结束时被库移除。
注:目前 BackupAndRestore
回调仅支持 Eager 模式。在图形模式下,考虑将 Model.save
/tf.saved_model.save
和 tf.keras.models.load_model
分别用于保存和恢复模型,如上面的模型保存和加载部分所述,并在训练期间在 Model.fit
中提供 initial_epoch
。
其他资源#
TensorFlow 中的分布式训练指南概述了可用的分布式策略。
使用 Keras 和 MultiWorkerMirroredStrategy 的自定义训练循环教程展示了如何将
MultiWorkerMirroredStrategy
与 Keras 和自定义训练循环一起使用。查看官方模型,其中许多模型可以配置为运行多个分布式策略。
使用 tf.function 提升性能指南提供了有关其他策略和工具的信息,例如可用于优化 TensorFlow 模型性能的 TensorFlow Profiler。