##### Copyright 2020 The TensorFlow Authors.
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
%cd ..
from set_env import temp_dir
/media/pc/data/lxw/ai/d2py/doc/libs/tf-chaos/guide

使用 tf.function 时提升性能#

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

在 TensorFlow 2 中,Eager Execution 默认处于启用状态。界面非常灵活直观(执行一次性运算要简单快速得多),不过,这可能对性能和可部署性造成一定影响。

您可以使用 tf.function 将程序转换为计算图。这是一个转换工具,用于从 Python 代码创建独立于 Python 的数据流图。它可以帮助您创建高效且可移植的模型,并且如果要使用 SavedModel,则必须使用此工具。

本指南介绍 tf.function 的底层工作原理,让您形成概念化理解,从而有效地加以利用。

要点和建议包括:

  • 先在 Eager 模式下调试,然后使用 @tf.function 进行装饰。

  • 不依赖 Python 副作用,如对象变异或列表追加。

  • tf.function 最适合处理 TensorFlow 运算;NumPy 和 Python 调用会转换为常量。

安装#

import tensorflow as tf

定义一个辅助函数来演示可能遇到的错误类型:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

基础知识#

用法#

您定义的 Function(例如,通过应用 @tf.function 装饰器)就像核心 TensorFlow 运算:您可以在 Eager 模式下执行它,可以计算梯度,等等。

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Function 中可以嵌套其他 Function

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function 的执行速度比 Eager 代码快,尤其是对于包含很多简单运算的计算图。但是,对于包含一些复杂运算(如卷积)的计算图,速度提升不会太明显。

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.013796081067994237
Function conv: 0.007262098835781217
Note how there's not much difference in performance for convolutions
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1729855710.179346 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.206193 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.219029 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.219816 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.254197 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.256597 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.269877 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.275167 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.275956 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.276898 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.277820 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.280412 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.283399 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.347206 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.350613 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.352031 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.355781 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.356716 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.364884 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.366005 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.370070 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.373444 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.377956 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.503679 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.504849 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.507080 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.508031 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.509096 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.509957 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.510725 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.511536 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.512716 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.514187 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.516705 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.518798 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.521376 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.522275 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.523279 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.524118 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced
W0000 00:00:1729855710.528098 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced

跟踪#

本部分介绍了 Function 的幕后运作方式,包括未来可能会发生变化的实现细节。但是,当您了解跟踪的原因和时间后,就能够更轻松高效地使用 tf.function

什么是“跟踪”?#

FunctionTensorFlow 计算图中运行您的程序。但是,tf.Graph 不能代表您在 Eager TensorFlow 程序中编写的全部内容。例如,Python 支持多态,但是 tf.Graph 要求其输入具有指定的数据类型和维度。或者,您可能执行辅助任务,例如读取命令行参数、引发错误或使用更复杂的 Python 对象。这些内容均不能在 tf.Graph 中运行。

Function 通过将代码分为以下两个阶段填补了这一空缺:

  1. 第一阶段称为跟踪,在这一阶段中,Function 会创建新的 tf.Graph。Python 代码可以正常运行,但是所有 TensorFlow 运算(例如添加两个张量)都会被推迟:它们会被 tf.Graph 捕获而不运行。

  2. 在第二阶段中,将运行包含第一阶段中推迟的全部内容的 tf.Graph。此阶段比跟踪阶段快得多。

根据输入,Function 在调用时并非总会运行第一阶段。请参阅下方的跟踪规则以更好地了解其决定方式。跳过第一阶段并仅执行第二阶段,可以实现 TensorFlow 的高性能。

Function 决定跟踪时,在跟踪阶段完成后会立即运行第二阶段,因此调用 Function 会创建并运行 tf.Graph。稍后,您将了解如何使用 get_concrete_function 来仅运行跟踪阶段。

当您将不同类型的参数传递给 Function 时,两个阶段都将运行:

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

请注意,如果重复使用同一参数类型调用 Function,TensorFlow 会跳过跟踪阶段并重用之前跟踪的计算图,因为后面的调用生成的计算图可能相同。

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

您可以使用 pretty_printed_concrete_signatures() 查看所有可用跟踪记录:

