{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "ISubpr_SSsiM" }, "outputs": [], "source": [ "##### Copyright 2020 The TensorFlow Authors.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "3jTMb1dySr3V" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/media/pc/data/lxw/ai/d2py/doc/libs/tf-chaos/guide\n" ] } ], "source": [ "%cd ..\n", "from set_env import temp_dir" ] }, { "cell_type": "markdown", "metadata": { "id": "6DWfyNThSziV" }, "source": [ "# 使用 tf.function 时提升性能\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看在 Google Colab 中运行在 GitHub 上查看源代码下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "J122XQYG7W6w" }, "source": [ "在 TensorFlow 2 中,Eager Execution 默认处于启用状态。界面非常灵活直观(执行一次性运算要简单快速得多),不过,这可能对性能和可部署性造成一定影响。\n", "\n", "您可以使用 `tf.function` 将程序转换为计算图。这是一个转换工具,用于从 Python 代码创建独立于 Python 的数据流图。它可以帮助您创建高效且可移植的模型,并且如果要使用 `SavedModel`,则必须使用此工具。\n", "\n", "本指南介绍 `tf.function` 的底层工作原理,让您形成概念化理解,从而有效地加以利用。\n", "\n", "要点和建议包括:\n", "\n", "- 先在 Eager 模式下调试,然后使用 `@tf.function` 进行装饰。\n", "- 不依赖 Python 副作用,如对象变异或列表追加。\n", "- `tf.function` 最适合处理 TensorFlow 运算;NumPy 和 Python 调用会转换为常量。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "SjvqpgepHJPd" }, "source": [ "## 安装" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "otIdN1TS8N7S" }, "outputs": [], "source": [ "import tensorflow as tf" ] }, { "cell_type": "markdown", "metadata": { "id": "I0xDjO4SHLUD" }, "source": [ "定义一个辅助函数来演示可能遇到的错误类型:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "D25apou9IOXa" }, "outputs": [], "source": [ "import traceback\n", "import contextlib\n", "\n", "# Some helper code to demonstrate the kinds of errors you might encounter.\n", "@contextlib.contextmanager\n", "def assert_raises(error_class):\n", " try:\n", " yield\n", " except error_class as e:\n", " print('Caught expected exception \\n {}:'.format(error_class))\n", " traceback.print_exc(limit=2)\n", " except Exception as e:\n", " raise e\n", " else:\n", " raise Exception('Expected {} to be raised but no error was raised!'.format(\n", " error_class))" ] }, { "cell_type": "markdown", "metadata": { "id": "WPSfepzTHThq" }, "source": [ "## 基础知识" ] }, { "cell_type": "markdown", "metadata": { "id": "CNwYTIJ8r56W" }, "source": [ "### 用法\n", "\n", "您定义的 `Function`(例如,通过应用 `@tf.function` 装饰器)就像核心 TensorFlow 运算:您可以在 Eager 模式下执行它,可以计算梯度,等等。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "SbtT1-Wm70F2" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@tf.function # The decorator converts `add` into a `Function`.\n", "def add(a, b):\n", " return a + b\n", "\n", "add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "uP-zUelB8DbX" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v = tf.Variable(1.0)\n", "with tf.GradientTape() as tape:\n", " result = add(v, 1.0)\n", "tape.gradient(result, v)" ] }, { "cell_type": "markdown", "metadata": { "id": "ocWZvqrmHnmX" }, "source": [ "`Function` 中可以嵌套其他 `Function`。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "l5qRjdbBVdU6" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@tf.function\n", "def dense_layer(x, w, b):\n", " return add(tf.matmul(x, w), b)\n", "\n", "dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))" ] }, { "cell_type": "markdown", "metadata": { "id": "piBhz7gYsHqU" }, "source": [ "`Function` 的执行速度比 Eager 代码快,尤其是对于包含很多简单运算的计算图。但是,对于包含一些复杂运算(如卷积)的计算图,速度提升不会太明显。\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "zuXt4wRysI03" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "W0000 00:00:1729855710.179346 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.206193 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.219029 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.219816 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.254197 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.256597 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.269877 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.275167 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.275956 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.276898 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.277820 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.280412 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.283399 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.347206 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.350613 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.352031 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.355781 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.356716 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.364884 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.366005 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.370070 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.373444 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.377956 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Eager conv: 0.013796081067994237\n", "Function conv: 0.007262098835781217\n", "Note how there's not much difference in performance for convolutions\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "W0000 00:00:1729855710.503679 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.504849 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.507080 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.508031 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.509096 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.509957 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.510725 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.511536 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.512716 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.514187 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.516705 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.518798 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.521376 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.522275 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.523279 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.524118 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n", "W0000 00:00:1729855710.528098 4126223 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n" ] } ], "source": [ "import timeit\n", "conv_layer = tf.keras.layers.Conv2D(100, 3)\n", "\n", "@tf.function\n", "def conv_fn(image):\n", " return conv_layer(image)\n", "\n", "image = tf.zeros([1, 200, 200, 100])\n", "# Warm up\n", "conv_layer(image); conv_fn(image)\n", "print(\"Eager conv:\", timeit.timeit(lambda: conv_layer(image), number=10))\n", "print(\"Function conv:\", timeit.timeit(lambda: conv_fn(image), number=10))\n", "print(\"Note how there's not much difference in performance for convolutions\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "uZ4Do2AV80cO" }, "source": [ "### 跟踪\n", "\n", "本部分介绍了 `Function` 的幕后运作方式,包括*未来可能会发生变化*的实现细节。但是,当您了解跟踪的原因和时间后,就能够更轻松高效地使用 `tf.function`!" ] }, { "cell_type": "markdown", "metadata": { "id": "nhpUtRqsXoyM" }, "source": [ "#### 什么是“跟踪”?\n", "\n", "`Function` 在 [TensorFlow 计算图](https://tensorflow.google.cn/guide/intro_to_graphs#what_are_graphs)中运行您的程序。但是,`tf.Graph` 不能代表您在 Eager TensorFlow 程序中编写的全部内容。例如,Python 支持多态,但是 `tf.Graph` 要求其输入具有指定的数据类型和维度。或者,您可能执行辅助任务,例如读取命令行参数、引发错误或使用更复杂的 Python 对象。这些内容均不能在 `tf.Graph` 中运行。\n", "\n", "`Function` 通过将代码分为以下两个阶段填补了这一空缺:\n", "\n", "1. 第一阶段称为**跟踪**,在这一阶段中,`Function` 会创建新的 `tf.Graph`。Python 代码可以正常运行,但是所有 TensorFlow 运算(例如添加两个张量)都会被*推迟*:它们会被 `tf.Graph` 捕获而不运行。\n", "\n", "2. 在第二阶段中,将运行包含第一阶段中推迟的全部内容的 `tf.Graph`。此阶段比跟踪阶段快得多。\n", "\n", "根据输入,`Function` 在调用时并非总会运行第一阶段。请参阅下方的[跟踪规则](#rules_of_tracing)以更好地了解其决定方式。跳过第一阶段并仅执行第二阶段,可以实现 TensorFlow 的高性能。\n", "\n", "当 `Function` 决定跟踪时,在跟踪阶段完成后会立即运行第二阶段,因此调用 `Function` 会创建并运行 `tf.Graph`。稍后,您将了解如何使用 [`get_concrete_function`](#obtaining_concrete_functions) 来仅运行跟踪阶段。" ] }, { "cell_type": "markdown", "metadata": { "id": "K7scSzLx662f" }, "source": [ "当您将不同类型的参数传递给 `Function` 时,两个阶段都将运行:\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "kojmJrgq8U9v" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing with Tensor(\"a:0\", shape=(), dtype=int32)\n", "tf.Tensor(2, shape=(), dtype=int32)\n", "\n", "Tracing with Tensor(\"a:0\", shape=(), dtype=float32)\n", "tf.Tensor(2.2, shape=(), dtype=float32)\n", "\n", "Tracing with Tensor(\"a:0\", shape=(), dtype=string)\n", "tf.Tensor(b'aa', shape=(), dtype=string)\n", "\n" ] } ], "source": [ "@tf.function\n", "def double(a):\n", " print(\"Tracing with\", a)\n", " return a + a\n", "\n", "print(double(tf.constant(1)))\n", "print()\n", "print(double(tf.constant(1.1)))\n", "print()\n", "print(double(tf.constant(\"a\")))\n", "print()\n" ] }, { "cell_type": "markdown", "metadata": { "id": "QPfouGUQrcNb" }, "source": [ "请注意,如果重复使用同一参数类型调用 `Function`,TensorFlow 会跳过跟踪阶段并重用之前跟踪的计算图,因为后面的调用生成的计算图可能相同。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "hFccbWFRrsBp" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(b'bb', shape=(), dtype=string)\n" ] } ], "source": [ "# This doesn't print 'Tracing with ...'\n", "print(double(tf.constant(\"b\")))" ] }, { "cell_type": "markdown", "metadata": { "id": "fgIO_XEzcB9o" }, "source": [ "您可以使用 `pretty_printed_concrete_signatures()` 查看所有可用跟踪记录:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "IiQc4IKAb-NX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input Parameters:\n", " a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.int32, name=None)\n", "Output Type:\n", " TensorSpec(shape=(), dtype=tf.int32, name=None)\n", "Captures:\n", " None\n", "\n", "Input Parameters:\n", " a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)\n", "Output Type:\n", " TensorSpec(shape=(), dtype=tf.float32, name=None)\n", "Captures:\n", " None\n", "\n", "Input Parameters:\n", " a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None)\n", "Output Type:\n", " TensorSpec(shape=(), dtype=tf.string, name=None)\n", "Captures:\n", " None\n" ] } ], "source": [ "print(double.pretty_printed_concrete_signatures())" ] }, { "cell_type": "markdown", "metadata": { "id": "rKQ92VEWI7n8" }, "source": [ "目前,您已经了解 `tf.function` 通过 TensorFlow 的计算图跟踪逻辑创建缓存的动态调度层。对于术语的含义,更具体的解释如下:\n", "\n", "- `tf.Graph` 与语言无关,是 TensorFlow 计算的原始可移植表示。\n", "- `ConcreteFunction` 封装 `tf.Graph`。\n", "- `Function` 管理 `ConcreteFunction` 的缓存,并为输入选择正确的缓存。\n", "- `tf.function` 封装 Python 函数,并返回一个 `Function` 对象。\n", "- **跟踪**会创建 `tf.Graph` 并将其封装在 `ConcreteFunction` 中,也称为**跟踪**。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "129-iRsPS-gY" }, "source": [ "#### 跟踪规则\n", "\n", "被调用时,`Function` 使用每个参数的 `tf.types.experimental.TraceType` 将调用参数与现有的 `ConcreteFunction` 匹配。如果找到匹配的 `ConcreteFunction`,则将调用分派给它。如果未找到匹配项,则跟踪新的 `ConcreteFunction`。\n", "\n", "如果找到多个匹配项,则会选择最具体的签名。匹配是通过[子类型化](https://en.wikipedia.org/wiki/Subtyping)完成的,就像 C++ 或 Java 中的普通函数调用一样。例如,`TensorShape([1, 2])` 是 `TensorShape([None, None])` 的子类型,因此可以将使用 `TensorShape([1, 2])` 对 tf.function 进行的调用分派到使用 `TensorShape([None, None])` 生成的 `ConcreteFunction`。但是,如果具有 `TensorShape([1, None])` 的 `ConcreteFunction` 也存在,那么它将被优先考虑,因为它更具体。\n", "\n", "`TraceType` 由输入参数确定,具体如下所示:\n", "\n", "- 对于 `Tensor`,类型由 `Tensor` 的 `dtype` 和 `shape` 参数化;有秩形状是无秩形状的子类型;固定维度是未知维度的子类型\n", "\n", "- 对于 `Variable`,类型类似于 `Tensor`,但还包括变量的唯一资源 ID,这是正确连接控制依赖项所必需的\n", "\n", "- 对于 Python 基元值,类型对应于**值**本身。例如,值为 `3` 的 `TraceType` 是 `LiteralTraceType<3>`,而不是 `int`。\n", "\n", "- 对于 `list` 和 `tuple` 等 Python 有序容器,类型是通过其元素的类型来参数化的;例如,`[1, 2]` 的类型是 `ListTraceType, LiteralTraceType<2>>`,`[2, 1]` 的类型是 `ListTraceType, LiteralTraceType<1>>`,两者不同。\n", "\n", "- 对于 `dict` 等 Python 映射,类型也是从相同的键到值类型而不是实际值的映射。例如,`{1: 2, 3: 4}` 的类型为 `MappingTraceType<>>, >>>`。但是,与有序容器不同的是,`{1: 2, 3: 4}` 和 `{3: 4, 1: 2}` 具有等价的类型。\n", "\n", "- 对于实现 `__tf_tracing_type__` 方法的 Python 对象,类型为该方法返回的任何内容\n", "\n", "- 对于任何其他 Python 对象,类型是通用的 `TraceType`,匹配过程如下:\n", "\n", " - 首先,它检查该对象与先前跟踪中使用的对象是否相同(使用 `id()` 或 `is`)。请注意,如果对象已更改,这仍然会匹配,因此如果您使用 Python 对象作为 `tf.function` 参数,最好使用*不可变*对象。\n", " - 接下来,它检查该对象是否等于先前跟踪中使用的对象(使用 python `==`)。\n", "\n", " 请注意,此过程仅保留对象的[弱引用](https://docs.python.org/3/library/weakref.html),因此仅在对象处于范围内/未被删除时有效。)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "GNNN4lgRzpIs" }, "source": [ "注:`TraceType` 基于 `Function` 输入参数,因此仅对全局变量和自由变量进行更改将不会创建新的跟踪记录。有关处理 Python 全局变量和自由变量的建议做法,请参阅[本部分](https://docs.python.org/3/reference/executionmodel.html#binding-of-names)。" ] }, { "cell_type": "markdown", "metadata": { "id": "PEDwbumO32Wh" }, "source": [ "### 控制回溯\n", "\n", "回溯即 `Function` 创建多个跟踪记录的过程,可以确保 TensorFlow 为每组输入生成正确的计算图。但是,跟踪非常消耗资源!如果 `Function` 为每一次调用都回溯新的计算图,您会发现代码的执行速度远不如不使用 `tf.function` 时快。\n", "\n", "要控制跟踪行为,可以采用以下技巧:" ] }, { "cell_type": "markdown", "metadata": { "id": "EUtycWJa34TT" }, "source": [ "#### 将固定的 `input_signature` 传递给 `tf.function`" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "_BDMIRmu1RGB" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing with Tensor(\"x:0\", shape=(None,), dtype=int32)\n", "tf.Tensor([4 1], shape=(2,), dtype=int32)\n", "Caught expected exception \n", " :\n", "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmp/ipykernel_4126223/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmp/ipykernel_4126223/3657259638.py\", line 9, in \n", " next_collatz(tf.constant([[1, 2], [3, 4]]))\n", "TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2, 2), dtype=tf.int32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).\n", "Traceback (most recent call last):\n", " File \"/tmp/ipykernel_4126223/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmp/ipykernel_4126223/3657259638.py\", line 13, in \n", " next_collatz(tf.constant([1.0, 2.0]))\n", "TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2,), dtype=tf.float32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).\n" ] } ], "source": [ "@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))\n", "def next_collatz(x):\n", " print(\"Tracing with\", x)\n", " return tf.where(x % 2 == 0, x // 2, 3 * x + 1)\n", "\n", "print(next_collatz(tf.constant([1, 2])))\n", "# You specified a 1-D tensor in the input signature, so this should fail.\n", "with assert_raises(TypeError):\n", " next_collatz(tf.constant([[1, 2], [3, 4]]))\n", "\n", "# You specified an int32 dtype in the input signature, so this should fail.\n", "with assert_raises(TypeError):\n", " next_collatz(tf.constant([1.0, 2.0]))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ocxX-HVk7P2o" }, "source": [ "#### 使用未知维度以获得灵活性\n", "\n", "由于 TensorFlow 根据其形状匹配张量,因此,对于可变大小输入,使用 `None` 维度作为通配符可以让 `Function` 重复使用跟踪记录。对于每个批次,如果有不同长度的序列或不同大小的图像,则会出现可变大小输入(请参阅 [Transformer](../tutorials/text/transformer.ipynb) 和 [Deep Dream](../tutorials/generative/deepdream.ipynb) 教程了解示例)。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "4Viun7dh7PmF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing with Tensor(\"x:0\", shape=(None,), dtype=int32)\n", "tf.Tensor([1 2 3], shape=(3,), dtype=int32)\n", "tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)\n" ] } ], "source": [ "@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))\n", "def g(x):\n", " print('Tracing with', x)\n", " return x\n", "\n", "# No retrace!\n", "print(g(tf.constant([1, 2, 3])))\n", "print(g(tf.constant([1, 2, 3, 4, 5])))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "AY5oiQN0XIyA" }, "source": [ "#### 传递张量而不是 Python 文字\n", "\n", "通常,Python 参数用于控制超参数和计算图构造,例如 `num_layers=10`、`training=True` 或 `nonlinearity='relu'`。所以,如果 Python 参数改变,则有必要回溯计算图。\n", "\n", "但是,Python 参数有可能并未用于控制计算图构造。在这些情况下,Python 值的改变可能触发非必要的回溯。例如,在此训练循环中,AutoGraph 会动态展开。尽管有多个跟踪,但生成的计算图实际上是相同的,所以没有必要进行回溯。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "uydzR5JYUU8H" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Retracing occurs for different Python arguments.\n", "Tracing with num_steps = 10\n", "Executing with num_steps = 10\n", "Tracing with num_steps = 20\n", "Executing with num_steps = 20\n", "\n", "Traces are reused for Tensor arguments.\n", "Tracing with num_steps = Tensor(\"num_steps:0\", shape=(), dtype=int32)\n", "Executing with num_steps = 10\n", "Executing with num_steps = 20\n" ] } ], "source": [ "def train_one_step():\n", " pass\n", "\n", "@tf.function\n", "def train(num_steps):\n", " print(\"Tracing with num_steps = \", num_steps)\n", " tf.print(\"Executing with num_steps = \", num_steps)\n", " for _ in tf.range(num_steps):\n", " train_one_step()\n", "\n", "print(\"Retracing occurs for different Python arguments.\")\n", "train(num_steps=10)\n", "train(num_steps=20)\n", "\n", "print()\n", "print(\"Traces are reused for Tensor arguments.\")\n", "train(num_steps=tf.constant(10))\n", "train(num_steps=tf.constant(20))" ] }, { "cell_type": "markdown", "metadata": { "id": "4pJqkDR_Q2wz" }, "source": [ "如果需要强制执行回溯,可以创建一个新的 `Function`。单独的 `Function` 对象肯定不会共享跟踪记录。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "uHp4ousu4DdN" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing!\n", "Executing\n", "Tracing!\n", "Executing\n" ] } ], "source": [ "def f():\n", " print('Tracing!')\n", " tf.print('Executing')\n", "\n", "tf.function(f)()\n", "tf.function(f)()" ] }, { "cell_type": "markdown", "metadata": { "id": "-tZoWrA6INvc" }, "source": [ "#### 使用跟踪协议\n", "\n", "在可能的情况下,您应当首选将 Python 类型转换为 `tf.experimental.ExtensionType`。此外,`ExtensionType` 的 `TraceType` 是与其关联的 `tf.TypeSpec`。因此,如果需要,您只需重写默认的 `tf.TypeSpec` 即可控制 `ExtensionType` 的 `Tracing Protocol`。请参阅[扩展程序类型](extension_type.ipynb)指南中的*自定义 ExtensionType 的 TypeSpec*部分以了解详情。\n", "\n", "否则,要直接控制 `Function` 何时应针对特定 Python 类型进行重新跟踪,您可以自行为其实现 `Tracing Protocol`。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "gZkIh7UaIKc6" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@tf.function\n", "def get_mixed_flavor(fruit_a, fruit_b):\n", " return fruit_a.flavor + fruit_b.flavor\n", "\n", "class Fruit:\n", " flavor = tf.constant([0, 0])\n", "\n", "class Apple(Fruit):\n", " flavor = tf.constant([1, 2])\n", "\n", "class Mango(Fruit):\n", " flavor = tf.constant([3, 4])\n", "\n", "# As described in the above rules, a generic TraceType for `Apple` and `Mango`\n", "# is generated (and a corresponding ConcreteFunction is traced) but it fails to\n", "# match the second function call since the first pair of Apple() and Mango()\n", "# have gone out out of scope by then and deleted.\n", "get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function\n", "get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again\n", "\n", "# However, each subclass of the `Fruit` class has a fixed flavor, and you\n", "# can reuse an existing traced concrete function if it was the same\n", "# subclass. Avoiding such unnecessary tracing of concrete functions\n", "# can have significant performance benefits.\n", "\n", "class FruitTraceType(tf.types.experimental.TraceType):\n", " def __init__(self, fruit):\n", " self.fruit_type = type(fruit)\n", " self.fruit_value = fruit\n", "\n", " def is_subtype_of(self, other):\n", " # True if self subtypes `other` and `other`'s type matches FruitTraceType.\n", " return (type(other) is FruitTraceType and\n", " self.fruit_type is other.fruit_type)\n", "\n", " def most_specific_common_supertype(self, others):\n", " # `self` is the specific common supertype if all input types match it.\n", " return self if all(self == other for other in others) else None\n", "\n", " def placeholder_value(self, placeholder_context=None):\n", " # Use the fruit itself instead of the type for correct tracing.\n", " return self.fruit_value\n", "\n", " def __eq__(self, other):\n", " return type(other) is FruitTraceType and self.fruit_type == other.fruit_type\n", "\n", " def __hash__(self):\n", " return hash(self.fruit_type)\n", "\n", "class FruitWithTraceType:\n", "\n", " def __tf_tracing_type__(self, context):\n", " return FruitTraceType(self)\n", "\n", "class AppleWithTraceType(FruitWithTraceType):\n", " flavor = tf.constant([1, 2])\n", "\n", "class MangoWithTraceType(FruitWithTraceType):\n", " flavor = tf.constant([3, 4])\n", "\n", "# Now if you try calling it again:\n", "get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Traces a new concrete function\n", "get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Re-uses the traced concrete function" ] }, { "cell_type": "markdown", "metadata": { "id": "96IxS2WR37fF" }, "source": [ "### 获取具体函数\n", "\n", "每次跟踪函数时都会创建一个新的具体函数。您可以使用 `get_concrete_function` 直接获取具体函数。\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "mHg2CGtPQ3Hz" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Obtaining concrete trace\n", "Executing traced function\n", "tf.Tensor(b'aa', shape=(), dtype=string)\n", "tf.Tensor(b'bb', shape=(), dtype=string)\n" ] } ], "source": [ "print(\"Obtaining concrete trace\")\n", "double_strings = double.get_concrete_function(tf.constant(\"a\"))\n", "print(\"Executing traced function\")\n", "print(double_strings(tf.constant(\"a\")))\n", "print(double_strings(a=tf.constant(\"b\")))\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "6IVZ-NVf9vsx" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(b'cc', shape=(), dtype=string)\n" ] } ], "source": [ "# You can also call get_concrete_function on an InputSpec\n", "double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))\n", "print(double_strings_from_inputspec(tf.constant(\"c\")))" ] }, { "cell_type": "markdown", "metadata": { "id": "iR4fVmG34xvF" }, "source": [ "打印 `ConcreteFunction` 会显示其输入参数(及类型)和输出类型的摘要。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "o3-JbkIk41r8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ConcreteFunction Input Parameters:\n", " a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None)\n", "Output Type:\n", " TensorSpec(shape=(), dtype=tf.string, name=None)\n", "Captures:\n", " None\n" ] } ], "source": [ "print(double_strings)" ] }, { "cell_type": "markdown", "metadata": { "id": "QtqfvljZeuOV" }, "source": [ "您也可以直接检索具体函数的签名。" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "nzbrqFABe0zG" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})\n", "Tensor(\"Identity:0\", shape=(), dtype=string)\n" ] } ], "source": [ "print(double_strings.structured_input_signature)\n", "print(double_strings.structured_outputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "lar5A_5m5IG1" }, "source": [ "对不兼容的类型使用具体跟踪记录会引发错误" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "G5eeTK-T5KYj" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py\", line 442, in bind_function_inputs\n", " bound_arguments = function_type.bind_with_defaults(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/core/function/polymorphism/function_type.py\", line 277, in bind_with_defaults\n", " with_default_args[arg_name] = constraint.cast(\n", " ^^^^^^^^^^^^^^^^\n", "TypeError: Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)\n", "\n", "The above exception was the direct cause of the following exception:\n", "\n", "Traceback (most recent call last):\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1179, in _call_impl\n", " return self._call_with_structured_signature(args, kwargs)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1259, in _call_with_structured_signature\n", " function_type_utils.canonicalize_function_inputs(\n", "TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)`. Received args: (,) and kwargs: {} for signature: (a: TensorSpec(shape=(), dtype=tf.string, name=None)) -> TensorSpec(shape=(), dtype=tf.string, name=None).\n", "\n", "During handling of the above exception, another exception occurred:\n", "\n", "Traceback (most recent call last):\n", " File \"/tmp/ipykernel_4126223/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmp/ipykernel_4126223/3196284684.py\", line 2, in \n", " double_strings(tf.constant(1))\n", "tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_189 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_189]\n" ] } ], "source": [ "with assert_raises(tf.errors.InvalidArgumentError):\n", " double_strings(tf.constant(1))" ] }, { "cell_type": "markdown", "metadata": { "id": "st2L9VNQVtSG" }, "source": [ "您可能会注意到,在具体函数的输入签名中对 Python 参数进行了特别处理。TensorFlow 2.3 之前的版本会将 Python 参数直接从具体函数的签名中移除。从 TensorFlow 2.3 开始,Python 参数会保留在签名中,但是会受到约束,只能获取在跟踪期间设置的值。" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "U_QyPSGoaC35" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ConcreteFunction Input Parameters:\n", " a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=, dtype=tf.float32, name=None)\n", " b (POSITIONAL_OR_KEYWORD): Literal[2]\n", "Output Type:\n", " TensorSpec(shape=, dtype=tf.float32, name=None)\n", "Captures:\n", " None\n" ] } ], "source": [ "@tf.function\n", "def pow(a, b):\n", " return a ** b\n", "\n", "square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)\n", "print(square)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "E76vIDhQbXIb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py\", line 442, in bind_function_inputs\n", " bound_arguments = function_type.bind_with_defaults(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/core/function/polymorphism/function_type.py\", line 277, in bind_with_defaults\n", " with_default_args[arg_name] = constraint.cast(\n", " ^^^^^^^^^^^^^^^^\n", "ValueError: Can not cast 3 to Literal[2]\n", "\n", "The above exception was the direct cause of the following exception:\n", "\n", "Traceback (most recent call last):\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1179, in _call_impl\n", " return self._call_with_structured_signature(args, kwargs)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1259, in _call_with_structured_signature\n", " function_type_utils.canonicalize_function_inputs(\n", "TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=, dtype=tf.float32, name=None).\n", "\n", "During handling of the above exception, another exception occurred:\n", "\n", "Traceback (most recent call last):\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1182, in _call_impl\n", " return self._call_with_flat_signature(args, kwargs)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1233, in _call_with_flat_signature\n", " raise TypeError(f\"{self._flat_signature_summary()} got unexpected \"\n", "TypeError: pow(a) got unexpected keyword arguments: b.\n", "\n", "During handling of the above exception, another exception occurred:\n", "\n", "Traceback (most recent call last):\n", " File \"/tmp/ipykernel_4126223/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmp/ipykernel_4126223/2310937119.py\", line 4, in \n", " square(tf.constant(10.0), b=3)\n", "TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=, dtype=tf.float32, name=None).\n", "Fallback to flat signature also failed due to: pow(a) got unexpected keyword arguments: b.\n" ] } ], "source": [ "assert square(tf.constant(10.0)) == 100\n", "\n", "with assert_raises(TypeError):\n", " square(tf.constant(10.0), b=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "41gJh_JGIfuA" }, "source": [ "### 获取计算图\n", "\n", "每个具体函数都是 `tf.Graph` 的可调用封装容器。虽然一般不需要检索实际 `tf.Graph` 对象,不过,您可以从任何具体函数轻松获得实际对象。" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "5UENeGHfaX8g" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[] -> a\n", "['a', 'a'] -> add\n", "['add'] -> Identity\n" ] } ], "source": [ "graph = double_strings.graph\n", "for node in graph.as_graph_def().node:\n", " print(f'{node.input} -> {node.name}')\n" ] }, { "cell_type": "markdown", "metadata": { "id": "aIKkgr6qdtp4" }, "source": [ "### 调试\n", "\n", "通常,在 Eager 模式下调试代码比在 `tf.function` 中简单。在使用 `tf.function` 进行装饰之前,您应该先确保代码可在 Eager 模式下无错误执行。为了帮助调试,您可以调用 `tf.config.run_functions_eagerly(True)` 来全局停用和重新启用 `tf.function`。\n", "\n", "追溯仅在 `tf.function` 中出现的问题时,可参考下面的几点提示:\n", "\n", "- 普通旧 Python `print` 调用仅在跟踪期间执行,可用于追溯(重新)跟踪函数的时间。\n", "- `tf.print` 调用每次都会执行,可用于追溯执行过程中产生的中间值。\n", "- 利用 `tf.debugging.enable_check_numerics` 很容易追溯到 NaN 和 Inf 在何处创建。\n", "- `pdb`([Python 调试器](https://docs.python.org/3/library/pdb.html))可以帮助您理解跟踪的详细过程。(提醒:使用 `pdb` 调试时,AutoGraph 会自动转换 Python 源代码。)" ] }, { "cell_type": "markdown", "metadata": { "id": "5f05Vr_YBUCz" }, "source": [ "## AutoGraph 转换\n", "\n", "AutoGraph 是一个库,在 `tf.function` 中默认处于启用状态。它可以将 Python Eager 代码的子集转换为与计算图兼容的 TensorFlow 运算。这包括 `if`、`for`、`while` 等控制流。\n", "\n", "`tf.cond` 和 `tf.while_loop` 等 TensorFlow 运算仍然可以运行,但是使用 Python 编写时,控制流通常更易于编写,代码也更易于理解。" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "yCQTtTPTW3WF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.10295248 0.722364306 0.462540388 0.685418 0.410427094]\n", "[0.102590263 0.618371665 0.43215242 0.595030427 0.388835251]\n", "[0.102231853 0.549993277 0.407118559 0.53350389 0.370355666]\n", "[0.101877175 0.500515163 0.386023313 0.488054901 0.354302764]\n", "[0.101526156 0.462522238 0.367926896 0.45267123 0.34018591]\n", "[0.101178743 0.432137698 0.352177054 0.424092293 0.327643335]\n", "[0.100834884 0.407106251 0.338304847 0.400372326 0.316401631]\n", "[0.100494511 0.386012852 0.325963169 0.380267531 0.306249619]\n", "[0.100157566 0.367917836 0.314888835 0.362939775 0.297021389]\n", "[0.0998239741 0.352169096 0.304878056 0.347800791 0.288584381]\n", "[0.0994937 0.338297784 0.295770347 0.334423721 0.280831337]\n", "[0.0991666913 0.325956881 0.287437141 0.322490036 0.273674309]\n", "[0.0988428891 0.314883202 0.279774219 0.311756641 0.267040461]\n", "[0.098522231 0.30487296 0.272696108 0.302034318 0.260868818]\n", "[0.0982046872 0.295765668 0.266131759 0.293173164 0.255107969]\n", "[0.0978902 0.287432849 0.260021746 0.285052747 0.249714166]\n", "[0.0975787044 0.279770285 0.254315883 0.277575016 0.244649917]\n", "[0.0972701609 0.272692442 0.248971328 0.270659208 0.239882946]\n", "[0.0969645381 0.266128331 0.24395144 0.264238119 0.235385165]\n", "[0.0966617838 0.260018587 0.239224538 0.258255273 0.23113212]\n", "[0.0963618383 0.254312903 0.234763145 0.252662897 0.227102354]\n", "[0.0960646719 0.248968542 0.230543256 0.247420177 0.223276913]\n", "[0.0957702398 0.243948802 0.226543784 0.242492035 0.219639108]\n", "[0.0954785049 0.23922205 0.222746134 0.237848178 0.216174051]\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# A simple loop\n", "\n", "@tf.function\n", "def f(x):\n", " while tf.reduce_sum(x) > 1:\n", " tf.print(x)\n", " x = tf.tanh(x)\n", " return x\n", "\n", "f(tf.random.uniform([5]))" ] }, { "cell_type": "markdown", "metadata": { "id": "KxwJ8znPI0Cg" }, "source": [ "如果您有兴趣,可以检查 Autograph 生成的代码。" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "jlQD1ffRXJhl" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def tf__f(x):\n", " with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:\n", " do_return = False\n", " retval_ = ag__.UndefinedReturnValue()\n", "\n", " def get_state():\n", " return (x,)\n", "\n", " def set_state(vars_):\n", " nonlocal x\n", " x, = vars_\n", "\n", " def loop_body():\n", " nonlocal x\n", " ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)\n", " x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)\n", "\n", " def loop_test():\n", " return ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1\n", " ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})\n", " try:\n", " do_return = True\n", " retval_ = ag__.ld(x)\n", " except:\n", " do_return = False\n", " raise\n", " return fscope.ret(retval_, do_return)\n", "\n" ] } ], "source": [ "print(tf.autograph.to_code(f.python_function))" ] }, { "cell_type": "markdown", "metadata": { "id": "xgKmkrNTZSyz" }, "source": [ "### 条件语句\n", "\n", "AutoGraph 会将某些 `if ` 语句转换为等效的 `tf.cond` 调用。如果 `` 是张量,则会执行这种替换,否则会将 `if` 语句作为 Python 条件语句执行。\n", "\n", "Python 条件语句在跟踪时执行,因此会将该条件语句的一个分支添加到计算图。如果不使用 AutoGraph,当存在依赖于数据的控制流时,此跟踪计算图将无法选择替代分支。\n", "\n", "`tf.cond` 跟踪并将条件的两个分支添加到计算图,在执行时动态选择分支。跟踪可能产生意外的副作用;请参阅 [AutoGraph 跟踪作用](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#effects-of-the-tracing-process)以了解详情。" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "id": "BOQl8PMq2Sf3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing for loop\n", "Tracing fizzbuzz branch\n", "Tracing fizz branch\n", "Tracing buzz branch\n", "Tracing default branch\n", "1\n", "2\n", "fizz\n", "4\n", "buzz\n", "1\n", "2\n", "fizz\n", "4\n", "buzz\n", "fizz\n", "7\n", "8\n", "fizz\n", "buzz\n", "11\n", "fizz\n", "13\n", "14\n", "fizzbuzz\n", "16\n", "17\n", "fizz\n", "19\n", "buzz\n" ] } ], "source": [ "@tf.function\n", "def fizzbuzz(n):\n", " for i in tf.range(1, n + 1):\n", " print('Tracing for loop')\n", " if i % 15 == 0:\n", " print('Tracing fizzbuzz branch')\n", " tf.print('fizzbuzz')\n", " elif i % 3 == 0:\n", " print('Tracing fizz branch')\n", " tf.print('fizz')\n", " elif i % 5 == 0:\n", " print('Tracing buzz branch')\n", " tf.print('buzz')\n", " else:\n", " print('Tracing default branch')\n", " tf.print(i)\n", "\n", "fizzbuzz(tf.constant(5))\n", "fizzbuzz(tf.constant(20))" ] }, { "cell_type": "markdown", "metadata": { "id": "4rBO5AQ15HVC" }, "source": [ "有关 AutoGraph 转换的 if 语句的其他限制,请参阅[参考文档](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#if-statements)。" ] }, { "cell_type": "markdown", "metadata": { "id": "yho4J0a0ZkQS" }, "source": [ "### 循环\n", "\n", "AutoGraph 会将某些 `for` 和 `while` 语句转换为等效的 TensorFlow 循环运算,例如 `tf.while_loop`。如果不转换,则会将 `for` 或 `while` 循环作为 Python 循环执行。\n", "\n", "以下情形会执行这种替换:\n", "\n", "- `for x in y`:如果 `y` 是一个张量,则转换为 `tf.while_loop`。在特殊情况下,如果 `y` 是 `tf.data.Dataset`,则会生成 `tf.data.Dataset` 运算的组合。\n", "- `while `:如果 `` 是张量,则转换为 `tf.while_loop`。\n", "\n", "Python 循环在跟踪时执行,因而循环每迭代一次,都会将额外的运算添加到 `tf.Graph`。\n", "\n", "TensorFlow 循环会跟踪循环体,并在执行时动态选择迭代的运行次数。循环体仅在生成的 `tf.Graph` 中出现一次。\n", "\n", "有关 AutoGraph 转换的 `for` 和 `while` 语句的其他限制,请参阅[参考文档](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#while-statements)。" ] }, { "cell_type": "markdown", "metadata": { "id": "sp4rbIdfbM6s" }, "source": [ "#### 在 Python 数据上循环\n", "\n", "一个常见陷阱是在 `tf.function` 中的 Python/Numpy 数据上循环。此循环在跟踪过程中执行,因而循环每迭代一次,都会将模型的一个副本添加到 `tf.Graph`。\n", "\n", "如果要在 `tf.function` 中封装整个训练循环,最安全的方式是将数据封装为 `tf.data.Dataset`,以便 AutoGraph 动态展开训练循环。" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "WGZ19LspbZ27" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph\n", "train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph\n", "train(<_FlatMapDataset element_spec=(TensorSpec(shape=, dtype=tf.int32, name=None), TensorSpec(shape=, dtype=tf.int32, name=None))>) contains 6 nodes in its graph\n", "train(<_FlatMapDataset element_spec=(TensorSpec(shape=, dtype=tf.int32, name=None), TensorSpec(shape=, dtype=tf.int32, name=None))>) contains 6 nodes in its graph\n" ] } ], "source": [ "def measure_graph_size(f, *args):\n", " g = f.get_concrete_function(*args).graph\n", " print(\"{}({}) contains {} nodes in its graph\".format(\n", " f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))\n", "\n", "@tf.function\n", "def train(dataset):\n", " loss = tf.constant(0)\n", " for x, y in dataset:\n", " loss += tf.abs(y - x) # Some dummy computation.\n", " return loss\n", "\n", "small_data = [(1, 1)] * 3\n", "big_data = [(1, 1)] * 10\n", "measure_graph_size(train, small_data)\n", "measure_graph_size(train, big_data)\n", "\n", "measure_graph_size(train, tf.data.Dataset.from_generator(\n", " lambda: small_data, (tf.int32, tf.int32)))\n", "measure_graph_size(train, tf.data.Dataset.from_generator(\n", " lambda: big_data, (tf.int32, tf.int32)))" ] }, { "cell_type": "markdown", "metadata": { "id": "JeD2U-yrbfVb" }, "source": [ "在数据集中封装 Python/Numpy 数据时,要注意 `tf.data.Dataset.from_generator` 与 ` tf.data.Dataset.from_tensors`。前者将数据保留在 Python 中,并通过 `tf.py_function` 获取,这可能会影响性能;后者将数据的副本捆绑成计算图中的一个大 `tf.constant()` 节点,这可能会消耗较多内存。\n", "\n", "通过 `TFRecordDataset`、`CsvDataset` 等从文件中读取数据是最高效的数据使用方式,因为这样 TensorFlow 就可以自行管理数据的异步加载和预提取,不必利用 Python。要了解详细信息,请参阅 [`tf.data`:构建 TensorFlow 输入流水线](../../guide/data)指南。" ] }, { "cell_type": "markdown", "metadata": { "id": "hyksHW9TCukR" }, "source": [ "#### 累加循环值\n", "\n", "一种常见模式是不断累加循环的中间值。通常,这可以通过将元素追加到 Python 列表或将条目添加到 Python 字典来实现。但是,由于存在 Python 副作用,在动态展开循环中,这些方式无法达到预期效果。要从动态展开循环累加结果,可以使用 `tf.TensorArray` 来实现。" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "HJ3Vb3dXfefN" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch_size = 2\n", "seq_len = 3\n", "feature_size = 4\n", "\n", "def rnn_step(inp, state):\n", " return inp + state\n", "\n", "@tf.function\n", "def dynamic_rnn(rnn_step, input_data, initial_state):\n", " # [batch, time, features] -> [time, batch, features]\n", " input_data = tf.transpose(input_data, [1, 0, 2])\n", " max_seq_len = input_data.shape[0]\n", "\n", " states = tf.TensorArray(tf.float32, size=max_seq_len)\n", " state = initial_state\n", " for i in tf.range(max_seq_len):\n", " state = rnn_step(input_data[i], state)\n", " states = states.write(i, state)\n", " return tf.transpose(states.stack(), [1, 0, 2])\n", "\n", "dynamic_rnn(rnn_step,\n", " tf.random.uniform([batch_size, seq_len, feature_size]),\n", " tf.zeros([batch_size, feature_size]))" ] }, { "cell_type": "markdown", "metadata": { "id": "i2MVoIVaNApG" }, "source": [ "## 限制\n", "\n", "TensorFlow `Function` 有意设计了一些限制,在将 Python 函数转换为 `Function` 时需加以注意。" ] }, { "cell_type": "markdown", "metadata": { "id": "EJqHGFSVLIKl" }, "source": [ "### 执行 Python 副作用\n", "\n", "副作用(如打印、附加到列表、改变全局变量)在 `Function` 内部可能会出现异常行为,有时会执行两次或完全无法执行。它们只会在您第一次使用一组输入调用 `Function` 时发生。之后,将重新执行跟踪的 `tf.Graph`,而不执行 Python 代码。\n", "\n", "一般经验法则是避免在逻辑中依赖 Python 副作用,而仅使用它们来调试跟踪记录。否则,TensorFlow API(例如 `tf.data`、`tf.print`、`tf.summary`、`tf.Variable.assign` 和 `tf.TensorArray`)是确保在每次调用时 TensorFlow 运行时都能执行您的代码的最佳方式。" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "w2sACuZ9TTRk" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Traced with 1\n", "Executed with 1\n", "Executed with 1\n", "Traced with 2\n", "Executed with 2\n" ] } ], "source": [ "@tf.function\n", "def f(x):\n", " print(\"Traced with\", x)\n", " tf.print(\"Executed with\", x)\n", "\n", "f(1)\n", "f(1)\n", "f(2)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "e1I0dPiqTV8H" }, "source": [ "如果希望在每次调用 `Function` 时都执行 Python 代码,`tf.py_function` 可以作为退出点。`tf.py_function` 的缺点是不可移植,性能不高,无法使用 SavedModel 保存并且在分布式(多 GPU、TPU)设置中效果不佳。另外,由于 `tf.py_function` 必须连接到计算图中,它会将所有输入/输出转换为张量。" ] }, { "cell_type": "markdown", "metadata": { "id": "bOW1v9WVKGgH" }, "source": [ "#### 更改 Python 全局变量和自由变量\n", "\n", "更改 Python 全局变量和[自由变量](https://docs.python.org/3/reference/executionmodel.html#binding-of-names)视为 Python 副作用,因此仅在跟踪期间发生。\n" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "7aJD--9qTWmg" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Python side effect\n" ] } ], "source": [ "external_list = []\n", "\n", "@tf.function\n", "def side_effect(x):\n", " print('Python side effect')\n", " external_list.append(x)\n", "\n", "side_effect(1)\n", "side_effect(1)\n", "side_effect(1)\n", "# The list append only happened once!\n", "assert len(external_list) == 1" ] }, { "cell_type": "markdown", "metadata": { "id": "5eZTFRv_k_nR" }, "source": [ "有时很难注意到意外行为。在下面的示例中,`counter` 旨在保护变量的增量。然而,由于它是一个 Python 整数而不是 TensorFlow 对象,它的值在第一次跟踪期间被捕获。使用 `tf.function` 时,`assign_add` 将被无条件记录在底层计算图中。因此,每次调用 `tf.function` 时 `v` 都会增加 1。当使用 Python 副作用(示例中的 `counter`)确定要运行的运算(示例中的 `assign_add`)时,此问题在尝试使用 `tf.function` 装饰器将其计算图模式 Tensorflow 代码迁移到 Tensorflow 2 的用户中十分常见。通常,用户只有在看到可疑的数值结果或明显低于预期的性能(例如,如果受保护运算的开销非常大)后才会意识到这一点。" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "id": "5r6p7-9jk_3L" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n", "2\n", "3\n" ] } ], "source": [ "class Model(tf.Module):\n", " def __init__(self):\n", " self.v = tf.Variable(0)\n", " self.counter = 0\n", "\n", " @tf.function\n", " def __call__(self):\n", " if self.counter == 0:\n", " # A python side-effect\n", " self.counter += 1\n", " self.v.assign_add(1)\n", "\n", " return self.v\n", "\n", "m = Model()\n", "for n in range(3):\n", " print(m().numpy()) # prints 1, 2, 3" ] }, { "cell_type": "markdown", "metadata": { "id": "tXCTcHoVcxhX" }, "source": [ "实现预期行为的一种解决方法是使用 [`tf.init_scope`](https://tensorflow.google.cn/api_docs/python/tf/init_scope) 将运算提升到函数计算图以外。这样可以确保变量增量在跟踪期间只执行一次。应当注意的是,`init_scope` 还有其他副作用,包括清除控制流和梯度带。有时 `init_scope` 的使用会变得过于复杂而无法实际管理。" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "id": "An4MrIbrcvi8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n", "1\n", "1\n" ] } ], "source": [ "class Model(tf.Module):\n", " def __init__(self):\n", " self.v = tf.Variable(0)\n", " self.counter = 0\n", "\n", " @tf.function\n", " def __call__(self):\n", " if self.counter == 0:\n", " # Lifts ops out of function-building graphs\n", " with tf.init_scope():\n", " self.counter += 1\n", " self.v.assign_add(1)\n", "\n", " return self.v\n", "\n", "m = Model()\n", "for n in range(3):\n", " print(m().numpy()) # prints 1, 1, 1" ] }, { "cell_type": "markdown", "metadata": { "id": "pbFG5CX4LwQA" }, "source": [ "总之,根据经验,您应避免改变整数或容器(如位于 `Function` 外部的列表)等 Python 对象,而应使用参数和 TF 对象。例如,[在循环中累加值](#accumulating_values_in_a_loop)部分中提供了一个如何实现类列表运算的示例。\n", "\n", "在某些情况下,如果为 [`tf.Variable`](https://tensorflow.google.cn/guide/variable),则您可以捕获和处理状态。这是通过重复调用相同的 `ConcreteFunction` 来更新 Keras 模型权重的方式。" ] }, { "cell_type": "markdown", "metadata": { "id": "X_oNNGrAqPJ1" }, "source": [ "#### 使用 Python 迭代器和生成器" ] }, { "cell_type": "markdown", "metadata": { "id": "msTmv-oyUNaf" }, "source": [ "很多 Python 功能(如生成器和迭代器)依赖 Python 运行时来跟踪状态。通常,虽然这些构造在 Eager 模式下可以正常工作,但它们是 Python 副作用的示例,因此仅在跟踪期间发生。" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "id": "FNPD4unZUedH" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Value: 1\n", "Value: 1\n", "Value: 1\n" ] } ], "source": [ "@tf.function\n", "def buggy_consume_next(iterator):\n", " tf.print(\"Value:\", next(iterator))\n", "\n", "iterator = iter([1, 2, 3])\n", "buggy_consume_next(iterator)\n", "# This reuses the first value from the iterator, rather than consuming the next value.\n", "buggy_consume_next(iterator)\n", "buggy_consume_next(iterator)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "wcS3TAgCjTWR" }, "source": [ "就像 TensorFlow 具有用于列表构造的专用 `tf.TensorArray` 一样,它也具有用于迭代构造的专用 `tf.data.Iterator`。有关概述,请参阅 [AutoGraph 转换](#autograph_transformations)部分。此外,[`tf.data`](https://tensorflow.google.cn/guide/data) API 也可帮助实现生成器模式:\n" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "id": "8D_iKetXW6VE" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Value: 1\n", "Value: 2\n", "Value: 3\n" ] } ], "source": [ "@tf.function\n", "def good_consume_next(iterator):\n", " # This is ok, iterator is a tf.data.Iterator\n", " tf.print(\"Value:\", next(iterator))\n", "\n", "ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])\n", "iterator = iter(ds)\n", "good_consume_next(iterator)\n", "good_consume_next(iterator)\n", "good_consume_next(iterator)" ] }, { "cell_type": "markdown", "metadata": { "id": "i8YAMYb6KEh4" }, "source": [ "### tf.function 的所有输出都必须是返回值\n", "\n", "除了 `tf.Variable` 外,一个 tf.function 必须返回其所有输出。尝试直接从函数访问任何张量而不遍历返回值会导致“泄漏”。\n", "\n", "例如,下面的函数通过 Python 全局变量 `x`“泄漏”张量 `a`:" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "zrdp4rjxg6jo" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3\n", "'SymbolicTensor' object has no attribute 'numpy'\n" ] } ], "source": [ "x = None\n", "\n", "@tf.function\n", "def leaky_function(a):\n", " global x\n", " x = a + 1 # Bad - leaks local tensor\n", " return a + 2\n", "\n", "correct_a = leaky_function(tf.constant(1))\n", "\n", "print(correct_a.numpy()) # Good - value obtained from function's returns\n", "try:\n", " x.numpy() # Bad - tensor leaked from inside the function, cannot be used here\n", "except AttributeError as expected:\n", " print(expected)" ] }, { "cell_type": "markdown", "metadata": { "id": "-d4_J_DC5rxX" }, "source": [ "即使同时返回泄漏的值时也是如此:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "id": "PrcpPB8C5s9T" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2\n", "'SymbolicTensor' object has no attribute 'numpy'\n", "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmp/ipykernel_4126223/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmp/ipykernel_4126223/566849597.py\", line 21, in \n", " captures_leaked_tensor(tf.constant(2))\n", "TypeError: is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.\n", "Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.\n", "\n", " was defined here:\n", " File \"\", line 198, in _run_module_as_main\n", " File \"\", line 88, in _run_code\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel_launcher.py\", line 17, in \n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel/kernelapp.py\", line 701, in start\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tornado/platform/asyncio.py\", line 205, in start\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/asyncio/base_events.py\", line 639, in run_forever\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/asyncio/base_events.py\", line 1985, in _run_once\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/asyncio/events.py\", line 88, in _run\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 534, in dispatch_queue\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 523, in process_one\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 429, in dispatch_shell\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel/ipkernel.py\", line 429, in do_execute\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3075, in run_cell\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3130, in _run_cell\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/IPython/core/async_helpers.py\", line 128, in _pseudo_sync_runner\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3334, in run_cell_async\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3517, in run_ast_nodes\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3577, in run_code\n", " File \"/tmp/ipykernel_4126223/566849597.py\", line 7, in \n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/util/traceback_utils.py\", line 150, in error_handler\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 833, in __call__\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 889, in _call\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 696, in _initialize\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py\", line 178, in trace_function\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py\", line 283, in _maybe_define_function\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py\", line 310, in _create_concrete_function\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/framework/func_graph.py\", line 1059, in func_graph_from_py_func\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 599, in wrapped_fn\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py\", line 41, in autograph_handler\n", " File \"/tmp/ipykernel_4126223/566849597.py\", line 4, in leaky_function\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/util/traceback_utils.py\", line 150, in error_handler\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/framework/override_binary_operator.py\", line 113, in binary_op_wrapper\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/ops/tensor_math_operator_overrides.py\", line 28, in _add_dispatch_factory\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/util/traceback_utils.py\", line 150, in error_handler\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/util/dispatch.py\", line 1260, in op_dispatch_handler\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/ops/math_ops.py\", line 1701, in _add_dispatch\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/ops/gen_math_ops.py\", line 490, in add_v2\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/framework/op_def_library.py\", line 796, in _apply_op_helper\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/framework/func_graph.py\", line 670, in _create_op_internal\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/framework/ops.py\", line 2682, in _create_op_internal\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tensorflow/python/framework/ops.py\", line 1177, in from_node_def\n", "\n", "The tensor cannot be accessed from here, because it was defined in FuncGraph(name=leaky_function, id=140693262518592), which is out of scope.\n" ] } ], "source": [ "@tf.function\n", "def leaky_function(a):\n", " global x\n", " x = a + 1 # Bad - leaks local tensor\n", " return x # Good - uses local tensor\n", "\n", "correct_a = leaky_function(tf.constant(1))\n", "\n", "print(correct_a.numpy()) # Good - value obtained from function's returns\n", "try:\n", " x.numpy() # Bad - tensor leaked from inside the function, cannot be used here\n", "except AttributeError as expected:\n", " print(expected)\n", "\n", "@tf.function\n", "def captures_leaked_tensor(b):\n", " b += x # Bad - `x` is leaked from `leaky_function`\n", " return b\n", "\n", "with assert_raises(TypeError):\n", " captures_leaked_tensor(tf.constant(2))" ] }, { "cell_type": "markdown", "metadata": { "id": "Sm2ghjyy50D4" }, "source": [ "通常,当您使用 Python 语句或数据结构时,会发生此类泄漏。除了泄漏不可访问的张量之外,此类语句也可能是错误的,因为它们被视为 Python 副作用,而且不能保证在每次函数调用时都执行。\n", "\n", "泄漏局部张量的常见方法还包括改变外部 Python 集合或对象:" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "id": "D7bLe8y652wU" }, "outputs": [], "source": [ "class MyClass:\n", "\n", " def __init__(self):\n", " self.field = None\n", "\n", "external_list = []\n", "external_object = MyClass()\n", "\n", "def leaky_function():\n", " a = tf.constant(1)\n", " external_list.append(a) # Bad - leaks tensor\n", " external_object.field = a # Bad - leaks tensor" ] }, { "cell_type": "markdown", "metadata": { "id": "g-XVQcD-wf5K" }, "source": [ "### 不支持递归 tf.functions\n", "\n", "不支持递归 `Function`,它们可能导致无限循环。例如:" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "id": "QSN-T1m5EFcR" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmp/ipykernel_4126223/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 9, in \n", " recursive_fn(tf.constant(5)) # Bad - maximum recursion error.\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "tensorflow.python.autograph.impl.api.StagingError: in user code:\n", "\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", " File \"/tmp/ipykernel_4126223/2233998312.py\", line 3, in recursive_fn *\n", " if n > 0:\n", "\n", " RecursionError: maximum recursion depth exceeded\n", "\n" ] } ], "source": [ "@tf.function\n", "def recursive_fn(n):\n", " if n > 0:\n", " return recursive_fn(n - 1)\n", " else:\n", " return 1\n", "\n", "with assert_raises(Exception):\n", " recursive_fn(tf.constant(5)) # Bad - maximum recursion error." ] }, { "cell_type": "markdown", "metadata": { "id": "LyRyooKGUxNV" }, "source": [ "即使递归 `Function` 看似有效,Python 函数也会被多次跟踪,并且可能会对性能产生影响。例如:" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "id": "7FlmTqfMUwmT" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tracing\n", "tracing\n", "tracing\n", "tracing\n", "tracing\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@tf.function\n", "def recursive_fn(n):\n", " if n > 0:\n", " print('tracing')\n", " return recursive_fn(n - 1)\n", " else:\n", " return 1\n", "\n", "recursive_fn(5) # Warning - multiple tracings" ] }, { "cell_type": "markdown", "metadata": { "id": "-D6nh3QirXAd" }, "source": [ "## 已知问题\n", "\n", "如果您的 `Function` 评估不正确,则这些计划于将来得到修复的已知问题可能可以解释该问题。" ] }, { "cell_type": "markdown", "metadata": { "id": "ZoPg5w1Pjqna" }, "source": [ "### 取决于 Python 全局变量和自由变量\n", "\n", "当使用 Python 参数的新值进行调用时,`Function` 会创建新的 `ConcreteFunction`。但是,对于该 `Function` 的 Python 闭包、全局变量或非局部变量,则不会创建。如果它们的值在调用 `Function` 之间发生变化,则 `Function` 仍将使用其在跟踪时所具有的值。这与常规 Python 函数的工作方式不同。\n", "\n", "因此,您应采用使用参数的函数式编程风格而非闭合外部名称。" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "id": "oeJMdXd3M0cM" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Buggy: tf.Tensor(2, shape=(), dtype=int32)\n", "Correct: tf.Tensor(2, shape=(), dtype=int32)\n" ] } ], "source": [ "@tf.function\n", "def buggy_add():\n", " return 1 + foo\n", "\n", "@tf.function\n", "def recommended_add(foo):\n", " return 1 + foo\n", "\n", "foo = 1\n", "print(\"Buggy:\", buggy_add())\n", "print(\"Correct:\", recommended_add(foo))" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "id": "L3q7sUJWZOSU" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Updating the value of `foo` to 100!\n", "Buggy: tf.Tensor(2, shape=(), dtype=int32)\n", "Correct: tf.Tensor(101, shape=(), dtype=int32)\n" ] } ], "source": [ "print(\"Updating the value of `foo` to 100!\")\n", "foo = 100\n", "print(\"Buggy:\", buggy_add()) # Did not change!\n", "print(\"Correct:\", recommended_add(foo))" ] }, { "cell_type": "markdown", "metadata": { "id": "ZoPg5w1Pjqnb" }, "source": [ "更新全局值的另一种方法是使其成为 `tf.Variable` 并改用 `Variable.assign` 方法。\n" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "id": "oeJMdXd3M0cc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable: tf.Tensor(2, shape=(), dtype=int32)\n" ] } ], "source": [ "@tf.function\n", "def variable_add():\n", " return 1 + foo\n", "\n", "foo = tf.Variable(1)\n", "print(\"Variable:\", variable_add())\n" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "id": "L3q7sUJWZOSd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Updating the value of `foo` to 100!\n", "Variable: tf.Tensor(101, shape=(), dtype=int32)\n" ] } ], "source": [ "print(\"Updating the value of `foo` to 100!\")\n", "foo.assign(100)\n", "print(\"Variable:\", variable_add())" ] }, { "cell_type": "markdown", "metadata": { "id": "hvwe9gTIWfx6" }, "source": [ "### 依赖于 Python 对象" ] }, { "cell_type": "markdown", "metadata": { "id": "BJkZS-SwPvOQ" }, "source": [ "支持将自定义 Python 对象作为参数传递给 `tf.function`,但有一定的限制。\n", "\n", "为了获得最大的特征覆盖率,请考虑在将对象传递给 `tf.function` 之前将其转换为[扩展类型](extension_type.ipynb)。此外,您也可以使用 Python 基元以及与 `tf.nest` 兼容的结构。\n", "\n", "但是,正如[跟踪规则](#rules_of_tracing)中所述,当自定义 Python 类未提供自定义 `TraceType` 时,`tf.function` 被迫使用基于实例的相等性,这意味着当您传递**具有修改特性的同一对象**时,它将**不会创建新的跟踪记录**。" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "id": "ux8KJESVWDxX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(20.0, shape=(), dtype=float32)\n" ] } ], "source": [ "class SimpleModel(tf.Module):\n", " def __init__(self):\n", " # These values are *not* tf.Variables.\n", " self.bias = 0.\n", " self.weight = 2.\n", "\n", "@tf.function\n", "def evaluate(model, x):\n", " return model.weight * x + model.bias\n", "\n", "simple_model = SimpleModel()\n", "x = tf.constant(10.)\n", "print(evaluate(simple_model, x))" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "id": "mUxRF4ghZZvX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Adding bias!\n", "tf.Tensor(20.0, shape=(), dtype=float32)\n" ] } ], "source": [ "print(\"Adding bias!\")\n", "simple_model.bias += 5.0\n", "print(evaluate(simple_model, x)) # Didn't change :(" ] }, { "cell_type": "markdown", "metadata": { "id": "Ytcgg2qFWaBF" }, "source": [ "使用相同的 `Function` 评估模型的修改实例并不合理,因为它仍然具有与原始模型[相同的基于实例的 TraceType](#rules_of_tracing)。\n", "\n", "因此,建议您编写 `Function` 以避免依赖于可变对象特性,或者为对象实现[跟踪协议](#use_the_tracing_protocol)以将此类特性通知给 `Function`。\n", "\n", "如果这不可行,则一种解决方法是,每次修改对象时都创建新的 `Function` 以强制回溯:" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "id": "pFvWmWAAQjrv" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(20.0, shape=(), dtype=float32)\n" ] } ], "source": [ "def evaluate(model, x):\n", " return model.weight * x + model.bias\n", "\n", "new_model = SimpleModel()\n", "evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)\n", "# Don't pass in `new_model`, `Function` already captured its state during tracing.\n", "print(evaluate_no_bias(x))" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "id": "bdU2-jF4ZH0B" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Adding bias!\n", "tf.Tensor(25.0, shape=(), dtype=float32)\n" ] } ], "source": [ "print(\"Adding bias!\")\n", "new_model.bias += 5.0\n", "# Create new Function and ConcreteFunction since you modified new_model.\n", "evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)\n", "print(evaluate_with_bias(x)) # Don't pass in `new_model`." ] }, { "cell_type": "markdown", "metadata": { "id": "uFgEZClsZrEi" }, "source": [ "[回溯可能十分耗费资源](https://tensorflow.google.cn/guide/intro_to_graphs#tracing_and_performance),您可以使用 `tf.Variable` 作为对象特性,可以对其进行改变(但非更改,请注意!) 以在无需回溯的情况下实现相似效果。\n" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "id": "daAP_lucwS6w" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(20.0, shape=(), dtype=float32)\n" ] } ], "source": [ "class BetterModel:\n", "\n", " def __init__(self):\n", " self.bias = tf.Variable(0.)\n", " self.weight = tf.Variable(2.)\n", "\n", "@tf.function\n", "def evaluate(model, x):\n", " return model.weight * x + model.bias\n", "\n", "better_model = BetterModel()\n", "print(evaluate(better_model, x))\n" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "id": "ktqwMJBqwTFj" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Adding bias!\n", "tf.Tensor(25.0, shape=(), dtype=float32)\n" ] } ], "source": [ "print(\"Adding bias!\")\n", "better_model.bias.assign_add(5.0) # Note: instead of better_model.bias += 5\n", "print(evaluate(better_model, x)) # This works!" ] }, { "cell_type": "markdown", "metadata": { "id": "lPr_6mK_AQWL" }, "source": [ "### 创建 tf.Variables\n", "\n", "`Function` 仅支持在第一次调用时创建一次,并且在后续函数调用中重复使用的单例 `tf.Variable`。下面的代码段会在每个函数调用中创建一个新的 `tf.Variable`,这会导致 `ValueError` 异常。\n", "\n", "示例:" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "id": "Tx0Vvnb_9OB-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmp/ipykernel_4126223/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmp/ipykernel_4126223/3018268426.py\", line 7, in \n", " f(1.0)\n", "ValueError: in user code:\n", "\n", " File \"/tmp/ipykernel_4126223/3018268426.py\", line 3, in f *\n", " v = tf.Variable(1.0)\n", "\n", " ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.\n", "\n" ] } ], "source": [ "@tf.function\n", "def f(x):\n", " v = tf.Variable(1.0)\n", " return v\n", "\n", "with assert_raises(ValueError):\n", " f(1.0)" ] }, { "cell_type": "markdown", "metadata": { "id": "KYm6-5GCILXQ" }, "source": [ "用于解决这种限制的常见模式是从 Python None 值开始,随后,在值为 None 时,有条件地创建 `tf.Variable`:" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "id": "HQrG5_kOiKl_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(1, shape=(), dtype=int32)\n", "tf.Tensor(2, shape=(), dtype=int32)\n" ] } ], "source": [ "class Count(tf.Module):\n", " def __init__(self):\n", " self.count = None\n", "\n", " @tf.function\n", " def __call__(self):\n", " if self.count is None:\n", " self.count = tf.Variable(0)\n", " return self.count.assign_add(1)\n", "\n", "c = Count()\n", "print(c())\n", "print(c())" ] }, { "cell_type": "markdown", "metadata": { "id": "7uD6qI7aJwbR" }, "source": [ "#### 与多个 Keras 优化器一起使用\n", "\n", "将多个 Keras 优化器与 `tf.function` 一起使用时,您可能会遇到 `ValueError: tf.function only supports singleton tf.Variables created on the first call.`。发生此错误的原因是优化器在首次应用梯度时会在内部创建 `tf.Variables`。" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "id": "yWQ3-r99Jvze" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Calling `train_step` with different optimizer...\n", "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmp/ipykernel_4126223/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmp/ipykernel_4126223/950644149.py\", line 18, in \n", " train_step(w, x, y, opt2)\n", "ValueError: in user code:\n", "\n", " File \"/tmp/ipykernel_4126223/950644149.py\", line 9, in train_step *\n", " optimizer.apply_gradients(zip(gradients, [w]))\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/optimizers/base_optimizer.py\", line 291, in apply_gradients **\n", " self.apply(grads, trainable_variables)\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/optimizers/base_optimizer.py\", line 330, in apply\n", " self.build(trainable_variables)\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/optimizers/adam.py\", line 97, in build\n", " self.add_variable_from_reference(\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/backend/tensorflow/optimizer.py\", line 36, in add_variable_from_reference\n", " return super().add_variable_from_reference(\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/optimizers/base_optimizer.py\", line 227, in add_variable_from_reference\n", " return self.add_variable(\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/optimizers/base_optimizer.py\", line 201, in add_variable\n", " variable = backend.Variable(\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/backend/common/variables.py\", line 163, in __init__\n", " self._initialize_with_initializer(initializer)\n", " File \"/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/keras/src/backend/tensorflow/core.py\", line 40, in _initialize_with_initializer\n", " self._value = tf.Variable(\n", "\n", " ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.\n", "\n" ] } ], "source": [ "opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)\n", "opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)\n", "\n", "@tf.function\n", "def train_step(w, x, y, optimizer):\n", " with tf.GradientTape() as tape:\n", " L = tf.reduce_sum(tf.square(w*x - y))\n", " gradients = tape.gradient(L, [w])\n", " optimizer.apply_gradients(zip(gradients, [w]))\n", "\n", "w = tf.Variable(2.)\n", "x = tf.constant([-1.])\n", "y = tf.constant([2.])\n", "\n", "train_step(w, x, y, opt1)\n", "print(\"Calling `train_step` with different optimizer...\")\n", "with assert_raises(ValueError):\n", " train_step(w, x, y, opt2)" ] }, { "cell_type": "markdown", "metadata": { "id": "7Q8BRPCThTjB" }, "source": [ "如果您需要在训练期间更改优化器,一种解决方法是为每个优化器创建一个新的 `Function`,直接调用 [`ConcreteFunction`](#obtaining_concrete_functions)。" ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "id": "YV5F2Gy9hSI3" }, "outputs": [], "source": [ "opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)\n", "opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)\n", "\n", "# Not a tf.function.\n", "def train_step(w, x, y, optimizer):\n", " with tf.GradientTape() as tape:\n", " L = tf.reduce_sum(tf.square(w*x - y))\n", " gradients = tape.gradient(L, [w])\n", " optimizer.apply_gradients(zip(gradients, [w]))\n", "\n", "w = tf.Variable(2.)\n", "x = tf.constant([-1.])\n", "y = tf.constant([2.])\n", "\n", "# Make a new Function and ConcreteFunction for each optimizer.\n", "train_step_1 = tf.function(train_step)\n", "train_step_2 = tf.function(train_step)\n", "for i in range(10):\n", " if i % 2 == 0:\n", " train_step_1(w, x, y, opt1)\n", " else:\n", " train_step_2(w, x, y, opt2)" ] }, { "cell_type": "markdown", "metadata": { "id": "Xjnz5CcuqQac" }, "source": [ "#### 与多个 Keras 模型一起使用\n", "\n", "将不同的模型实例传递给同一 `Function` 时,您也可能会遇到 `ValueError: tf.function only supports singleton tf.Variables created on the first call.`。\n", "\n", "发生此错误的原因是 Keras 模型([未定义其输入形状](https://tensorflow.google.cn/guide/keras/custom_layers_and_models#best_practice_deferring_weight_creation_until_the_shape_of_the_inputs_is_known))和 Keras 层会在首次调用时创建 `tf.Variables`。您可能正在尝试在已调用的 `Function` 中初始化这些变量。为避免此错误,请在训练模型之前尝试调用 `model.build(input_shape)` 以初始化所有权重。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "IKyrEY5GVX3M" }, "source": [ "## 延伸阅读\n", "\n", "要了解如何导出和加载 `Function`,请参阅 [SavedModel 指南](https://render.githubusercontent.com/guide/saved_model)。要详细了解跟踪后执行的计算图优化,请参阅 [Grappler 指南](https://render.githubusercontent.com/guide/graph_optimization)。要了解如何优化数据流水线和剖析模型性能,请参阅 [Profiler 指南](https://render.githubusercontent.com/guide/profiler.md)。" ] } ], "metadata": { "colab": { "name": "function.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "xxx", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 0 }