TF-NumPy 类型提升#
在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 Github 上查看源代码 | 下载笔记本 |
文本特征向量#
TensorFlow 中的类型提升有 4 个选项。
默认情况下,TensorFlow 会引发错误,而不是提升混合类型运算的类型。
运行
tf.numpy.experimental_enable_numpy_behavior()
会将 TensorFlow 切换为使用 NumPy 类型提升规则。本文档介绍了 TensorFlow 2.15(或目前为
tf-nightly
)中提供的两个新选项:
from set_env import temp_dir
# !pip install -q tf_nightly
注:experimental_enable_numpy_behavior
会更改所有 TensorFlow 的行为。
安装#
import numpy as np
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
print("Using TensorFlow version %s" % tf.__version__)
Using TensorFlow version 2.17.0
启用新类型提升#
为了在 TF-Numpy 中使用类似 JAX 的类型提升,请在为 TensorFlow 启用 NumPy 行为时指定 'all'
或 'safe'
作为数据类型转换模式。
此新系统 (dtype_conversion_mode="all"
) 可结合、可交换,并且可以轻松控制最终的浮点数宽度(它不会自动转换为更宽的浮点数)。它确实引入了一些溢出和精度损失的风险,但 dtype_conversion_mode="safe"
会强制您显式处理这些情况。下一部分将更详细地解释这两种模式。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
两种模式:ALL 模式与 SAFE 模式#
在新提升系统中,我们引入了两种模式:ALL
模式和 SAFE
模式。SAFE
模式用于减轻可能导致精度损失或位加宽的“风险”提升的担忧。
数据类型#
为简洁起见,我们将使用以下缩写。
b
表示tf.bool
u8
表示tf.uint8
i16
表示tf.int16
i32
表示tf.int32
bf16
表示tf.bfloat16
f32
表示tf.float32
f64
表示tf.float64
i32*
表示 Pythonint
或弱类型i32
f32*
表示 Pythonfloat
浮点型或弱类型f32
c128*
表示 Pythoncomplex
或弱类型c128
星号 (*) 表示相应的类型是“弱类型”- 此类数据类型是由系统临时推断的,可以遵从其他数据类型。此处更详细地解释了这个概念。
精度损失运算示例#
在以下示例中,ALL
模式下允许使用 i32
+ f32
,但由于精度损失的风险,SAFE
模式下不允许使用。
# i32 + f32 returns a f32 result in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, dtype = tf.int32)
b = tf.constant(5.0, dtype = tf.float32)
a + b # <tf.Tensor: shape=(), dtype=float32, numpy=15.0>
<tf.Tensor: shape=(), dtype=float32, numpy=15.0>
# This promotion is not allowed in SAFE mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")
a = tf.constant(10, dtype = tf.int32)
b = tf.constant(5.0, dtype = tf.float32)
try:
a + b
except TypeError as e:
print(f'{type(e)}: {e}') # TypeError: explicitly specify the dtype or switch to ALL mode.
<class 'TypeError'>: In promotion mode PromoMode.SAFE, implicit dtype promotion between (<dtype: 'int32'>, weak=False) and (<dtype: 'float32'>, weak=False) is disallowed. You need to explicitly specify the dtype in your op, or relax your dtype promotion rules (such as from SAFE mode to ALL mode).
位加宽运算示例#
在以下示例中,ALL 模式下允许使用 i8
+ u32
,但由于位加宽,SAFE 模式下不允许使用,这意味着使用的位数多于输入中的位数。请注意,新的类型提升语义仅允许必要的位加宽。
# i8 + u32 returns an i64 result in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, dtype = tf.int8)
b = tf.constant(5, dtype = tf.uint32)
a + b
<tf.Tensor: shape=(), dtype=int64, numpy=15>
# This promotion is not allowed in SAFE mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")
a = tf.constant(10, dtype = tf.int8)
b = tf.constant(5, dtype = tf.uint32)
try:
a + b
except TypeError as e:
print(f'{type(e)}: {e}') # TypeError: explicitly specify the dtype or switch to ALL mode.
<class 'TypeError'>: In promotion mode PromoMode.SAFE, implicit dtype promotion between (<dtype: 'int8'>, weak=False) and (<dtype: 'uint32'>, weak=False) is disallowed. You need to explicitly specify the dtype in your op, or relax your dtype promotion rules (such as from SAFE mode to ALL mode).
基于点阵的系统#
类型提升点阵#
新的类型提升行为通过以下类型提升点阵来确定:
更具体地说,任何两种类型之间的提升是通过查找两个节点的第一个公共子节点(包括节点本身)来确定的。
例如,在上图中,i8
和 i32
的第一个公共子节点是 i32
,因为沿着箭头方向,这两个节点在 i32
处第一次相交。
类似地,在另一个示例中,u64
和 f16
之间的结果提升类型为 f16
。
类型提升表#
按照点阵行进会生成下面的二进制提升表:
注:SAFE
不允许高亮显示的单元格。ALL
模式允许全部情况。
新类型提升的优点#
我们针对新类型提升采用类似 JAX 的基于点阵的系统,它具有以下优点:
基于点阵的系统的优点#
首先,使用基于点阵的系统可以确保三个非常重要的属性:
存在性:任何类型的组合都存在唯一的结果提升类型。
交换性:
a + b = b + a
结合性:
a + (b + c) = (a + b) = c
这三个属性对于构建一致且可预测的类型提升语义至关重要。
类似 JAX 的点阵系统的优点#
类似 JAX 的点阵系统的另一个重要优点是,除了无符号整数之外,它避免了所有超出必要范围的提升。这意味着没有 64 位输入就无法获得 64 位结果。这对于加速器上的工作特别有利,因为它可以避免不必要的 64 位值,这在旧类型提升中十分常见。
不过,这需要一定的权衡:混合浮点/整数提升很容易导致精度损失。例如,在下面的示例中,i64
+ f16
会导致将 i64
提升为 f16
。
# The first input is promoted to f16 in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
tf.constant(1, tf.int64) + tf.constant(3.2, tf.float16) # <tf.Tensor: shape=(), dtype=float16, numpy=4.2>
<tf.Tensor: shape=(), dtype=float16, numpy=4.2>
为了缓解这种担忧,我们引入了 SAFE
模式,此模式会禁止这些“风险”提升。
注:要详细了解构造点阵系统的设计注意事项,请参阅 JAX 的类型提升语义设计。
WeakTensor#
概述#
WeakTensor 是“弱类型”的张量,类似于 JAX 中的概念。
WeakTensor
的数据类型是由系统临时推断的,并且可以遵从其他数据类型。在新类型提升中引入此概念的目的是防止 TF 值与没有用户显式指定类型的值(例如 Python 标量文字)之间的二进制运算中出现不需要的类型提升。
例如,在下面的示例中,tf.constant(1.2)
被视为“弱”,因为它没有特定的数据类型。因此,tf.constant(1.2)
遵从 tf.constant(3.1, tf.float16)
的类型,产生 f16
输出。
tf.constant(1.2) + tf.constant(3.1, tf.float16) # <tf.Tensor: shape=(), dtype=float16, numpy=4.3>
<tf.Tensor: shape=(), dtype=float16, numpy=4.3>
WeakTensor 构造#
如果您创建张量而不指定数据类型,则会创建 WeakTensor。可以通过检查张量字符串表示末尾的弱特性来检查张量是否为“弱”张量。
第一种情况:使用没有用户指定数据类型的输入调用 tf.constant
时。
tf.constant(5) # <tf.Tensor: shape=(), dtype=int32, numpy=5, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=5, weak=True>
tf.constant([5.0, 10.0, 3]) # <tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 5., 10., 3.], dtype=float32), weak=True>
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 5., 10., 3.], dtype=float32), weak=True>
# A normal Tensor is created when dtype arg is specified.
tf.constant(5, tf.int32) # <tf.Tensor: shape=(), dtype=int32, numpy=5>
<tf.Tensor: shape=(), dtype=int32, numpy=5>
第二种情况:当没有用户指定数据类型的输入被传递到支持 WeakTensor 的 API 中时。
tf.math.abs([100.0, 4.0]) # <tf.Tensor: shape=(2,), dtype=float32, numpy=array([100., 4.], dtype=float32), weak=True>
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([100., 4.], dtype=float32), weak=True>
##开启新类型提升的效果
以下是由于开启新类型提升而引起的更改的非详尽清单。
提升结果更一致且可预测。
降低位加宽的风险。
tf.Tensor
数学 dunder 方法使用新类型提升。tf.constant
可以返回WeakTensor
。当传入一个数据类型与
dtype
参数不同的张量输入时,tf.constant
允许隐式转换。tf.Variable
就地运算(assign
、assign-add
、assign-sub
)允许隐式转换。tnp.array(1)
和tnp.array(1.0)
返回 32 位 WeakTensor。将创建
WeakTensor
用于支持 WeakTensor 的一元和二元 API。
提升结果更一致且可预测性提升#
使用基于点阵的系统允许新类型提升产生一致且可预测的类型提升结果。
旧类型提升#
使用旧类型提升更改运算顺序会产生不一致的结果。
# Setup
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
# (a + b) + c throws an InvalidArgumentError.
try:
tf.add(tf.add(a, b), c)
except tf.errors.InvalidArgumentError as e:
print(f'{type(e)}: {e}') # InvalidArgumentError
<class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: cannot compute AddV2 as input #1(zero-based) was expected to be a int8 tensor but is a int32 tensor [Op:AddV2] name:
# (b + a) + c returns an i32 result.
tf.add(tf.add(b, a), c) # <tf.Tensor: shape=(), dtype=int32, numpy=3>
<tf.Tensor: shape=(), dtype=int32, numpy=3>
新类型提升#
无论顺序如何,新类型提升都会产生一致的结果。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
# (a + b) + c returns a f16 result.
tf.add(tf.add(a, b), c) # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>
# (b + a) + c also returns a f16 result.
tf.add(tf.add(b, a), c) # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>
降低位加宽的风险#
旧类型提升#
旧类型提升通常会产生 64 位结果。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50) # <tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>
<tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>
新类型提升#
新类型提升返回所需位数最少的结果。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50) # <tf.Tensor: shape=(), dtype=float16, numpy=54.2>
<tf.Tensor: shape=(), dtype=float16, numpy=54.2>
tf.Tensor 数学 dunder 方法#
所有 tf.Tensor
数学 dunder 方法都将遵循新类型提升。
-tf.constant(5) # <tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
tf.constant(5, tf.int16) - tf.constant(1, tf.float32) # <tf.Tensor: shape=(), dtype=float32, numpy=4.0>
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>
tf.Variable 就地运算#
tf.Variable
就地运算中允许隐式转换。
注:任何导致数据类型与变量的原始数据类型不同的提升都是不允许的。原因是 tf.Variable
不能更改其数据类型。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.Variable(10, tf.int32)
a.assign_add(tf.constant(5, tf.int16)) # <tf.Variable shape=() dtype=int32, numpy=15>
<tf.Variable 'UnreadVariable' shape=() dtype=int32, numpy=15>
tf.constant 隐式转换#
在旧类型提升中,tf.constant
要求输入张量与数据类型参数具有相同的数据类型。不过,在新类型提升中,我们将张量隐式转换为指定的数据类型。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, tf.int16)
tf.constant(a, tf.float32) # <tf.Tensor: shape=(), dtype=float32, numpy=10.0>
<tf.Tensor: shape=(), dtype=float32, numpy=10.0>
TF-NumPy 数组#
对于使用新类型提升的 Python 输入,tnp.array
默认为 i32*
和 f32*
。
tnp.array(1) # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
tnp.array(1.0) # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=float32, numpy=1.0, weak=True>
##输入类型推断
下面是在新类型提升中推断不同输入类型的方式。
tf.Tensor
:由于tf.Tensor
具有数据类型属性,我们不做进一步的推断。NumPy 类型:包括
np.array(1)
、np.int16(1)
和np.float
等类型。由于 NumPy 输入也具有数据类型属性,我们将数据类型属性作为结果推断类型。请注意,NumPy 默认为i64
和f64
。Python 标量/嵌套类型:包括
1
、[1, 2, 3]
和(1.0, 2.0)
等类型。Python
int
被推断为i32*
。Python
float
被推断为f32*
。Python
complex
被推断为c128*
。
如果输入不属于上述任何类别,但具有数据类型属性,我们将数据类型属性作为结果推断类型。
延伸阅读#
新类型提升与 JAX-NumPy 的类型提升非常相似。如果想了解有关新类型提升和设计选择的更多详细信息,请查阅以下资源。
参考#
支持 WeakTensor 的 API#
以下是支持 WeakTensor
的 API 列表。
对于一元运算,这意味着如果传入没有用户指定类型的输入,它将返回 WeakTensor
。
对于二元运算,它将遵循此处的提升表。它可能会也可能不会返回 WeakTensor
,具体取决于两个输入的提升结果。
注:支持所有数学运算(+
、-
、*
、…)。
tf.bitwise.invert
tf.clip_by_value
tf.debugging.check_numerics
tf.expand_dims
tf.identity
tf.image.adjust_brightness
tf.image.adjust_gamma
tf.image.extract_patches
tf.image.random_brightness
tf.image.stateless_random_brightness
tf.linalg.diag
tf.linalg.diag_part
tf.linalg.matmul
tf.linalg.matrix_transpose
tf.linalg.tensor_diag_part
tf.linalg.trace
tf.math.abs
tf.math.acos
tf.math.acosh
tf.math.add
tf.math.angle
tf.math.asin
tf.math.asinh
tf.math.atan
tf.math.atanh
tf.math.ceil
tf.math.conj
tf.math.cos
tf.math.cosh
tf.math.digamma
tf.math.divide_no_nan
tf.math.divide
tf.math.erf
tf.math.erfc
tf.math.erfcinv
tf.math.erfinv
tf.math.exp
tf.math.expm1
tf.math.floor
tf.math.floordiv
tf.math.floormod
tf.math.imag
tf.math.lgamma
tf.math.log1p
tf.math.log_sigmoid
tf.math.log
tf.math.multiply_no_nan
tf.math.multiply
tf.math.ndtri
tf.math.negative
tf.math.pow
tf.math.real
tf.math.real
tf.math.reciprocal_no_nan
tf.math.reciprocal
tf.math.reduce_euclidean_norm
tf.math.reduce_logsumexp
tf.math.reduce_max
tf.math.reduce_mean
tf.math.reduce_min
tf.math.reduce_prod
tf.math.reduce_std
tf.math.reduce_sum
tf.math.reduce_variance
tf.math.rint
tf.math.round
tf.math.rsqrt
tf.math.scalar_mul
tf.math.sigmoid
tf.math.sign
tf.math.sin
tf.math.sinh
tf.math.softplus
tf.math.special.bessel_i0
tf.math.special.bessel_i0e
tf.math.special.bessel_i1
tf.math.special.bessel_i1e
tf.math.special.bessel_j0
tf.math.special.bessel_j1
tf.math.special.bessel_k0
tf.math.special.bessel_k0e
tf.math.special.bessel_k1
tf.math.special.bessel_k1e
tf.math.special.bessel_y0
tf.math.special.bessel_y1
tf.math.special.dawsn
tf.math.special.expint
tf.math.special.fresnel_cos
tf.math.special.fresnel_sin
tf.math.special.spence
tf.math.sqrt
tf.math.square
tf.math.subtract
tf.math.tan
tf.math.tanh
tf.nn.depth_to_space
tf.nn.elu
tf.nn.gelu
tf.nn.leaky_relu
tf.nn.log_softmax
tf.nn.relu6
tf.nn.relu
tf.nn.selu
tf.nn.softsign
tf.nn.space_to_depth
tf.nn.swish
tf.ones_like
tf.realdiv
tf.reshape
tf.squeeze
tf.stop_gradient
tf.transpose
tf.truncatediv
tf.truncatemod
tf.zeros_like
tf.experimental.numpy.abs
tf.experimental.numpy.absolute
tf.experimental.numpy.amax
tf.experimental.numpy.amin
tf.experimental.numpy.angle
tf.experimental.numpy.arange
tf.experimental.numpy.arccos
tf.experimental.numpy.arccosh
tf.experimental.numpy.arcsin
tf.experimental.numpy.arcsinh
tf.experimental.numpy.arctan
tf.experimental.numpy.arctanh
tf.experimental.numpy.around
tf.experimental.numpy.array
tf.experimental.numpy.asanyarray
tf.experimental.numpy.asarray
tf.experimental.numpy.ascontiguousarray
tf.experimental.numpy.average
tf.experimental.numpy.bitwise_not
tf.experimental.numpy.cbrt
tf.experimental.numpy.ceil
tf.experimental.numpy.conj
tf.experimental.numpy.conjugate
tf.experimental.numpy.copy
tf.experimental.numpy.cos
tf.experimental.numpy.cosh
tf.experimental.numpy.cumprod
tf.experimental.numpy.cumsum
tf.experimental.numpy.deg2rad
tf.experimental.numpy.diag
tf.experimental.numpy.diagflat
tf.experimental.numpy.diagonal
tf.experimental.numpy.diff
tf.experimental.numpy.empty_like
tf.experimental.numpy.exp2
tf.experimental.numpy.exp
tf.experimental.numpy.expand_dims
tf.experimental.numpy.expm1
tf.experimental.numpy.fabs
tf.experimental.numpy.fix
tf.experimental.numpy.flatten
tf.experimental.numpy.flip
tf.experimental.numpy.fliplr
tf.experimental.numpy.flipud
tf.experimental.numpy.floor
tf.experimental.numpy.full_like
tf.experimental.numpy.imag
tf.experimental.numpy.log10
tf.experimental.numpy.log1p
tf.experimental.numpy.log2
tf.experimental.numpy.log
tf.experimental.numpy.max
tf.experimental.numpy.mean
tf.experimental.numpy.min
tf.experimental.numpy.moveaxis
tf.experimental.numpy.nanmean
tf.experimental.numpy.negative
tf.experimental.numpy.ones_like
tf.experimental.numpy.positive
tf.experimental.numpy.prod
tf.experimental.numpy.rad2deg
tf.experimental.numpy.ravel
tf.experimental.numpy.real
tf.experimental.numpy.reciprocal
tf.experimental.numpy.repeat
tf.experimental.numpy.reshape
tf.experimental.numpy.rot90
tf.experimental.numpy.round
tf.experimental.numpy.signbit
tf.experimental.numpy.sin
tf.experimental.numpy.sinc
tf.experimental.numpy.sinh
tf.experimental.numpy.sort
tf.experimental.numpy.sqrt
tf.experimental.numpy.square
tf.experimental.numpy.squeeze
tf.experimental.numpy.std
tf.experimental.numpy.sum
tf.experimental.numpy.swapaxes
tf.experimental.numpy.tan
tf.experimental.numpy.tanh
tf.experimental.numpy.trace
tf.experimental.numpy.transpose
tf.experimental.numpy.triu
tf.experimental.numpy.vander
tf.experimental.numpy.var
tf.experimental.numpy.zeros_like