{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "Tce3stUlHN0L" }, "outputs": [], "source": [ "##### Copyright 2019 The TensorFlow Authors.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "tuOe1ymfHZPu", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "# Keras 的分布式训练" ] }, { "cell_type": "markdown", "metadata": { "id": "r6P32iYYV27b" }, "source": [ "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
PrintLR
的回调以在笔记本中显示学习率。\n",
"\n",
"**注:** 使用 `BackupAndRestore` 回调而不是 `ModelCheckpoint` 作为从作业失败重新启动时还原训练状态的主要机制。由于 `BackupAndRestore` 仅支持 Eager 模式,在计算图模式下考虑使用 `ModelCheckpoint`。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A9bwLCcXzSgy",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"# Define the checkpoint directory to store the checkpoints.\n",
"checkpoint_dir = './training_checkpoints'\n",
"# Define the name of the checkpoint files.\n",
"checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wpU-BEdzJDbK",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"# Define a function for decaying the learning rate.\n",
"# You can define any decay function you need.\n",
"def decay(epoch):\n",
" if epoch < 3:\n",
" return 1e-3\n",
" elif epoch >= 3 and epoch < 7:\n",
" return 1e-4\n",
" else:\n",
" return 1e-5"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jKhiMgXtKq2w",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"# Define a callback for printing the learning rate at the end of each epoch.\n",
"class PrintLR(tf.keras.callbacks.Callback):\n",
" def on_epoch_end(self, epoch, logs=None):\n",
" print('\\nLearning rate for epoch {} is {}'.format( epoch + 1, model.optimizer.lr.numpy()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YVqAbR6YyNQh",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"# Put all the callbacks together.\n",
"callbacks = [\n",
" tf.keras.callbacks.TensorBoard(log_dir='./logs'),\n",
" tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,\n",
" save_weights_only=True),\n",
" tf.keras.callbacks.LearningRateScheduler(decay),\n",
" PrintLR()\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "70HXgDQmK46q"
},
"source": [
"## 训练并评估"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6EophnOAB3YD"
},
"source": [
"现在,以普通方式训练模型,在模型上调用 Keras `Model.fit` 并传入在教程开始时创建的数据集。无论您是否分布训练,此步骤相同。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7MVw_6CqB3YD",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"EPOCHS = 12\n",
"\n",
"model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NUcWAUUupIvG"
},
"source": [
"查看保存的检查点:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JQ4zeSTxKEhB",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"# Check the checkpoint directory.\n",
"!ls {checkpoint_dir}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qor53h7FpMke"
},
"source": [
"要查看模型的执行情况,请加载最新的检查点,并在测试数据上调用 `Model.evaluate`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JtEwxiTgpQoP",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))\n",
"\n",
"eval_loss, eval_acc = model.evaluate(eval_dataset)\n",
"\n",
"print('Eval loss: {}, Eval accuracy: {}'.format(eval_loss, eval_acc))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IIeF2RWfYu4N"
},
"source": [
"要可视化输出,请启动 TensorBoard 并查看日志:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vtyAZO0DoKu_",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"%tensorboard --logdir=logs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a0a82d26d6bd"
},
"source": [
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LnyscOkvKKBR",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"!ls -sh ./logs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kBLlogrDvMgg"
},
"source": [
"## 保存模型"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Xa87y_A0vRma"
},
"source": [
"使用 `Model.save` 将模型保存到一个 `.keras` 压缩归档中。保存后,您可以使用或不使用 `Strategy.scope` 加载模型。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "h8Q4MKOLwG7K",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"path = 'my_model.keras'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4HvcDmVsvQoa",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"model.save(path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vKJT4w5JwVPI"
},
"source": [
"现在,在没有 `Strategy.scope` 的情况下加载模型:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T_gT0RbRvQ3o",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"unreplicated_model = tf.keras.models.load_model(path)\n",
"\n",
"unreplicated_model.compile(\n",
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" optimizer=tf.keras.optimizers.Adam(),\n",
" metrics=['accuracy'])\n",
"\n",
"eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)\n",
"\n",
"print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YBLzcRF0wbDe"
},
"source": [
"使用 `Strategy.scope` 加载模型:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BBVo3WGGwd9a",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"with strategy.scope():\n",
" replicated_model = tf.keras.models.load_model(path)\n",
" replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" optimizer=tf.keras.optimizers.Adam(),\n",
" metrics=['accuracy'])\n",
"\n",
" eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)\n",
" print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MUZwaz4AKjtD"
},
"source": [
"### 其他资源\n",
"\n",
"更多通过 Keras `Model.fit` API 使用不同分布策略的示例:\n",
"\n",
"1. [在 TPU 上使用 BERT 解决 GLUE 任务](https://tensorflow.google.cn/text/tutorials/bert_glue)教程使用 `tf.distribute.MirroredStrategy` 在 GPU 上进行训练,并使用 `tf.distribute.TPUStrategy` 在 TPU 上进行训练。\n",
"2. [使用分布式策略保存和加载模型](save_and_load.ipynb)教程演示了如何将 SavedModel API 与 `tf.distribute.Strategy` 一起使用。\n",
"3. [官方 TensorFlow 模型](https://github.com/tensorflow/models/tree/master/official)可以配置为运行多个分布式策略。\n",
"\n",
"要了解有关 TensorFlow 分布式策略的更多信息,请参阅以下资料:\n",
"\n",
"1. [使用 tf.distribute.Strategy 进行自定义训练](custom_training.ipynb)教程展示了如何使用 `tf.distribute.MirroredStrategy` 通过自定义训练循环进行单工作进程训练。\n",
"2. [使用 Keras 进行多工作进程训练](multi_worker_with_keras.ipynb)教程展示了如何将 `MultiWorkerMirroredStrategy` 与 `Model.fit` 一起使用。\n",
"3. [使用 Keras 和 MultiWorkerMirroredStrategy 的自定义训练循环](multi_worker_with_ctl.ipynb)教程展示了如何将 `MultiWorkerMirroredStrategy` 与 Keras 和自定义训练循环一起使用。\n",
"4. [TensorFlow 中的分布式训练](https://tensorflow.google.cn/guide/distributed_training)指南概述了可用的分布式策略。\n",
"5. [使用 tf.function 获得更佳性能](../../guide/function.ipynb)指南提供了有关其他策略和工具的信息,例如可用于优化 TensorFlow 模型性能的 [TensorFlow Profiler](../../guide/profiler.md)。\n",
"\n",
"注:`tf.distribute.Strategy` 正在积极开发中,TensorFlow 将在不久的将来添加更多示例和教程。请进行尝试。我们欢迎您通过 [GitHub 上的议题](https://github.com/tensorflow/tensorflow/issues/new)提交反馈。"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "keras.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}