print(double.pretty_printed_concrete_signatures())
Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.int32, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.int32, name=None)
Captures:
  None

Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.float32, name=None)
Captures:
  None

Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.string, name=None)
Captures:
  None

目前,您已经了解 tf.function 通过 TensorFlow 的计算图跟踪逻辑创建缓存的动态调度层。对于术语的含义,更具体的解释如下:

  • tf.Graph 与语言无关,是 TensorFlow 计算的原始可移植表示。

  • ConcreteFunction 封装 tf.Graph

  • Function 管理 ConcreteFunction 的缓存,并为输入选择正确的缓存。

  • tf.function 封装 Python 函数,并返回一个 Function 对象。

  • 跟踪会创建 tf.Graph 并将其封装在 ConcreteFunction 中,也称为跟踪

跟踪规则#

被调用时,Function 使用每个参数的 tf.types.experimental.TraceType 将调用参数与现有的 ConcreteFunction 匹配。如果找到匹配的 ConcreteFunction,则将调用分派给它。如果未找到匹配项,则跟踪新的 ConcreteFunction

如果找到多个匹配项,则会选择最具体的签名。匹配是通过子类型化完成的,就像 C++ 或 Java 中的普通函数调用一样。例如,TensorShape([1, 2])TensorShape([None, None]) 的子类型,因此可以将使用 TensorShape([1, 2]) 对 tf.function 进行的调用分派到使用 TensorShape([None, None]) 生成的 ConcreteFunction。但是,如果具有 TensorShape([1, None])ConcreteFunction 也存在,那么它将被优先考虑,因为它更具体。

TraceType 由输入参数确定,具体如下所示:

  • 对于 Tensor,类型由 Tensordtypeshape 参数化;有秩形状是无秩形状的子类型;固定维度是未知维度的子类型

  • 对于 Variable,类型类似于 Tensor,但还包括变量的唯一资源 ID,这是正确连接控制依赖项所必需的

  • 对于 Python 基元值,类型对应于本身。例如,值为 3TraceTypeLiteralTraceType<3>,而不是 int

  • 对于 listtuple 等 Python 有序容器,类型是通过其元素的类型来参数化的;例如,[1, 2] 的类型是 ListTraceType<LiteralTraceType<1>, LiteralTraceType<2>>[2, 1] 的类型是 ListTraceType<LiteralTraceType<2>, LiteralTraceType<1>>,两者不同。

  • 对于 dict 等 Python 映射,类型也是从相同的键到值类型而不是实际值的映射。例如,{1: 2, 3: 4} 的类型为 MappingTraceType<<KeyValue<1, LiteralTraceType<2>>>, <KeyValue<3, LiteralTraceType<4>>>>。但是,与有序容器不同的是,{1: 2, 3: 4}{3: 4, 1: 2} 具有等价的类型。

  • 对于实现 __tf_tracing_type__ 方法的 Python 对象,类型为该方法返回的任何内容

  • 对于任何其他 Python 对象,类型是通用的 TraceType,匹配过程如下:

    • 首先,它检查该对象与先前跟踪中使用的对象是否相同(使用 id()is)。请注意,如果对象已更改,这仍然会匹配,因此如果您使用 Python 对象作为 tf.function 参数,最好使用不可变对象。

    • 接下来,它检查该对象是否等于先前跟踪中使用的对象(使用 python ==)。

    请注意,此过程仅保留对象的弱引用,因此仅在对象处于范围内/未被删除时有效。)

注:TraceType 基于 Function 输入参数,因此仅对全局变量和自由变量进行更改将不会创建新的跟踪记录。有关处理 Python 全局变量和自由变量的建议做法,请参阅本部分

控制回溯#

回溯即 Function 创建多个跟踪记录的过程,可以确保 TensorFlow 为每组输入生成正确的计算图。但是,跟踪非常消耗资源!如果 Function 为每一次调用都回溯新的计算图,您会发现代码的执行速度远不如不使用 tf.function 时快。

要控制跟踪行为,可以采用以下技巧:

将固定的 input_signature 传递给 tf.function#

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(TypeError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(TypeError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'TypeError'>:
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_4126223/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_4126223/3657259638.py", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2, 2), dtype=tf.int32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[1, 2],
       [3, 4]], dtype=int32)>,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
  File "/tmp/ipykernel_4126223/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_4126223/3657259638.py", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2,), dtype=tf.float32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 2.], dtype=float32)>,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).

