##### Copyright 2021 The TensorFlow Authors.
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

迁移检查点保存#

在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本

持续保存“最佳”模型或模型权重/参数有许多好处,包括能够跟踪训练进度并从不同的保存状态加载保存的模型。

在 TensorFlow 1 中,要使用 tf.estimator.Estimator API 在训练/验证期间配置检查点保存,可以在 tf.estimator.RunConfig 中指定计划或使用 tf.estimator.CheckpointSaverHook。本指南演示了如何从该工作流迁移到 TensorFlow 2 Keras API。

在 TensorFlow 2 中,可以通过多种方式配置 tf.keras.callbacks.ModelCheckpoint

  • 根据使用 save_best_only=True 参数监视的指标保存“最佳”版本,其中 monitor 可以是 'loss''val_loss''accuracy''val_accuracy'

  • 以特定频率持续保存(使用 save_freq 参数)。

  • 通过将 save_weights_only 设置为 True,仅保存权重/参数而不是整个模型。

有关详情,请参阅 tensorflow.keras.callbacks.ModelCheckpoint API 文档和保存和加载模型教程中的训练期间保存检查点部分。在保存和加载 Keras 模型指南中的 TF 检查点格式部分中详细了解检查点格式。另外,要添加容错,可以使用 tf.keras.callbacks.BackupAndRestoretf.train.Checkpoint 手动设置检查点。在容错迁移指南中了解详情。

Keras 回调是在内置 Keras Model.fit/Model.evaluate/Model.predict API 中的训练/评估/预测期间的不同点调用的对象。请在指南末尾的后续步骤部分中了解详情。

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # 设置日志级别为ERROR,以减少警告信息
# 禁用 Gemini 的底层库(gRPC 和 Abseil)在初始化日志警告
os.environ["GRPC_VERBOSITY"] = "ERROR"
os.environ["GLOG_minloglevel"] = "3"  # 0: INFO, 1: WARNING, 2: ERROR, 3: FATAL
os.environ["GLOG_minloglevel"] = "true"
import logging
import tensorflow as tf
tf.get_logger().setLevel(logging.ERROR)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
!export TF_FORCE_GPU_ALLOW_GROWTH=true

from pathlib import Path

temp_dir = Path(".temp")
temp_dir.mkdir(parents=True, exist_ok=True)

安装#

从导入和用于演示目的的简单数据集开始:

import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

TensorFlow 1:使用 tf.estimator API 保存检查点#

此 TensorFlow 1 示例展示了如何配置 tf.estimator.RunConfig 以在使用 tf.estimator.Estimator API 进行训练/评估期间的每一步保存检查点:

feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]

config = tf1.c.RunConfig(save_summary_steps=1,
                                 save_checkpoints_steps=1)

path = tempfile.mkdtemp(dir=temp_dir)

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)

train_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_train},
    y=y_train.astype(np.int32),
    num_epochs=10,
    batch_size=50,
    shuffle=True,
)

test_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_test},
    y=y_test.astype(np.int32),
    num_epochs=10,
    shuffle=False
)

train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)
eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,
                                   steps=10,
                                   throttle_secs=0)

tf1.estimator.train_and_evaluate(estimator=classifier,
                                train_spec=train_spec,
                                eval_spec=eval_spec)
%ls {classifier.model_dir}

TensorFlow 2:使用 Model.fit 的 Keras 回调保存检查点#

在 TensorFlow 2 中,使用内置 Keras Model.fit(或 Model.evaluate)进行训练/评估时,可以配置 tf.keras.callbacks.ModelCheckpoint,然后将其传递给 Model.fit(或 Model.evaluate)的 callbacks 参数。(请在 API 文档和使用内置方法进行训练和评估指南中的使用回调部分中了解详情。)

在下面的示例中,您将使用 tf.keras.callbacks.ModelCheckpoint 回调将检查点存储在临时目录中:

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
  ])

model = create_model()
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'],
              steps_per_execution=10)

log_dir = tempfile.mkdtemp(dir=temp_dir)

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=f"{log_dir}/test.keras")

model.fit(x=x_train,
          y=y_train,
          epochs=10,
          validation_data=(x_test, y_test),
          callbacks=[model_checkpoint_callback])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[15], line 20
     15 log_dir = tempfile.mkdtemp(dir=temp_dir)
     17 model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
     18     filepath=f"{log_dir}/test.keras")
---> 20 model.fit(x=x_train,
     21           y=y_train,
     22           epochs=10,
     23           validation_data=(x_test, y_test),
     24           callbacks=[model_checkpoint_callback])

File /media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File /media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/context.py:657, in Context.ensure_initialized(self)
    654   pywrap_tfe.TFE_ContextOptionsSetRunEagerOpAsFunction(opts, True)
    655   pywrap_tfe.TFE_ContextOptionsSetJitCompileRewrite(
    656       opts, self._jit_compile_rewrite)
--> 657   context_handle = pywrap_tfe.TFE_NewContext(opts)
    658 finally:
    659   pywrap_tfe.TFE_DeleteContextOptions(opts)

RuntimeError: Bad StatusOr access: INTERNAL: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 25428426752
%ls {model_checkpoint_callback.filepath}

后续步骤#

在以下资源中详细了解检查点:

以下资源中详细了解回调:

此外,您可能还会发现下列与迁移相关的资源十分有用: