##### Copyright 2021 The TensorFlow Authors.
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
迁移检查点保存#
在 TensorFlow.org 上查看 | 在 Google Colab 运行 | 在 Github 上查看源代码 | 下载笔记本 |
持续保存“最佳”模型或模型权重/参数有许多好处,包括能够跟踪训练进度并从不同的保存状态加载保存的模型。
在 TensorFlow 1 中,要使用 tf.estimator.Estimator
API 在训练/验证期间配置检查点保存,可以在 tf.estimator.RunConfig
中指定计划或使用 tf.estimator.CheckpointSaverHook
。本指南演示了如何从该工作流迁移到 TensorFlow 2 Keras API。
在 TensorFlow 2 中,可以通过多种方式配置 tf.keras.callbacks.ModelCheckpoint
:
根据使用
save_best_only=True
参数监视的指标保存“最佳”版本,其中monitor
可以是'loss'
、'val_loss'
、'accuracy'
或'val_accuracy'
。以特定频率持续保存(使用
save_freq
参数)。通过将
save_weights_only
设置为True
,仅保存权重/参数而不是整个模型。
有关详情,请参阅 tensorflow.keras.callbacks.ModelCheckpoint
API 文档和保存和加载模型教程中的训练期间保存检查点部分。在保存和加载 Keras 模型指南中的 TF 检查点格式部分中详细了解检查点格式。另外,要添加容错,可以使用 tf.keras.callbacks.BackupAndRestore
或 tf.train.Checkpoint
手动设置检查点。在容错迁移指南中了解详情。
Keras 回调是在内置 Keras Model.fit
/Model.evaluate
/Model.predict
API 中的训练/评估/预测期间的不同点调用的对象。请在指南末尾的后续步骤部分中了解详情。
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 设置日志级别为ERROR,以减少警告信息
# 禁用 Gemini 的底层库(gRPC 和 Abseil)在初始化日志警告
os.environ["GRPC_VERBOSITY"] = "ERROR"
os.environ["GLOG_minloglevel"] = "3" # 0: INFO, 1: WARNING, 2: ERROR, 3: FATAL
os.environ["GLOG_minloglevel"] = "true"
import logging
import tensorflow as tf
tf.get_logger().setLevel(logging.ERROR)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
!export TF_FORCE_GPU_ALLOW_GROWTH=true
from pathlib import Path
temp_dir = Path(".temp")
temp_dir.mkdir(parents=True, exist_ok=True)
安装#
从导入和用于演示目的的简单数据集开始:
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
TensorFlow 1:使用 tf.estimator API 保存检查点#
此 TensorFlow 1 示例展示了如何配置 tf.estimator.RunConfig
以在使用 tf.estimator.Estimator
API 进行训练/评估期间的每一步保存检查点:
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.c.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp(dir=temp_dir)
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
test_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_test},
y=y_test.astype(np.int32),
num_epochs=10,
shuffle=False
)
train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)
eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,
steps=10,
throttle_secs=0)
tf1.estimator.train_and_evaluate(estimator=classifier,
train_spec=train_spec,
eval_spec=eval_spec)
%ls {classifier.model_dir}
TensorFlow 2:使用 Model.fit 的 Keras 回调保存检查点#
在 TensorFlow 2 中,使用内置 Keras Model.fit
(或 Model.evaluate
)进行训练/评估时,可以配置 tf.keras.callbacks.ModelCheckpoint
,然后将其传递给 Model.fit
(或 Model.evaluate
)的 callbacks
参数。(请在 API 文档和使用内置方法进行训练和评估指南中的使用回调部分中了解详情。)
在下面的示例中,您将使用 tf.keras.callbacks.ModelCheckpoint
回调将检查点存储在临时目录中:
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
steps_per_execution=10)
log_dir = tempfile.mkdtemp(dir=temp_dir)
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=f"{log_dir}/test.keras")
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[model_checkpoint_callback])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[15], line 20
15 log_dir = tempfile.mkdtemp(dir=temp_dir)
17 model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
18 filepath=f"{log_dir}/test.keras")
---> 20 model.fit(x=x_train,
21 y=y_train,
22 epochs=10,
23 validation_data=(x_test, y_test),
24 callbacks=[model_checkpoint_callback])
File /media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.__traceback__)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
File /media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/context.py:657, in Context.ensure_initialized(self)
654 pywrap_tfe.TFE_ContextOptionsSetRunEagerOpAsFunction(opts, True)
655 pywrap_tfe.TFE_ContextOptionsSetJitCompileRewrite(
656 opts, self._jit_compile_rewrite)
--> 657 context_handle = pywrap_tfe.TFE_NewContext(opts)
658 finally:
659 pywrap_tfe.TFE_DeleteContextOptions(opts)
RuntimeError: Bad StatusOr access: INTERNAL: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 25428426752
%ls {model_checkpoint_callback.filepath}
后续步骤#
在以下资源中详细了解检查点:
API 文档:
tf.keras.callbacks.ModelCheckpoint
教程:保存和加载模型(训练期间保存检查点部分)
指南:保存和加载 Keras 模型(TF 检查点格式部分)
以下资源中详细了解回调:
API 文档:
tf.keras.callbacks.Callback
指南:编写自己的回调
指南:使用内置方法进行训练和评估(使用回调部分)
此外,您可能还会发现下列与迁移相关的资源十分有用:
容错迁移指南:用于
Model.fit
的tf.keras.callbacks.BackupAndRestore
,或用于自定义训练循环的tf.train.Checkpoint
和tf.train.CheckpointManager
API提前停止迁移指南:
tf.keras.callbacks.EarlyStopping
是一个内置的提前停止回调TensorBoard 迁移指南:TensorBoard 支持跟踪和显示指标