使用未知维度以获得灵活性#

由于 TensorFlow 根据其形状匹配张量,因此,对于可变大小输入,使用 None 维度作为通配符可以让 Function 重复使用跟踪记录。对于每个批次,如果有不同长度的序列或不同大小的图像,则会出现可变大小输入(请参阅 TransformerDeep Dream 教程了解示例)。

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)

传递张量而不是 Python 文字#

通常,Python 参数用于控制超参数和计算图构造,例如 num_layers=10training=Truenonlinearity='relu'。所以,如果 Python 参数改变,则有必要回溯计算图。

但是,Python 参数有可能并未用于控制计算图构造。在这些情况下,Python 值的改变可能触发非必要的回溯。例如,在此训练循环中,AutoGraph 会动态展开。尽管有多个跟踪,但生成的计算图实际上是相同的,所以没有必要进行回溯。

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

如果需要强制执行回溯,可以创建一个新的 Function。单独的 Function 对象肯定不会共享跟踪记录。

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

使用跟踪协议#

在可能的情况下,您应当首选将 Python 类型转换为 tf.experimental.ExtensionType。此外,ExtensionTypeTraceType 是与其关联的 tf.TypeSpec。因此,如果需要,您只需重写默认的 tf.TypeSpec 即可控制 ExtensionTypeTracing Protocol。请参阅扩展程序类型指南中的自定义 ExtensionType 的 TypeSpec部分以了解详情。

否则,要直接控制 Function 何时应针对特定 Python 类型进行重新跟踪,您可以自行为其实现 Tracing Protocol

@tf.function
def get_mixed_flavor(fruit_a, fruit_b):
  return fruit_a.flavor + fruit_b.flavor

class Fruit:
  flavor = tf.constant([0, 0])

class Apple(Fruit):
  flavor = tf.constant([1, 2])

class Mango(Fruit):
  flavor = tf.constant([3, 4])

# As described in the above rules, a generic TraceType for `Apple` and `Mango`
# is generated (and a corresponding ConcreteFunction is traced) but it fails to
# match the second function call since the first pair of Apple() and Mango()
# have gone out out of scope by then and deleted.
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again

# However, each subclass of the `Fruit` class has a fixed flavor, and you
# can reuse an existing traced concrete function if it was the same
# subclass. Avoiding such unnecessary tracing of concrete functions
# can have significant performance benefits.

class FruitTraceType(tf.types.experimental.TraceType):
  def __init__(self, fruit):
    self.fruit_type = type(fruit)
    self.fruit_value = fruit

  def is_subtype_of(self, other):
      # True if self subtypes `other` and `other`'s type matches FruitTraceType.
      return (type(other) is FruitTraceType and
              self.fruit_type is other.fruit_type)

  def most_specific_common_supertype(self, others):
      # `self` is the specific common supertype if all input types match it.
      return self if all(self == other for other in others) else None

  def placeholder_value(self, placeholder_context=None):
      # Use the fruit itself instead of the type for correct tracing.
      return self.fruit_value

  def __eq__(self, other):
    return type(other) is FruitTraceType and self.fruit_type == other.fruit_type

  def __hash__(self):
    return hash(self.fruit_type)

class FruitWithTraceType:

  def __tf_tracing_type__(self, context):
    return FruitTraceType(self)

class AppleWithTraceType(FruitWithTraceType):
  flavor = tf.constant([1, 2])

class MangoWithTraceType(FruitWithTraceType):
  flavor = tf.constant([3, 4])

# Now if you try calling it again:
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Traces a new concrete function
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Re-uses the traced concrete function
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 6], dtype=int32)>

获取具体函数#

每次跟踪函数时都会创建一个新的具体函数。您可以使用 get_concrete_function 直接获取具体函数。

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)

打印 ConcreteFunction 会显示其输入参数(及类型)和输出类型的摘要。

print(double_strings)
ConcreteFunction Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.string, name=None)
Captures:
  None

您也可以直接检索具体函数的签名。

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

