##### Copyright 2021 The TensorFlow Authors.
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

扩展程序类型#

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本

安装#

!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile

扩展程序类型#

用户定义的类型可以使项目的可读性、模块化、可维护程度更高。但是,大多数 TensorFlow API 对于用户定义的 Python 类型的支持却非常有限。这包括高级 API(如 Kerastf.functiontf.SavedModel)和低级 API(如 tf.while_looptf.concat)。TensorFlow 扩展程序类型可用于创建能够与 TensorFlow 的 API 无缝协作的用户定义的面向对象类型。要创建扩展程序类型,只需定义一个以 tf.experimental.ExtensionType 为基础的 Python 类,并使用类型注解来指定每个字段的类型。

class TensorGraph(tf.experimental.ExtensionType):
  """A collection of labeled nodes connected by weighted edges."""
  edge_weights: tf.Tensor               # shape=[num_nodes, num_nodes]
  node_labels: Mapping[str, tf.Tensor]  # shape=[num_nodes]; dtype=any

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for missing/invalid values.

class CSRSparseMatrix(tf.experimental.ExtensionType):
  """Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
  values: tf.Tensor     # shape=[num_nonzero]; dtype=any
  col_index: tf.Tensor  # shape=[num_nonzero]; dtype=int64
  row_index: tf.Tensor  # shape=[num_rows+1]; dtype=int64

tf.experimental.ExtensionType 基类的工作方式类似于标准 Python 库中的 typing.NamedTuple@dataclasses.dataclass。特别是,它会根据字段类型注解自动添加构造函数和特殊方法(例如 __repr____eq__)。

通常,扩展程序类型往往属于以下两个类别之一:

  • 数据结构,会将一组相关的值组合在一起,并且可以基于这些值提供有用的运算。数据结构可以十分常规(例如上面的 TensorGraph 示例),也可以针对特定模型进行高度定制。

  • 类张量类型,限定或延伸了“张量”的概念。此类别中的类型具有 rankshape,通常还有 dtype;并且将它们与张量运算(例如 tf.stacktf.addtf.matmul)一起使用是合理的。MaskedTensorCSRSparseMatrix 是类张量类型的示例。

支持的 API#

以下 TensorFlow API 支持扩展程序类型:

  • Keras:扩展程序类型可以用作 Keras ModelsLayers 的输入和输出。

  • tf.data.Dataset:扩展程序类型可以包含在 Datasets 中,并由数据集 Iterators 返回。

  • TensorFlow Hub:扩展程序类型可以用作 tf.hub 模块的输入和输出。

  • SavedModel:扩展程序类型可以用作 SavedModel 函数的输入和输出。

  • tf.function:扩展程序类型可以用作使用 @tf.function 装饰器包装的函数的参数和返回值。

  • While 循环:扩展程序类型可以用作 tf.while_loop 中的循环变量,也可以用作 while 循环体的参数和返回值。

  • 条件:可以使用 tf.condtf.case 有条件地选择扩展程序类型。

  • tf.py_function:扩展程序类型可以用作 tf.py_function 的参数以及针对 func 参数的返回值。

  • 张量运算:扩展程序类型可扩展以支持大多数接受张量输入的 TensorFlow 运算(例如,tf.matmultf.gathertf.reduce_sum)。如需了解详情,请转到下面的调度部分。

  • 分布策略:扩展程序类型可以用作按副本值。

有关详情,请参阅下面的“支持 ExtensionType 的 TensorFlow API”部分。

要求#

字段类型#

必须声明所有字段(实例变量),并且必须为每个字段提供类型注解。支持以下类型注解:

类型

示例

Python 整数

i: int

Python 浮点数

f: float

Python 字符串

s: str

Python 布尔值

b: bool

Python None

n: None

张量形状

shape: tf.TensorShape

张量数据类型

dtype: tf.DType

张量

t: tf.Tensor

扩展程序类型

mt: MyMaskedTensor

不规则张量

rt: tf.RaggedTensor

稀疏张量

st: tf.SparseTensor

索引切片

s: tf.IndexedSlices

可选张量

o: tf.experimental.Optional

类型联合

int_or_float: typing.Union[int, float]

元组

params: typing.Tuple[int, float, tf.Tensor, int]

可变长度元组

lengths: typing.Tuple[int, ...]

映射

tags: typing.Mapping[str, tf.Tensor]

可选值

weight: typing.Optional[tf.Tensor]

可变性#

扩展程序类型必须是不可变的。这可以确保它们能够被 TensorFlow 的计算图跟踪机制正确跟踪。如果您发现自己想要改变扩展程序类型值,请考虑改为定义用于转换值的方法。例如,与其定义 set_mask 方法来改变 MaskedTensor,您可以定义用于返回新的 MaskedTensorset_mask 方法:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def replace_mask(self, new_mask):
      self.values.shape.assert_is_compatible_with(new_mask.shape)
      return MaskedTensor(self.values, new_mask)

ExtensionType 添加的功能#

ExtensionType 基类提供了以下功能:

  • 构造函数 (__init__)。

  • 可打印表示方法 (__repr__)。

  • 相等和不等运算符 (__eq__)。

  • 验证方法 (__validate__)。

  • 强制不变性。

  • 嵌套 TypeSpec

  • 张量 API 调度支持。

有关自定义此功能的更多信息,请转到下面的“自定义 ExtensionType”部分。

构造函数#

ExtensionType 添加的构造函数会将每个字段作为命名参数(按照它们在类定义中的排列顺序)。此构造函数将对每个形参进行类型检查,并在必要时对其进行转换。特别是,Tensor 字段会使用 tf.convert_to_tensor 进行转换;Tuple 字段会被转换为 tupleMapping 字段会被转换为不可变字典。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])

# Fields are type-checked and converted to the declared types.
# For example, `mt.values` is converted to a Tensor.
print(mt.values)

如果字段值无法转换为其声明的类型,构造函数将引发 TypeError

try:
  MaskedTensor([1, 2, 3], None)
except TypeError as e:
  print(f"Got expected TypeError: {e}")

可以通过在类级别设置字段的值来指定字段的默认值:

class Pencil(tf.experimental.ExtensionType):
  color: str = "black"
  has_erasor: bool = True
  length: tf.Tensor = 1.0

Pencil()
Pencil(length=0.5, color="blue")

可打印表示#

ExtensionType 添加了一个默认的可打印表示方法 (__repr__),其中包括类名和每个字段的值:

print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))

相等运算符#

ExtensionType 添加了默认相等运算符 (__eq____ne__),如果两个值具有相同的类型并且其所有字段都相等,则认为二者相等。如果张量字段具有相同的形状并且对所有元素均符合逐元素相等,则认为张量字段相等。

a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")

:如果任何字段包含 Tensor,则 __eq__ 可能会返回标量布尔 Tensor(而非 Python 布尔值)。

验证方法#

ExtensionType 添加了一个 __validate__ 方法,此方法可重写以对字段执行验证检查。它会在调用构造函数之后,以及在字段经过类型检查并转换为其声明的类型之后运行,因此它可以假定所有字段都具有其声明的类型。

以下示例会更新 MaskedTensor 以验证其字段的 shapedtype

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor
  def __validate__(self):
    self.values.shape.assert_is_compatible_with(self.mask.shape)
    assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
  MaskedTensor([1, 2, 3], [0, 1, 0])  # Wrong `dtype` for mask.
except AssertionError as e:
  print(f"Got expected AssertionError: {e}")
try:
  MaskedTensor([1, 2, 3], [True, False])  # shapes don't match.
except ValueError as e:
  print(f"Got expected ValueError: {e}")

强制不变性#

ExtensionType 会重写 __setattr____delattr__ 方法以防止变更,从而确保扩展程序类型值不可变。

mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
  mt.mask = [True, True, True]
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")
try:
  mt.mask[0] = False
except TypeError as e:
  print(f"Got expected TypeError: {e}")
try:
  del mt.mask
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")

嵌套 TypeSpec#

每个 ExtensionType 类都有一个对应的 TypeSpec 类,它会自动创建并存储为 <extension_type_name>.Spec

此类会从值中捕获所有信息,除了任何嵌套张量的值。特别是,值的 TypeSpec 是通过将任何嵌套张量、ExtensionType 或 CompositeTensor 替换为其 TypeSpec 来创建的。

class Player(tf.experimental.ExtensionType):
  name: tf.Tensor
  attributes: Mapping[str, tf.Tensor]

anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name)  # Records `dtype` and `shape`, but not the string value.
print(anne_spec.attributes)  # Records keys and TensorSpecs for values.

TypeSpec 值可以显式构造,也可以使用 tf.type_spec_from_valueExtensionType 值构造:

spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)

TensorFlow 会使用 TypeSpec 将值划分为静态组件动态组件

  • 静态组件(在计算图构建时固定不变)使用 tf.TypeSpec 进行编码。

  • 动态组件(每次运行计算图时都会发生变化)被编码为 tf.Tensor 的列表。

例如,每当参数具有以前未见过的 TypeSpec 时,tf.function 都会回溯它的包装函数:

@tf.function
def anonymize_player(player):
  print("<<TRACING>>")
  return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))

有关详情,请参阅 tf.function 指南

自定义 ExtensionType#

除了简单地声明字段及其类型外,扩展程序类型还可以:

  • 重写默认可打印表示 (__repr__)。

  • 定义方法。

  • 定义类方法和静态方法。

  • 定义属性。

  • 重写默认构造函数 (__init__)。

  • 重写默认相等运算符 (__eq__)。

  • 定义运算符(例如 __add____lt__)。

  • 声明字段的默认值。

  • 定义子类。

重写默认可打印表示#

您可以为扩展程序类型重写此默认字符串转换运算符。以下示例会更新 MaskedTensor 类以在 Eager 模式下打印值时生成更具可读性的字符串表示。

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for invalid values.

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

def masked_tensor_str(values, mask):
  if isinstance(values, tf.Tensor):
    if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
      return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
    else:
      return f'MaskedTensor(values={values}, mask={mask})'
  if len(values.shape) == 1:
    items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
  else:
    items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
  return '[%s]' % ', '.join(items)

mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])
print(mt)

定义方法#

与任何常规 Python 类一样,扩展程序类型也可以定义方法。例如,MaskedTensor 类型可以定义 with_default 方法,该方法会返回一个 self 的副本,其中掩码值会被替换为给定的 default 值。可以选择使用 @tf.function 装饰器注解方法。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)

定义类方法和静态方法#

扩展程序类型可以使用 @classmethod@staticmethod 装饰器定义方法。例如,MaskedTensor 类型可以定义能够使用给定值来遮盖任何元素的工厂方法:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  @staticmethod
  def from_tensor_and_value_to_mask(values, value_to_mask):
    return MaskedTensor(values, values != value_to_mask)

x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)

定义属性#

与任何常规 Python 类一样,扩展程序类型也可以使用 @property 装饰器定义属性。例如,MaskedTensor 类型可以定义 dtype 属性,它是值的数据类型的简写形式:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  @property
  def dtype(self):
    return self.values.dtype

MaskedTensor([1, 2, 3], [True, False, True]).dtype

重写默认构造函数#

您可以重写扩展程序类型的默认构造函数。自定义构造函数必须为每个声明的字段均设置一个值;并且在自定义构造函数返回后,所有字段都将进行类型检查,并将按上述方式转换值。

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor
  def __init__(self, name, price, discount=0):
    self.name = name
    self.price = price * (1 - discount)

print(Toy("ball", 5.0, discount=0.2))  # On sale -- 20% off!

或者,您可以考虑保留默认构造函数,但添加一个或多个工厂方法。例如:

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor

  @staticmethod
  def new_toy_with_discount(name, price, discount):
    return Toy(name, price * (1 - discount))

print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))

重写默认相等运算符 (__eq__)#

您可以重写扩展程序类型的默认 __eq__ 运算符。以下示例会更新 MaskedTensor 以在比较相等性时忽略遮盖元素。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def __eq__(self, other):
    result = tf.math.equal(self.values, other.values)
    result = result | ~(self.mask & other.mask)
    return tf.reduce_all(result)

x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)

:您通常不需要重写 __ne__,因为其默认实现只需调用 __eq__ 并对结果求反。

使用前向引用#

如果字段的类型尚未定义,您可以改用包含类型名称的字符串。在以下示例中,字符串 "Node" 用于注解 children 字段,因为 Node 类型尚未(完全)定义。

class Node(tf.experimental.ExtensionType):
  value: tf.Tensor
  children: Tuple["Node", ...] = ()

Node(3, [Node(5), Node(2)])

定义子类#

扩展程序类型可以使用标准 Python 语法进行子类化。扩展程序类型子类可以添加新字段、方法和属性;并且可以重写构造函数、可打印表示和相等运算符。以下示例定义了一个基本的 TensorGraph 类,使用三个 Tensor 字段来编码节点之间的一组边。然后,它会定义一个子类,添加一个 Tensor 字段来记录每个节点的“特征值”。该子类还会定义一个沿着边传播特征值的方法。

class TensorGraph(tf.experimental.ExtensionType):
  num_nodes: tf.Tensor
  edge_src: tf.Tensor   # edge_src[e] = index of src node for edge e.
  edge_dst: tf.Tensor   # edge_dst[e] = index of dst node for edge e.

class TensorGraphWithNodeFeature(TensorGraph):
  node_features: tf.Tensor  # node_features[n] = feature value for node n.

  def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
    updates = tf.gather(self.node_features, self.edge_src) * weight
    new_node_features = tf.tensor_scatter_nd_add(
        self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
    return TensorGraphWithNodeFeature(
        self.num_nodes, self.edge_src, self.edge_dst, new_node_features)

g = TensorGraphWithNodeFeature(  # Edges: 0->1, 4->3, 2->2, 2->1
    num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
    node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])

print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)

定义私有字段#

扩展程序类型的字段可以通过在前面加上下划线来标记为私有(遵循标准 Python 惯例)。这不会影响 TensorFlow 处理字段的任何方式;但只为向扩展程序类型的任何用户表明这些字段为私有。

自定义 ExtensionType 的 TypeSpec#

每个 ExtensionType 类都有一个对应的 TypeSpec 类,后者是自动创建的并被存储为 <extension_type_name>.Spec。有关详情,请参阅上面的“嵌套 TypeSpec”部分。

要自定义 TypeSpec,只需定义您自己的名为 Spec 的嵌套类,ExtensionType 将使用它作为自动构造的 TypeSpec 的基础。您可以通过以下方式自定义 Spec 类:

  • 重写默认可打印表示。

  • 重写默认构造函数。

  • 定义方法、类方法、静态方法和属性。

以下示例自定义了 MaskedTensor.Spec 类以使其更加易于使用:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def with_values(self, new_values):
    return MaskedTensor(new_values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    def __repr__(self):
      return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

:自定义 Spec 类不能使用任何未在原始 ExtensionType 中声明的实例变量。

张量 API 调度#

扩展程序类型可以是“类张量”,因为它们限定或延伸了 tf.Tensor 类型定义的接口。类张量扩展程序类型的示例包括 RaggedTensorSparseTensorMaskedTensor。当应用于类张量扩展程序类型时,调度装饰器可用于重写 TensorFlow 运算的默认行为。TensorFlow 目前定义了三个调度装饰器:

  • @tf.experimental.dispatch_for_api(tf_api)

  • @tf.experimental.dispatch_for_unary_elementwise_apis(x_type)

  • @tf.experimental.dispatch_for_binary_elementwise_apis(x_type, y_type)

单个 API 的调度#

在使用指定签名进行调用时,tf.experimental.dispatch_for_api 装饰器会重写指定 TensorFlow 运算的默认行为。例如,您可以使用此装饰器来指定 tf.stack 应如何处理 MaskedTensor 值:

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))

每当使用 MaskedTensor 值的列表调用tf.stack 时,这都会重写它的默认实现(因为 values 参数使用 typing.List[MaskedTensor] 注解):

x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])

要允许 tf.stack 处理混合的 MaskedTensorTensor 值的列表,您可以优化 values 形参的类型注解并适当地更新函数体:

tf.experimental.unregister_dispatch_for(masked_stack)

def convert_to_masked_tensor(x):
  if isinstance(x, MaskedTensor):
    return x
  else:
    return MaskedTensor(x, tf.ones_like(x, tf.bool))

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
  values = [convert_to_masked_tensor(v) for v in values]
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])

有关可重写 API 的列表,请参阅 tf.experimental.dispatch_for_api 的 API 文档。

所有一元逐元素 API 的调度#

只要第一个参数(通常命名为 x)的值与类型注解 x_type 相匹配,tf.experimental.dispatch_for_unary_elementwise_apis 装饰器就会重写所有一元逐元素运算(例如 tf.math.cos)的默认行为。装饰函数应接受两个参数:

  • api_func:接受单个形参并执行逐元素运算的函数(例如 tf.abs)。

  • x:逐元素运算的第一个参数。

以下示例会更新所有一元逐元素运算以处理 MaskedTensor 类型:

 @tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
 def masked_tensor_unary_elementwise_api_handler(api_func, x):
   return MaskedTensor(api_func(x.values), x.mask)

现在,只要在 MaskedTensor 上调用一元逐元素运算,就会使用此函数。

 x = MaskedTensor([1, -2, -3], [True, False, True])
 print(tf.abs(x))
print(tf.ones_like(x, dtype=tf.float32))

所有二进制逐元素 API 的调度#

同样,tf.experimental.dispatch_for_binary_elementwise_apis 可用于更新所有二进制逐元素运算以处理 MaskedTensor 类型:

@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
  return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)

有关被重写的逐元素 API 的列表,请转到 tf.experimental.dispatch_for_unary_elementwise_apistf.experimental.dispatch_for_binary_elementwise_apis 的 API 文档。

可批处理 ExtensionType#

如果单个实例可用于表示一批值,则 ExtensionType可批处理。通常,这可以通过向所有嵌套 Tensor 添加批量维度来实现。以下 TensorFlow API 要求任何扩展程序类型的输入都可批处理:

  • tf.data.Datasetbatchunbatchfrom_tensor_slices

  • tf.kerasfitevaluatepredict

  • tf.map_fn

默认情况下,BatchableExtensionType 会通过批处理任何嵌套的 TensorCompositeTensorExtensionType 来创建批处理值。如果这不适合您的类,那么您将需要使用 tf.experimental.ExtensionTypeBatchEncoder 来重写此默认行为。例如,通过简单地堆叠各个稀疏张量的 valuesindicesdense_shape 字段来创建一批 tf.SparseTensor 值是不合适的 – 在大多数情况下,您不能堆叠这些张量,因为它们具有不兼容的形状;即便可以,结果也不会是有效的 SparseTensor

BatchableExtensionType 不会自动为 tf.stacktf.concattf.slice 等定义调度器。如果您的类需要这些 API 的支持,请使用上述调度装饰器。

BatchableExtensionType 示例:Network#

例如,请思考用于负载均衡的简单 Network 类,用于跟踪每个节点还有多少剩余工作,以及有多少带宽可用于在节点之间移动工作:

class Network(tf.experimental.ExtensionType):  # This version is not batchable.
  work: tf.Tensor       # work[n] = work left to do at node n
  bandwidth: tf.Tensor  # bandwidth[n1, n2] = bandwidth from n1->n2

net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])

要使此类型可批处理,请将基本类型更改为 BatchableExtensionType,并调整每个字段的形状来包含可选的批次维度。以下示例还添加了一个 shape 字段来跟踪批次形状。tf.data.Datasettf.map_fn 不需要此 shape 字段,但 tf.keras 需要

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape. A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)

def network_repr(network):
  work = network.work
  bandwidth = network.bandwidth
  if hasattr(work, 'numpy'):
    work = ' '.join(str(work.numpy()).split())
  if hasattr(bandwidth, 'numpy'):
    bandwidth = ' '.join(str(bandwidth.numpy()).split())
  return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
    work=tf.stack([net1.work, net2.work]),
    bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")

然后,您可以使用 tf.data.Dataset 迭代一批网络:

dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
  print(f"Batch element {i}: {network}")

您还可以使用 map_fn 对每个批处理元素应用函数:

def balance_work_greedy(network):
  delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
  delta /= 4
  delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
  new_work = network.work + tf.reduce_sum(delta, -1)
  return Network(new_work, network.bandwidth)

tf.map_fn(balance_work_greedy, batch_of_networks)

支持 ExtensionType 的 TensorFlow API#

@tf.function#

tf.function 是预计算 Python 函数 TensorFlow 计算图的装饰器,可以大幅改善 TensorFlow 代码的性能。扩展程序类型能够透明地与 @tf.function 装饰的函数一起使用。

class Pastry(tf.experimental.ExtensionType):
  sweetness: tf.Tensor  # 2d embedding that encodes sweetness
  chewiness: tf.Tensor  # 2d embedding that encodes chewiness

@tf.function
def combine_pastry_features(x: Pastry):
  return (x.sweetness + x.chewiness) / 2

cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)

如果您希望为 tf.function 明确指定 input_signature,则可以使用扩展程序类型的 TypeSpec 执行此操作。

pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))

@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
  return Pastry(x.sweetness + delta, x.chewiness)

increase_sweetness(cookie)

具体函数#

具体函数封装通过 tf.function 构建的各个跟踪计算图。扩展程序类型可以透明地与具体函数一起使用。

cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)

控制流运算#

TensorFlow 的控制流运算支持扩展程序类型:

  • tf.cond

  • tf.case

  • tf.while_loop

  • tf.identity

# Example: using tf.cond to select between two MaskedTensors. Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
  return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])

Autograph 控制流#

tf.function 中的控制流语句也支持扩展程序类型(使用 autograph)。在以下示例中,if 语句和 for 语句会自动转换为支持扩展程序类型的 tf.condtf.while_loop 运算。

@tf.function
def fn(x, b):
  if b:
    x = MaskedTensor(x, tf.less(x, 0))
  else:
    x = MaskedTensor(x, tf.greater(x, 0))
  for i in tf.range(5 if b else 7):
    x = x.with_values(x.values + 1 / 2)
  return x

print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))

Keras#

tf.keras 是 TensorFlow 用于构建和训练深度学习模型的高级 API。扩展程序类型可以作为输入传递给 Keras 模型,在 Keras 层之间传递,并由 Keras 模型返回。Keras 目前对扩展程序类型具有两项要求:

  • 它们必须可批处理(请转到上面的“可批处理 ExtensionType”)。

  • 它们必须具有名为 shape 的字段或属性。假定shape[0] 为批次维度。

以下两个小节提供了展示如何将扩展程序类型与 Keras 一起使用的示例。

Keras 示例:Network#

对于第一个示例,请思考上面“可批处理 ExtensionType”部分定义的 Network 类,它可以用于节点之间的负载均衡工作。这里再次给出它的定义:

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape. A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)
single_network = Network(  # A single network with 4 nodes.
    work=[8.0, 5, 12, 2],
    bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])

batch_of_networks = Network(  # Batch of 2 networks, each w/ 2 nodes.
    work=[[8.0, 5], [3, 2]],
    bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])

您可以定义用于处理 Network 的新 Keras 层。

class BalanceNetworkLayer(tf.keras.layers.Layer):
  """Layer that balances work between nodes in a network.

  Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
  """
  def call(self, inputs):
    # This function is defined above in the "Batchable `ExtensionType`s" section.
    return balance_work_greedy(inputs)

然后,您可以使用这些层来创建一个简单的模型。要将 ExtensionType 馈送给模型,您可以使用 tf.keras.layer.Input 层并将 type_spec 设置为扩展程序类型的 TypeSpec。如果 Keras 模型将用于处理批次,那么 type_spec 必须包含批次维度。

input_spec = Network.Spec(shape=None,
                          work=tf.TensorSpec(None, tf.float32),
                          bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    BalanceNetworkLayer(),
    ])

最后,您可以将模型应用于单个网络和一批网络。

model(single_network)
model(batch_of_networks)

Keras 示例:MaskedTensor#

在此示例中,MaskedTensor 进行了扩展以支持 Kerasshape 定义为从 values 字段计算的属性。Keras 要求您将此属性添加到扩展程序类型及其 TypeSpecMaskedTensor 还定义了 SavedModel 序列化所需的 __name__ 变量(如下)。

class MaskedTensor(tf.experimental.BatchableExtensionType):
  # __name__ is required for serialization in SavedModel; see below for details.
  __name__ = 'extension_type_colab.MaskedTensor'

  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

    def with_shape(self):
      return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
                               tf.TensorSpec(shape, self.mask.dtype))

接下来,调度装饰器会用于重写多个 TensorFlow API 的默认行为。由于这些 API 会由标准 Keras 层(例如 Dense 层)使用,对其进行重写,我们就能够将这些层与 MaskedTensor 一起使用。出于本示例的目的,我们定义了掩码张量的 matmul 以将掩码值视为零(即,不将它们包含在乘积中)。

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
 return MaskedTensor(op(x.values), x.mask)

@tf.experimental.dispatch_for_binary_elementwise_apis(
    Union[MaskedTensor, tf.Tensor],
    Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
  x = convert_to_masked_tensor(x)
  y = convert_to_masked_tensor(y)
  return MaskedTensor(op(x.values, y.values), x.mask & y.mask)

@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
                  transpose_a=False, transpose_b=False,
                  adjoint_a=False, adjoint_b=False,
                  a_is_sparse=False, b_is_sparse=False,
                  output_type=None):
  if isinstance(a, MaskedTensor):
    a = a.with_default(0)
  if isinstance(b, MaskedTensor):
    b = b.with_default(0)
  return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
                   adjoint_b, a_is_sparse, b_is_sparse, output_type)

然后,您可以使用标准 Keras 层构建一个接受 MaskedTensor 输入的 Keras 模型:

input_spec = MaskedTensor.Spec([None, 2], tf.float32)

masked_tensor_model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    tf.keras.layers.Dense(16, activation="relu"),
    tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
                  [[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))

SavedModel#

SavedModel 是序列化 TensorFlow 程序,包括权重和计算。它可以通过 Keras 模型或自定义模型构建。在任何一种情况下,扩展程序类型都可以透明地与 SavedModel 定义的函数和方法一起使用。

SavedModel 可以保存用于处理扩展程序类型的模型、层和函数,只要扩展程序类型具有 __name__ 字段即可。此名称用于注册扩展程序类型,以便在加载模型时进行定位。

示例:保存 Keras 模型#

可以使用 SavedModel 来保存使用扩展程序类型的 Keras 模型。

masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)

示例:保存自定义模型#

SavedModel 还可用于保存包含用于处理扩展程序类型的函数的自定义 tf.Module 子类。

class CustomModule(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def grow(self, x: MaskedTensor):
    """Increase values in `x` by multiplying them by `self.v`."""
    return MaskedTensor(x.values * self.v, x.mask)

module = CustomModule(100.0)

module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
                                                    dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))

在 ExtensionType 不可用时加载 SavedModel#

如果您加载使用 ExtensionTypeSavedModel,但该 ExtensionType 不可用(即尚未导入),您将看到一条警告,并且 TensorFlow 将回退到使用“匿名扩展程序类型”对象。此对象将具有与原始类型相同的字段,但将缺少您为该类型添加的任何后续自定义内容,例如自定义方法或属性。

ExtensionType 与 TensorFlow Serving 一起使用#

目前,TensorFlow Serving(以及 SavedModel“签名”字典的其他使用者)要求所有输入和输出都是原始张量。如果您希望将 TensorFlow Serving 与使用扩展程序类型的模型一起使用,可以添加用于组合或分解张量的扩展程序类型值的封装容器方法。 例如:

class CustomModuleWrapper(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def var_weighted_mean(self, x: MaskedTensor):
    """Mean value of unmasked values in x, weighted by self.v."""
    x = MaskedTensor(x.values * self.v, x.mask)
    return (tf.reduce_sum(x.with_default(0)) /
            tf.reduce_sum(tf.cast(x.mask, x.dtype)))

  @tf.function()
  def var_weighted_mean_wrapper(self, x_values, x_mask):
    """Raw tensor wrapper for var_weighted_mean."""
    return self.var_weighted_mean(MaskedTensor(x_values, x_mask))

module = CustomModuleWrapper([3., 2., 8., 5.])

module.var_weighted_mean_wrapper.get_concrete_function(
    tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)

数据集#

tf.data 是一个 API,可用于通过简单的可重用代码块构建复杂的输入流水线。它的核心数据结构是 tf.data.Dataset,表示一系列元素,每个元素包含一个或多个分量。

使用扩展程序类型构建数据集#

可以使用 Dataset.from_tensorsDataset.from_tensor_slicesDataset.from_generator 从扩展程序类型值构建数据集:

ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
  print(value)
def value_gen():
  for i in range(2, 7):
    yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])

ds = tf.data.Dataset.from_generator(
    value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
  print(value)

使用扩展程序类型批处理和取消批处理数据集#

可以使用 Dataset.batchDataset.unbatch 对具有扩展程序类型的数据集进行批处理和取消批处理。

batched_ds = ds.batch(2)
for value in batched_ds:
  print(value)
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
  print(value)