##### Copyright 2021 The TensorFlow Authors.
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
验证正确性和数值等价性#
将 TensorFlow 代码从 TF1.x 迁移到 TF2 时,最好确保迁移的代码在 TF2 中的行为方式与在 TF1.x 中的行为方式相同。
本指南涵盖了将 tf.compat.v1.keras.utils.track_tf1_style_variables
建模填充码应用于 tf.keras.layers.Layer
方法的迁移代码示例。阅读模型映射指南来详细了解 TF2 建模填充码。
本指南详细介绍了可以使用的方式:
使用迁移的代码验证从训练模型中获得的结果的正确性
验证您的代码在不同 TensorFlow 版本中的数值等价性
安装#
!pip uninstall -y -q tensorflow
# Install tf-nightly as the DeterministicRandomTestTool is available only in
# Tensorflow 2.8
!pip install -q tf-nightly
!pip install -q tf_slim
import tensorflow as tf
import tensorflow.compat.v1 as v1
import numpy as np
import tf_slim as slim
import sys
from contextlib import contextmanager
!git clone --depth=1 https://github.com/tensorflow/models.git
import models.research.slim.nets.inception_resnet_v2 as inception
当您将一个非常重要的前向传递代码块放入填充码时,您希望知道它的行为方式与在 TF1.x 中的行为方式是否相同。例如,考虑尝试将整个 TF-Slim Inception-Resnet-v2 模型放入填充码,如下所示:
# TF1 Inception resnet v2 forward pass based on slim layers
def inception_resnet_v2(inputs, num_classes, is_training):
with slim.arg_scope(
inception.inception_resnet_v2_arg_scope(batch_norm_scale=True)):
return inception.inception_resnet_v2(inputs, num_classes, is_training=is_training)
class InceptionResnetV2(tf.keras.layers.Layer):
"""Slim InceptionResnetV2 forward pass as a Keras layer"""
def __init__(self, num_classes, **kwargs):
super().__init__(**kwargs)
self.num_classes = num_classes
@tf.compat.v1.keras.utils.track_tf1_style_variables
def call(self, inputs, training=None):
is_training = training or False
# Slim does not accept `None` as a value for is_training,
# Keras will still pass `None` to layers to construct functional models
# without forcing the layer to always be in training or in inference.
# However, `None` is generally considered to run layers in inference.
with slim.arg_scope(
inception.inception_resnet_v2_arg_scope(batch_norm_scale=True)):
return inception.inception_resnet_v2(
inputs, self.num_classes, is_training=is_training)
碰巧的是,此层实际上完美实现了开箱即用(完成了准确的正则化损失跟踪)。
但是,这不是您想认为理所当然的事情。按照下面的步骤验证它的行为是否与在 TF1.x 中的行为一样,直至观察到完美的数值等价。这些步骤还可以帮助您定位前向传递的哪一部分导致与 TF1.x 间的散度(确定散度是否出现在模型前向传递中,而不是模型的不同部分)。
第 1 步:验证变量是否只创建一次#
应当验证的第一件事是您已经以在每次调用中重用变量的方式正确地构建了模型,而不是每次都意外地创建和使用新变量。例如,如果模型创建一个新的 Keras 层或在每个前向传递调用中调用 tf.Variable
,那么它很可能无法捕获变量并每次都创建新变量。
下面是两个上下文管理器范围,可以使用它们来检测模型何时创建新变量并调试模型的哪个部分正在执行它。
@contextmanager
def assert_no_variable_creations():
"""Assert no variables are created in this context manager scope."""
def invalid_variable_creator(next_creator, **kwargs):
raise ValueError("Attempted to create a new variable instead of reusing an existing one. Args: {}".format(kwargs))
with tf.variable_creator_scope(invalid_variable_creator):
yield
@contextmanager
def catch_and_raise_created_variables():
"""Raise all variables created within this context manager scope (if any)."""
created_vars = []
def variable_catcher(next_creator, **kwargs):
var = next_creator(**kwargs)
created_vars.append(var)
return var
with tf.variable_creator_scope(variable_catcher):
yield
if created_vars:
raise ValueError("Created vars:", created_vars)
一旦您尝试在范围内创建变量,第一个范围 (assert_no_variable_creations()
) 将立即引发错误。这样,您就可以检查堆栈跟踪(并使用交互式调试)来准确确定哪些代码行创建了变量,而不是重用现有变量。
如果最终创建了任何变量,则第二个范围 (catch_and_raise_created_variables()
) 将在范围结束时引发异常。此异常将包括在范围内创建的所有变量的列表。假如您可以发现一般模式,那么这对于查明您的模型正在创建的所有权重集是什么非常有用。但是,它对于确定创建这些变量的确切代码行的用处不大。
使用下面的两个范围来验证基于填充码的 InceptionResnetV2 层在第一次调用后是否会创建任何新变量(可能是重用它们)。
model = InceptionResnetV2(1000)
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
# Create all weights on the first call
model(inputs)
# Verify that no new weights are created in followup calls
with assert_no_variable_creations():
model(inputs)
with catch_and_raise_created_variables():
model(inputs)
在下面的示例中,观察这些装饰器如何在每次错误地创建新权重而不是重用现有权重的层上工作。
class BrokenScalingLayer(tf.keras.layers.Layer):
"""Scaling layer that incorrectly creates new weights each time:"""
@tf.compat.v1.keras.utils.track_tf1_style_variables
def call(self, inputs):
var = tf.Variable(initial_value=2.0)
bias = tf.Variable(initial_value=2.0, name='bias')
return inputs * var + bias
model = BrokenScalingLayer()
inputs = tf.ones( (1, height, width, 3))
model(inputs)
try:
with assert_no_variable_creations():
model(inputs)
except ValueError as err:
import traceback
traceback.print_exc()
model = BrokenScalingLayer()
inputs = tf.ones( (1, height, width, 3))
model(inputs)
try:
with catch_and_raise_created_variables():
model(inputs)
except ValueError as err:
print(err)
可以通过确保它只创建一次权重然后每次重用来修正该层。
class FixedScalingLayer(tf.keras.layers.Layer):
"""Scaling layer that incorrectly creates new weights each time:"""
def __init__(self):
super().__init__()
self.var = None
self.bias = None
@tf.compat.v1.keras.utils.track_tf1_style_variables
def call(self, inputs):
if self.var is None:
self.var = tf.Variable(initial_value=2.0)
self.bias = tf.Variable(initial_value=2.0, name='bias')
return inputs * self.var + self.bias
model = FixedScalingLayer()
inputs = tf.ones( (1, height, width, 3))
model(inputs)
with assert_no_variable_creations():
model(inputs)
with catch_and_raise_created_variables():
model(inputs)
问题排查#
下面是模型可能会意外创建新权重而不是重用现有权重的一些常见原因:
它使用显式
tf.Variable
调用而不是重用已经创建的tf.Variables
。先检查是否尚未创建,然后重用现有选项可以修正此问题。它每次都直接在前向传递中创建一个 Keras 层或模型(相对于
tf.compat.v1.layers
)。先检查是否尚未创建,然后重用现有选项可以修正此问题。它在
tf.compat.v1.layers
之上构建,但未能为所有compat.v1.layers
分配显式名称或将compat.v1.layer
用法包装在一个命名的variable_scope
内,导致自动生成的层名称在每个模型调用中递增。将命名的tf.compat.v1.variable_scope
置于包装所有tf.compat.v1.layers
用法的填充码装饰方法内可以修正此问题。
第 2 步:检查变量计数、名称和形状是否匹配#
第二步是确保在 TF2 中运行的层创建具有相同形状的相同数量权重,就像 TF1.x 中的相应代码一样。
您可以混合手动检查它们以查看其是否匹配,并在单元测试中以编程方式执行检查,代码如下所示。
# Build the forward pass inside a TF1.x graph, and
# get the counts, shapes, and names of the variables
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)
tf1_variable_names_and_shapes = {
var.name: (var.trainable, var.shape) for var in tf.compat.v1.global_variables()}
num_tf1_variables = len(tf.compat.v1.global_variables())
接下来,对 TF2 中的填充码包装层执行相同的过程。请注意,在获取权重之前,模型也会被多次调用。这样做是为了高效地测试变量重用。
height, width = 299, 299
num_classes = 1000
model = InceptionResnetV2(num_classes)
# The weights will not be created until you call the model
inputs = tf.ones( (1, height, width, 3))
# Call the model multiple times before checking the weights, to verify variables
# get reused rather than accidentally creating additional variables
out, endpoints = model(inputs, training=False)
out, endpoints = model(inputs, training=False)
# Grab the name: shape mapping and the total number of variables separately,
# because in TF2 variables can be created with the same name
num_tf2_variables = len(model.variables)
tf2_variable_names_and_shapes = {
var.name: (var.trainable, var.shape) for var in model.variables}
# Verify that the variable counts, names, and shapes all match:
assert num_tf1_variables == num_tf2_variables
assert tf1_variable_names_and_shapes == tf2_variable_names_and_shapes
基于填充码的 InceptionResnetV2 层通过了此测试。但是,在它们不匹配的情况下,可以通过 diff(文本或其他)运行它来查看有何差异。
这样可以提供关于模型哪些部分的行为不符合预期的线索。通过 Eager Execution,可以使用 pdb、交互式调试和断点来挖掘模型中看起来可疑的部分,并更深入地调试出现的问题。
问题排查#
密切注意由显式
tf.Variable
调用和 Keras 层/模型直接创建的任何变量的名称,因为它们的变量名称生成语义在 TF1.x 计算图和 TF2 功能(例如 Eager Execution 和tf.function
)之间可能略有不同,即使其他一切正常。如果是这种情况,请调整您的测试以适应任何稍有不同的命名语义。您有时可能会发现,在训练循环的前向传递中创建的
tf.Variable
、tf.keras.layers.Layer
或tf.keras.Model
在 TF2 变量列表中缺失,即使它们是由 TF1.x 中的变量集合捕获的。将前向传递创建的变量/层/模型分配给模型中的实例特性可以修正此问题。请参阅此处了解详情。
第 3 步:重置所有变量,在停用所有随机性的情况下检查数值等价性#
下一步是在您修正模型时验证实际输出和正则化损失跟踪的数值等价性,这样便不涉及随机数生成(例如在推断期间)。
完成此操作的确切方式可能取决于您的特定模型,但在大多数模型(例如此模型)中,可以通过下列方式实现此目标:
将权重初始化为没有随机性的相同值。这可以通过在创建后将它们重置为固定值来完成。
在推断模式下运行模型以避免触发任何可能成为随机性来源的随机失活层。
以下代码演示了如何以这种方式比较 TF1.x 和 TF2 结果。
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)
# Rather than running the global variable initializers,
# reset all variables to a constant value
var_reset = tf.group([var.assign(tf.ones_like(var) * 0.001) for var in tf.compat.v1.global_variables()])
sess.run(var_reset)
# Grab the outputs & regularization loss
reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
tf1_output = sess.run(out)
print("Regularization loss:", tf1_regularization_loss)
tf1_output[0][:5]
获取 TF2 结果。
height, width = 299, 299
num_classes = 1000
model = InceptionResnetV2(num_classes)
inputs = tf.ones((1, height, width, 3))
# Call the model once to create the weights
out, endpoints = model(inputs, training=False)
# Reset all variables to the same fixed value as above, with no randomness
for var in model.variables:
var.assign(tf.ones_like(var) * 0.001)
tf2_output, endpoints = model(inputs, training=False)
# Get the regularization loss
tf2_regularization_loss = tf.math.add_n(model.losses)
print("Regularization loss:", tf2_regularization_loss)
tf2_output[0][:5]
# Create a dict of tolerance values
tol_dict={'rtol':1e-06, 'atol':1e-05}
# Verify that the regularization loss and output both match
# when we fix the weights and avoid randomness by running inference:
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)
当您移除随机性来源时,TF1.x 与 TF2 之间的数值匹配,并且与 TF2 兼容的 InceptionResnetV2
层通过了测试。
如果您正在观察与自己的模型存在分歧的结果,可以使用打印或 pdb 和交互式调试来确定结果开始出现分歧的位置和原因。Eager Execution 可以让此过程变得更容易。此外,还可以使用消融方式仅对固定的中间输入运行模型的一小部分,并隔离出现散度的位置。
方便的是,许多填充码网络(和其他模型)也会显露您可以探测的中间端点。
第 4 步:对齐随机数生成,检查训练和推断中的数值等价性#
最后一步是验证 TF2 模型在数值上是否与 TF1.x 模型匹配,即使在变量初始化和前向传递本身(例如前向传递期间的随机失活层)中考虑随机数生成时,也是如此。
可以通过使用下面的测试工具来使随机数生成语义在 TF1.x 计算图/会话与 Eager Execution 之间匹配。
TF1 旧版计算图/会话和 TF2 Eager Execution 使用不同的有状态随机数生成语义。
在 tf.compat.v1.Session
中,如果没有指定种子,则随机数的生成取决于添加随机运算时计算图中有多少运算,以及该计算图运行了多少次。在 Eager Execution 中,有状态随机数生成取决于全局种子、运算随机种子以及带有给定随机种子的运算的运行次数。有关详情,请参阅 tf.random.set_seed
。
以下 v1.keras.utils.DeterministicRandomTestTool
类提供了一个上下文管理器 scope()
,它可以使有状态随机运算在 TF1 计算图/会话和 Eager Execution 中使用相同的种子。
此工具提供了两种测试模式:
constant
,无论被调用过多少次,都会为每个单一运算使用相同的种子,以及num_random_ops
,使用先前观测到的有状态随机运算的数量作为运算种子。
这既适用于用于创建和初始化变量的有状态随机运算,也适用于计算中使用的有状态随机运算(例如用于随机失活层)。
生成三个随机张量来展示如何使用此工具在会话和 Eager Execution 之间进行有状态随机数生成匹配。
random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
a = tf.random.uniform(shape=(3,1))
a = a * 3
b = tf.random.uniform(shape=(3,3))
b = b * 3
c = tf.random.uniform(shape=(3,3))
c = c * 3
graph_a, graph_b, graph_c = sess.run([a, b, c])
graph_a, graph_b, graph_c
random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
a = tf.random.uniform(shape=(3,1))
a = a * 3
b = tf.random.uniform(shape=(3,3))
b = b * 3
c = tf.random.uniform(shape=(3,3))
c = c * 3
a, b, c
# Demonstrate that the generated random numbers match
np.testing.assert_allclose(graph_a, a.numpy(), **tol_dict)
np.testing.assert_allclose(graph_b, b.numpy(), **tol_dict)
np.testing.assert_allclose(graph_c, c.numpy(), **tol_dict)
但请注意,在 constant
模式下,由于 b
和 c
是使用相同的种子生成并且具有相同的形状,它们将具有完全相同的值。
np.testing.assert_allclose(b.numpy(), c.numpy(), **tol_dict)
跟踪顺序#
如果您担心在 constant
模式下匹配的某些随机数会降低对数值等价性测试的信心(例如,如果多个权重采用相同的初始化),则可以使用 num_random_ops
模式来避免这种情况。在 num_random_ops
模式下,生成的随机数将取决于程序中随机运算的顺序。
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
a = tf.random.uniform(shape=(3,1))
a = a * 3
b = tf.random.uniform(shape=(3,3))
b = b * 3
c = tf.random.uniform(shape=(3,3))
c = c * 3
graph_a, graph_b, graph_c = sess.run([a, b, c])
graph_a, graph_b, graph_c
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
a = tf.random.uniform(shape=(3,1))
a = a * 3
b = tf.random.uniform(shape=(3,3))
b = b * 3
c = tf.random.uniform(shape=(3,3))
c = c * 3
a, b, c
# Demonstrate that the generated random numbers match
np.testing.assert_allclose(graph_a, a.numpy(), **tol_dict)
np.testing.assert_allclose(graph_b, b.numpy(), **tol_dict )
np.testing.assert_allclose(graph_c, c.numpy(), **tol_dict)
# Demonstrate that with the 'num_random_ops' mode,
# b & c took on different values even though
# their generated shape was the same
assert not np.allclose(b.numpy(), c.numpy(), **tol_dict)
但请注意,在这种模式下,随机生成对程序顺序非常敏感,因此下面生成的随机数不匹配。
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
a = tf.random.uniform(shape=(3,1))
a = a * 3
b = tf.random.uniform(shape=(3,3))
b = b * 3
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
b_prime = tf.random.uniform(shape=(3,3))
b_prime = b_prime * 3
a_prime = tf.random.uniform(shape=(3,1))
a_prime = a_prime * 3
assert not np.allclose(a.numpy(), a_prime.numpy())
assert not np.allclose(b.numpy(), b_prime.numpy())
为了允许调试由于跟踪顺序而导致的变化,num_random_ops
模式下的 DeterministicRandomTestTool
允许查看使用 operation_seed
属性跟踪了多少随机运算。
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
print(random_tool.operation_seed)
a = tf.random.uniform(shape=(3,1))
a = a * 3
print(random_tool.operation_seed)
b = tf.random.uniform(shape=(3,3))
b = b * 3
print(random_tool.operation_seed)
如果需要在测试中考虑不同的跟踪顺序,甚至可以显式设置自动递增 operation_seed
。例如,可以使用它来使随机数生成在两个不同的程序顺序之间匹配。
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
print(random_tool.operation_seed)
a = tf.random.uniform(shape=(3,1))
a = a * 3
print(random_tool.operation_seed)
b = tf.random.uniform(shape=(3,3))
b = b * 3
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
random_tool.operation_seed = 1
b_prime = tf.random.uniform(shape=(3,3))
b_prime = b_prime * 3
random_tool.operation_seed = 0
a_prime = tf.random.uniform(shape=(3,1))
a_prime = a_prime * 3
np.testing.assert_allclose(a.numpy(), a_prime.numpy(), **tol_dict)
np.testing.assert_allclose(b.numpy(), b_prime.numpy(), **tol_dict)
但是,DeterministicRandomTestTool
不允许重用已使用的运算种子,因此请确保自动递增的序列不能重叠。这是因为 Eager Execution 会为相同运算种子的后续使用生成不同的数值,而 TF1 计算图和会话则不会,因此引发错误有助于保持会话和 Eager 有状态随机数生成一致。
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
random_tool.operation_seed = 1
b_prime = tf.random.uniform(shape=(3,3))
b_prime = b_prime * 3
random_tool.operation_seed = 0
a_prime = tf.random.uniform(shape=(3,1))
a_prime = a_prime * 3
try:
c = tf.random.uniform(shape=(3,1))
raise RuntimeError("An exception should have been raised before this, " +
"because the auto-incremented operation seed will " +
"overlap an already-used value")
except ValueError as err:
print(err)
验证推断#
现在,您可以使用 DeterministicRandomTestTool
来确保 InceptionResnetV2
模型在推断中匹配,即使在使用随机权重初始化时也是如此。对于因匹配程序顺序而获得的更强测试条件,请使用 num_random_ops
模式。
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)
# Initialize the variables
sess.run(tf.compat.v1.global_variables_initializer())
# Grab the outputs & regularization loss
reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
tf1_output = sess.run(out)
print("Regularization loss:", tf1_regularization_loss)
height, width = 299, 299
num_classes = 1000
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
model = InceptionResnetV2(num_classes)
inputs = tf.ones((1, height, width, 3))
tf2_output, endpoints = model(inputs, training=False)
# Grab the regularization loss as well
tf2_regularization_loss = tf.math.add_n(model.losses)
print("Regularization loss:", tf2_regularization_loss)
# Verify that the regularization loss and output both match
# when using the DeterministicRandomTestTool:
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)
验证训练#
由于 DeterministicRandomTestTool
适用于所有有状态随机运算(包括权重初始化和诸如随机失活层之类的计算),您也可以使用它来验证模型在训练模式下是否匹配。可以再次使用 num_random_ops
模式,因为有状态随机运算的程序顺序匹配。
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=True)
# Initialize the variables
sess.run(tf.compat.v1.global_variables_initializer())
# Grab the outputs & regularization loss
reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
tf1_output = sess.run(out)
print("Regularization loss:", tf1_regularization_loss)
height, width = 299, 299
num_classes = 1000
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
model = InceptionResnetV2(num_classes)
inputs = tf.ones((1, height, width, 3))
tf2_output, endpoints = model(inputs, training=True)
# Grab the regularization loss as well
tf2_regularization_loss = tf.math.add_n(model.losses)
print("Regularization loss:", tf2_regularization_loss)
# Verify that the regularization loss and output both match
# when using the DeterministicRandomTestTool
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)
现在,您已经验证在 tf.keras.layers.Layer
周围使用装饰器以 Eager 方式运行的 InceptionResnetV2
模型在数值上与在 TF1 计算图和会话中运行的填充码网络匹配。
注:在 num_random_ops
模式下使用 DeterministicRandomTestTool
时,建议在测试数值等价性时直接使用和调用 tf.keras.layers.Layer
方法装饰器。将其嵌入到 Keras 函数模型或其他 Keras 模型中可能会在有状态随机运算跟踪顺序中产生差异,这会导致在比较 TF1.x 计算图/会话和 Eager Execution 时难以做出推断或精确匹配。
例如,使用 training=True
直接调用 InceptionResnetV2
层会根据网络创建顺序将变量初始化与随机失活顺序交错。
另一方面,首先将 tf.keras.layers.Layer
装饰器放入 Keras 函数模型中,随后使用 training=True
调用模型,这相当于先初始化所有变量然后使用随机失活层。这会产生不同的跟踪顺序和一组不同的随机数。
但是,默认的 mode='constant'
对跟踪顺序的这些差异不敏感,即使将层嵌入到 Keras 函数模型中,也无需额外工作即可传递。
random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=True)
# Initialize the variables
sess.run(tf.compat.v1.global_variables_initializer())
# Get the outputs & regularization losses
reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
tf1_output = sess.run(out)
print("Regularization loss:", tf1_regularization_loss)
height, width = 299, 299
num_classes = 1000
random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
keras_input = tf.keras.Input(shape=(height, width, 3))
layer = InceptionResnetV2(num_classes)
model = tf.keras.Model(inputs=keras_input, outputs=layer(keras_input))
inputs = tf.ones((1, height, width, 3))
tf2_output, endpoints = model(inputs, training=True)
# Get the regularization loss
tf2_regularization_loss = tf.math.add_n(model.losses)
print("Regularization loss:", tf2_regularization_loss)
# Verify that the regularization loss and output both match
# when using the DeterministicRandomTestTool
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)
第 3b 或 4b 步(可选):使用既有检查点进行测试#
在上面的第 3 步或第 4 步之后,如果您有一些基于名称的既有检查点,不妨运行您的数值等价性测试。这可以测试旧检查点加载是否正常执行以及模型本身是否正常工作。重用 TF1.x 检查点指南介绍了如何重用既有 TF1.x 检查点并将它们转移到 TF2 检查点。
附加测试和问题排查#
当您添加更多数值等价性测试时,还可以选择添加一个验证梯度计算(甚至优化器更新)是否匹配的测试。
反向传播和梯度计算比模型前向传递更容易出现浮点数值不稳定性。这意味着,当等价性测试涵盖训练中更多非孤立的部分时,您可能会开始看到完全以 Eager 方式运行与 TF1 计算图之间存在非常重要的数值差异。这可能是由 TensorFlow 的计算图优化引起的,这种优化执行诸如用较少的数学运算替换图计算图中的子表达式之类的事情。
为了判断是否可能是这种情况,可以将 TF1 代码与 tf.function
内部发生的 TF2 计算进行比较(它像您的 TF1 计算图一样应用计算图优化传递),而不是纯粹的 Eager 计算。或者,也可以尝试在 TF1 计算之前使用 tf.config.optimizer.set_experimental_options
停用优化传递(例如 "arithmetic_optimization"
)以查看结果是否在数值上更接近 TF2 计算结果。在实际训练运行中,出于性能原因,建议使用启用优化传递的 tf.function
,但您可能会发现,在数值等价性单元测试中停用它们更有用。
同样,您可能还会发现 tf.compat.v1.train
优化器与 TF2 优化器的浮点数值属性略有不同,即使它们表示的数学公式相同。这在训练运行中不太可能成为问题,但在等价性单元测试中可能需要更高的数值容差。