对不兼容的类型使用具体跟踪记录会引发错误

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py", line 442, in bind_function_inputs
    bound_arguments = function_type.bind_with_defaults(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/core/function/polymorphism/function_type.py", line 277, in bind_with_defaults
    with_default_args[arg_name] = constraint.cast(
                                  ^^^^^^^^^^^^^^^^
TypeError: Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1179, in _call_impl
    return self._call_with_structured_signature(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1259, in _call_with_structured_signature
    function_type_utils.canonicalize_function_inputs(
TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)`. Received args: (<tf.Tensor: shape=(), dtype=int32, numpy=1>,) and kwargs: {} for signature: (a: TensorSpec(shape=(), dtype=tf.string, name=None)) -> TensorSpec(shape=(), dtype=tf.string, name=None).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmp/ipykernel_4126223/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_4126223/3196284684.py", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_189 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_189]

您可能会注意到,在具体函数的输入签名中对 Python 参数进行了特别处理。TensorFlow 2.3 之前的版本会将 Python 参数直接从具体函数的签名中移除。从 TensorFlow 2.3 开始,Python 参数会保留在签名中,但是会受到约束,只能获取在跟踪期间设置的值。

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)
  b (POSITIONAL_OR_KEYWORD): Literal[2]
Output Type:
  TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)
Captures:
  None
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py", line 442, in bind_function_inputs
    bound_arguments = function_type.bind_with_defaults(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/core/function/polymorphism/function_type.py", line 277, in bind_with_defaults
    with_default_args[arg_name] = constraint.cast(
                                  ^^^^^^^^^^^^^^^^
ValueError: Can not cast 3 to Literal[2]

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1179, in _call_impl
    return self._call_with_structured_signature(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1259, in _call_with_structured_signature
    function_type_utils.canonicalize_function_inputs(
TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (<tf.Tensor: shape=(), dtype=float32, numpy=10.0>,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=<unknown>, dtype=tf.float32, name=None).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1182, in _call_impl
    return self._call_with_flat_signature(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1233, in _call_with_flat_signature
    raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmp/ipykernel_4126223/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_4126223/2310937119.py", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (<tf.Tensor: shape=(), dtype=float32, numpy=10.0>,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=<unknown>, dtype=tf.float32, name=None).
Fallback to flat signature also failed due to: pow(a) got unexpected keyword arguments: b.

获取计算图#

每个具体函数都是 tf.Graph 的可调用封装容器。虽然一般不需要检索实际 tf.Graph 对象,不过,您可以从任何具体函数轻松获得实际对象。

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

调试#

通常,在 Eager 模式下调试代码比在 tf.function 中简单。在使用 tf.function 进行装饰之前,您应该先确保代码可在 Eager 模式下无错误执行。为了帮助调试,您可以调用 tf.config.run_functions_eagerly(True) 来全局停用和重新启用 tf.function

追溯仅在 tf.function 中出现的问题时,可参考下面的几点提示:

  • 普通旧 Python print 调用仅在跟踪期间执行,可用于追溯(重新)跟踪函数的时间。

  • tf.print 调用每次都会执行,可用于追溯执行过程中产生的中间值。

  • 利用 tf.debugging.enable_check_numerics 很容易追溯到 NaN 和 Inf 在何处创建。

  • pdbPython 调试器)可以帮助您理解跟踪的详细过程。(提醒:使用 pdb 调试时,AutoGraph 会自动转换 Python 源代码。)

AutoGraph 转换#

AutoGraph 是一个库,在 tf.function 中默认处于启用状态。它可以将 Python Eager 代码的子集转换为与计算图兼容的 TensorFlow 运算。这包括 ifforwhile 等控制流。

tf.condtf.while_loop 等 TensorFlow 运算仍然可以运行,但是使用 Python 编写时,控制流通常更易于编写,代码也更易于理解。

# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.10295248 0.722364306 0.462540388 0.685418 0.410427094]
[0.102590263 0.618371665 0.43215242 0.595030427 0.388835251]
[0.102231853 0.549993277 0.407118559 0.53350389 0.370355666]
[0.101877175 0.500515163 0.386023313 0.488054901 0.354302764]
[0.101526156 0.462522238 0.367926896 0.45267123 0.34018591]
[0.101178743 0.432137698 0.352177054 0.424092293 0.327643335]
[0.100834884 0.407106251 0.338304847 0.400372326 0.316401631]
[0.100494511 0.386012852 0.325963169 0.380267531 0.306249619]
[0.100157566 0.367917836 0.314888835 0.362939775 0.297021389]
[0.0998239741 0.352169096 0.304878056 0.347800791 0.288584381]
[0.0994937 0.338297784 0.295770347 0.334423721 0.280831337]
[0.0991666913 0.325956881 0.287437141 0.322490036 0.273674309]
[0.0988428891 0.314883202 0.279774219 0.311756641 0.267040461]
[0.098522231 0.30487296 0.272696108 0.302034318 0.260868818]
[0.0982046872 0.295765668 0.266131759 0.293173164 0.255107969]
[0.0978902 0.287432849 0.260021746 0.285052747 0.249714166]
[0.0975787044 0.279770285 0.254315883 0.277575016 0.244649917]
[0.0972701609 0.272692442 0.248971328 0.270659208 0.239882946]
[0.0969645381 0.266128331 0.24395144 0.264238119 0.235385165]
[0.0966617838 0.260018587 0.239224538 0.258255273 0.23113212]
[0.0963618383 0.254312903 0.234763145 0.252662897 0.227102354]
[0.0960646719 0.248968542 0.230543256 0.247420177 0.223276913]
[0.0957702398 0.243948802 0.226543784 0.242492035 0.219639108]
[0.0954785049 0.23922205 0.222746134 0.237848178 0.216174051]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.09518941, 0.23476078, 0.21913388, 0.23346221, 0.21286845],
      dtype=float32)>

如果您有兴趣,可以检查 Autograph 生成的代码。

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            x, = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

条件语句#

AutoGraph 会将某些 if <condition> 语句转换为等效的 tf.cond 调用。如果 <condition> 是张量,则会执行这种替换,否则会将 if 语句作为 Python 条件语句执行。

Python 条件语句在跟踪时执行,因此会将该条件语句的一个分支添加到计算图。如果不使用 AutoGraph,当存在依赖于数据的控制流时,此跟踪计算图将无法选择替代分支。

tf.cond 跟踪并将条件的两个分支添加到计算图,在执行时动态选择分支。跟踪可能产生意外的副作用;请参阅 AutoGraph 跟踪作用以了解详情。

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

有关 AutoGraph 转换的 if 语句的其他限制,请参阅参考文档

循环#

AutoGraph 会将某些 forwhile 语句转换为等效的 TensorFlow 循环运算,例如 tf.while_loop。如果不转换,则会将 forwhile 循环作为 Python 循环执行。

以下情形会执行这种替换:

  • for x in y:如果 y 是一个张量,则转换为 tf.while_loop。在特殊情况下,如果 ytf.data.Dataset,则会生成 tf.data.Dataset 运算的组合。

  • while <condition>:如果 <condition> 是张量,则转换为 tf.while_loop

Python 循环在跟踪时执行,因而循环每迭代一次,都会将额外的运算添加到 tf.Graph

TensorFlow 循环会跟踪循环体,并在执行时动态选择迭代的运行次数。循环体仅在生成的 tf.Graph 中出现一次。

有关 AutoGraph 转换的 forwhile 语句的其他限制,请参阅参考文档

在 Python 数据上循环#

一个常见陷阱是在 tf.function 中的 Python/Numpy 数据上循环。此循环在跟踪过程中执行,因而循环每迭代一次,都会将模型的一个副本添加到 tf.Graph

如果要在 tf.function 中封装整个训练循环,最安全的方式是将数据封装为 tf.data.Dataset,以便 AutoGraph 动态展开训练循环。

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<_FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph
train(<_FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph

在数据集中封装 Python/Numpy 数据时,要注意 tf.data.Dataset.from_generator tf.data.Dataset.from_tensors。前者将数据保留在 Python 中,并通过 tf.py_function 获取,这可能会影响性能;后者将数据的副本捆绑成计算图中的一个大 tf.constant() 节点,这可能会消耗较多内存。

通过 TFRecordDatasetCsvDataset 等从文件中读取数据是最高效的数据使用方式,因为这样 TensorFlow 就可以自行管理数据的异步加载和预提取,不必利用 Python。要了解详细信息,请参阅 tf.data:构建 TensorFlow 输入流水线指南。

累加循环值#

一种常见模式是不断累加循环的中间值。通常,这可以通过将元素追加到 Python 列表或将条目添加到 Python 字典来实现。但是,由于存在 Python 副作用,在动态展开循环中,这些方式无法达到预期效果。要从动态展开循环累加结果,可以使用 tf.TensorArray 来实现。

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.04055965, 0.9615973 , 0.11646771, 0.37283874],
        [0.10482395, 0.99413276, 0.4927038 , 0.40920806],
        [0.79102623, 1.4659768 , 0.64847696, 0.44963348]],

       [[0.23183179, 0.27364147, 0.11120021, 0.5280762 ],
        [0.44744837, 0.5999154 , 0.2355485 , 0.9415574 ],
        [1.4402324 , 0.69936645, 0.7179471 , 1.5228264 ]]], dtype=float32)>

限制#

TensorFlow Function 有意设计了一些限制,在将 Python 函数转换为 Function 时需加以注意。

执行 Python 副作用#

副作用(如打印、附加到列表、改变全局变量)在 Function 内部可能会出现异常行为,有时会执行两次或完全无法执行。它们只会在您第一次使用一组输入调用 Function 时发生。之后,将重新执行跟踪的 tf.Graph,而不执行 Python 代码。

一般经验法则是避免在逻辑中依赖 Python 副作用,而仅使用它们来调试跟踪记录。否则,TensorFlow API(例如 tf.datatf.printtf.summarytf.Variable.assigntf.TensorArray)是确保在每次调用时 TensorFlow 运行时都能执行您的代码的最佳方式。

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

如果希望在每次调用 Function 时都执行 Python 代码,tf.py_function 可以作为退出点。tf.py_function 的缺点是不可移植,性能不高,无法使用 SavedModel 保存并且在分布式(多 GPU、TPU)设置中效果不佳。另外,由于 tf.py_function 必须连接到计算图中,它会将所有输入/输出转换为张量。

更改 Python 全局变量和自由变量#

更改 Python 全局变量和自由变量视为 Python 副作用,因此仅在跟踪期间发生。

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

有时很难注意到意外行为。在下面的示例中,counter 旨在保护变量的增量。然而,由于它是一个 Python 整数而不是 TensorFlow 对象,它的值在第一次跟踪期间被捕获。使用 tf.function 时,assign_add 将被无条件记录在底层计算图中。因此,每次调用 tf.functionv 都会增加 1。当使用 Python 副作用(示例中的 counter)确定要运行的运算(示例中的 assign_add)时,此问题在尝试使用 tf.function 装饰器将其计算图模式 Tensorflow 代码迁移到 Tensorflow 2 的用户中十分常见。通常,用户只有在看到可疑的数值结果或明显低于预期的性能(例如,如果受保护运算的开销非常大)后才会意识到这一点。

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 2, 3
1
2
3

实现预期行为的一种解决方法是使用 tf.init_scope 将运算提升到函数计算图以外。这样可以确保变量增量在跟踪期间只执行一次。应当注意的是,init_scope 还有其他副作用,包括清除控制流和梯度带。有时 init_scope 的使用会变得过于复杂而无法实际管理。

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.counter += 1
        self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1
1
1
1

总之,根据经验,您应避免改变整数或容器(如位于 Function 外部的列表)等 Python 对象,而应使用参数和 TF 对象。例如,在循环中累加值部分中提供了一个如何实现类列表运算的示例。

在某些情况下,如果为 tf.Variable,则您可以捕获和处理状态。这是通过重复调用相同的 ConcreteFunction 来更新 Keras 模型权重的方式。

使用 Python 迭代器和生成器#

很多 Python 功能(如生成器和迭代器)依赖 Python 运行时来跟踪状态。通常,虽然这些构造在 Eager 模式下可以正常工作,但它们是 Python 副作用的示例,因此仅在跟踪期间发生。

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

就像 TensorFlow 具有用于列表构造的专用 tf.TensorArray 一样,它也具有用于迭代构造的专用 tf.data.Iterator。有关概述,请参阅 AutoGraph 转换部分。此外,tf.data API 也可帮助实现生成器模式:

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

tf.function 的所有输出都必须是返回值#

除了 tf.Variable 外,一个 tf.function 必须返回其所有输出。尝试直接从函数访问任何张量而不遍历返回值会导致“泄漏”。

例如,下面的函数通过 Python 全局变量 x“泄漏”张量 a

x = None

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)
3
'SymbolicTensor' object has no attribute 'numpy'

即使同时返回泄漏的值时也是如此:

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

@tf.function
def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))
2
'SymbolicTensor' object has no attribute 'numpy'
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_4126223/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_4126223/566849597.py", line 21, in <module>
    captures_leaked_tensor(tf.constant(2))
TypeError: <tf.Tensor 'add:0' shape=() dtype=int32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

<tf.Tensor 'add:0' shape=() dtype=int32> was defined here:
    File "<frozen runpy>", line 198, in _run_module_as_main
    File "<frozen runpy>", line 88, in _run_code
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel_launcher.py", line 17, in <module>
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 701, in start
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 205, in start
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/asyncio/base_events.py", line 639, in run_forever
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/asyncio/base_events.py", line 1985, in _run_once
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/asyncio/events.py", line 88, in _run
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 534, in dispatch_queue
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 523, in process_one
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 429, in dispatch_shell
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 429, in do_execute
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    File "/tmp/ipykernel_4126223/566849597.py", line 7, in <module>
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 833, in __call__
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 889, in _call
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 696, in _initialize
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/framework/func_graph.py", line 1059, in func_graph_from_py_func
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 599, in wrapped_fn
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py", line 41, in autograph_handler
    File "/tmp/ipykernel_4126223/566849597.py", line 4, in leaky_function
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/framework/override_binary_operator.py", line 113, in binary_op_wrapper
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/ops/tensor_math_operator_overrides.py", line 28, in _add_dispatch_factory
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/util/dispatch.py", line 1260, in op_dispatch_handler
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/ops/math_ops.py", line 1701, in _add_dispatch
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/ops/gen_math_ops.py", line 490, in add_v2
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/framework/op_def_library.py", line 796, in _apply_op_helper
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/framework/func_graph.py", line 670, in _create_op_internal
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/framework/ops.py", line 2682, in _create_op_internal
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/framework/ops.py", line 1177, in from_node_def

The tensor <tf.Tensor 'add:0' shape=() dtype=int32> cannot be accessed from here, because it was defined in FuncGraph(name=leaky_function, id=140693262518592), which is out of scope.

通常,当您使用 Python 语句或数据结构时,会发生此类泄漏。除了泄漏不可访问的张量之外,此类语句也可能是错误的,因为它们被视为 Python 副作用,而且不能保证在每次函数调用时都执行。

泄漏局部张量的常见方法还包括改变外部 Python 集合或对象:

class MyClass:

  def __init__(self):
    self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
  a = tf.constant(1)
  external_list.append(a)  # Bad - leaks tensor
  external_object.field = a  # Bad - leaks tensor

不支持递归 tf.functions#

不支持递归 Function,它们可能导致无限循环。例如:

@tf.function
def recursive_fn(n):
  if n > 0:
    return recursive_fn(n - 1)
  else:
    return 1

with assert_raises(Exception):
  recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
Caught expected exception 
  <class 'Exception'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_4126223/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_4126223/2233998312.py", line 9, in <module>
    recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tensorflow.python.autograph.impl.api.StagingError: in user code:

    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/tmp/ipykernel_4126223/2233998312.py", line 3, in recursive_fn  *
        if n > 0:

    RecursionError: maximum recursion depth exceeded

即使递归 Function 看似有效,Python 函数也会被多次跟踪,并且可能会对性能产生影响。例如:

@tf.function
def recursive_fn(n):
  if n > 0:
    print('tracing')
    return recursive_fn(n - 1)
  else:
    return 1

recursive_fn(5)  # Warning - multiple tracings
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>

已知问题#

如果您的 Function 评估不正确,则这些计划于将来得到修复的已知问题可能可以解释该问题。

取决于 Python 全局变量和自由变量#

当使用 Python 参数的新值进行调用时,Function 会创建新的 ConcreteFunction。但是,对于该 Function 的 Python 闭包、全局变量或非局部变量,则不会创建。如果它们的值在调用 Function 之间发生变化,则 Function 仍将使用其在跟踪时所具有的值。这与常规 Python 函数的工作方式不同。

因此,您应采用使用参数的函数式编程风格而非闭合外部名称。

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

更新全局值的另一种方法是使其成为 tf.Variable 并改用 Variable.assign 方法。

@tf.function
def variable_add():
  return 1 + foo

foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)

依赖于 Python 对象#

支持将自定义 Python 对象作为参数传递给 tf.function,但有一定的限制。

为了获得最大的特征覆盖率,请考虑在将对象传递给 tf.function 之前将其转换为扩展类型。此外,您也可以使用 Python 基元以及与 tf.nest 兼容的结构。

但是,正如跟踪规则中所述,当自定义 Python 类未提供自定义 TraceType 时,tf.function 被迫使用基于实例的相等性,这意味着当您传递具有修改特性的同一对象时,它将不会创建新的跟踪记录

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

使用相同的 Function 评估模型的修改实例并不合理,因为它仍然具有与原始模型相同的基于实例的 TraceType

因此,建议您编写 Function 以避免依赖于可变对象特性,或者为对象实现跟踪协议以将此类特性通知给 Function

如果这不可行,则一种解决方法是,每次修改对象时都创建新的 Function 以强制回溯:

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

回溯可能十分耗费资源,您可以使用 tf.Variable 作为对象特性,可以对其进行改变(但非更改,请注意!) 以在无需回溯的情况下实现相似效果。

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

创建 tf.Variables#

Function 仅支持在第一次调用时创建一次,并且在后续函数调用中重复使用的单例 tf.Variable。下面的代码段会在每个函数调用中创建一个新的 tf.Variable,这会导致 ValueError 异常。

示例:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_4126223/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_4126223/3018268426.py", line 7, in <module>
    f(1.0)
ValueError: in user code:

    File "/tmp/ipykernel_4126223/3018268426.py", line 3, in f  *
        v = tf.Variable(1.0)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

用于解决这种限制的常见模式是从 Python None 值开始,随后,在值为 None 时,有条件地创建 tf.Variable

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

与多个 Keras 优化器一起使用#

将多个 Keras 优化器与 tf.function 一起使用时,您可能会遇到 ValueError: tf.function only supports singleton tf.Variables created on the first call.。发生此错误的原因是优化器在首次应用梯度时会在内部创建 tf.Variables

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_4126223/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_4126223/950644149.py", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    File "/tmp/ipykernel_4126223/950644149.py", line 9, in train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/optimizers/base_optimizer.py", line 291, in apply_gradients  **
        self.apply(grads, trainable_variables)
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/optimizers/base_optimizer.py", line 330, in apply
        self.build(trainable_variables)
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/optimizers/adam.py", line 97, in build
        self.add_variable_from_reference(
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/backend/tensorflow/optimizer.py", line 36, in add_variable_from_reference
        return super().add_variable_from_reference(
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/optimizers/base_optimizer.py", line 227, in add_variable_from_reference
        return self.add_variable(
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/optimizers/base_optimizer.py", line 201, in add_variable
        variable = backend.Variable(
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/backend/common/variables.py", line 163, in __init__
        self._initialize_with_initializer(initializer)
    File "/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/backend/tensorflow/core.py", line 40, in _initialize_with_initializer
        self._value = tf.Variable(

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

如果您需要在训练期间更改优化器,一种解决方法是为每个优化器创建一个新的 Function,直接调用 ConcreteFunction

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step)
train_step_2 = tf.function(train_step)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y, opt1)
  else:
    train_step_2(w, x, y, opt2)

与多个 Keras 模型一起使用#

将不同的模型实例传递给同一 Function 时,您也可能会遇到 ValueError: tf.function only supports singleton tf.Variables created on the first call.

发生此错误的原因是 Keras 模型(未定义其输入形状)和 Keras 层会在首次调用时创建 tf.Variables。您可能正在尝试在已调用的 Function 中初始化这些变量。为避免此错误,请在训练模型之前尝试调用 model.build(input_shape) 以初始化所有权重。

延伸阅读#

要了解如何导出和加载 Function,请参阅 SavedModel 指南。要详细了解跟踪后执行的计算图优化,请参阅 Grappler 指南。要了解如何优化数据流水线和剖析模型性能,请参阅 Profiler 指南