{ "cells": [ { "cell_type": "markdown", "metadata": { "cellView": "form", "id": "yOYx6tzSnWQ3" }, "source": [ "````{admonition} Copyright 2020 The TensorFlow Authors.\n", "```\n", "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "```\n", "````" ] }, { "cell_type": "markdown", "metadata": { "id": "6xgB0Oz5eGSQ" }, "source": [ "# 计算图和 `tf.function` 简介" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import set_env\n", "!export TF_FORCE_GPU_ALLOW_GROWTH=true" ] }, { "cell_type": "markdown", "metadata": { "id": "w4zzZVZtQb1w" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看在 Google Colab 中运行 在 GitHub 上查看源代码\n", "下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "RBKqnXI9GOax" }, "source": [ "## 概述\n", "\n", "本指南深入探究 TensorFlow 和 Keras 以演示 TensorFlow 的工作原理。如果您想立即开始使用 Keras,请查看 [Keras 指南合集](https://tensorflow.google.cn/guide/keras/)。\n", "\n", "在本指南中,您将了解 TensorFlow 如何让您对代码进行简单的更改来获取计算图、计算图的存储和表示方式以及如何使用它们来加速您的模型。\n", "\n", "注:对于那些只熟悉 TensorFlow 1.x 的用户来说,本指南演示了迥然不同的计算图视图。\n", "\n", "**这是一个整体概述,涵盖了 `tf.function` 如何让您从 Eager Execution 切换到计算图执行。**有关 `tf.function` 的更完整规范,请转到使用 `tf.function` 提高性能指南。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "v0DdlfacAdTZ" }, "source": [ "### 什么是计算图?\n", "\n", "在前三篇指南中,您以 **Eager** 模式运行了 TensorFlow。这意味着 TensorFlow 运算由 Python 逐个执行,然后将结果返回给 Python。\n", "\n", "虽然 Eager Execution 具有多个独特的优势,但计算图执行在 Python 外部实现了可移植性,并且往往提供更出色的性能。**计算图执行**意味着张量计算作为 *TensorFlow 计算图*执行,后者有时被称为 `tf.Graph` 或简称为“计算图”。\n", "\n", "**计算图是包含一组 `tf.Operation` 对象(表示计算单元)和 `tf.Tensor` 对象(表示在运算之间流动的数据单元)的数据结构。**计算图在 `tf.Graph` 上下文中定义。由于这些计算图是数据结构,无需原始 Python 代码即可保存、运行和恢复它们。\n", "\n", "下面是一个表示两层神经网络的 TensorFlow 计算图在 TensorBoard 中呈现后的样子:" ] }, { "cell_type": "markdown", "metadata": { "id": "FvQ5aBuRGT1o" }, "source": [ "\"一个简单的 " ] }, { "cell_type": "markdown", "metadata": { "id": "DHpY3avXGITP" }, "source": [ "### 计算图的优点\n", "\n", "使用计算图,您将拥有极大的灵活性。您可以在移动应用、嵌入式设备和后端服务器等没有 Python 解释器的环境中使用 TensorFlow 计算图。当 TensorFlow 从 Python 导出计算图时,它会将这些计算图用作[已保存模型](./saved_model.ipynb)的格式。\n", "\n", "计算图的优化也十分轻松,允许编译器进行如下转换:\n", "\n", "- 通过在计算中折叠常量节点来静态推断张量的值*(“常量折叠”)*。\n", "- 分离独立的计算子部分,并在线程或设备之间进行拆分。\n", "- 通过消除通用子表达式来简化算术运算。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "o1x1EOD9GjnB" }, "source": [ "有一个完整的优化系统 [Grappler](./graph_optimization.ipynb) 来执行这种加速和其他加速。\n", "\n", "简而言之,计算图极为有用,它可以使 TensorFlow **快速**运行、**并行**运行以及**在多个设备上**高效运行。\n", "\n", "但是,为了方便起见,您仍然需要在 Python 中定义我们的机器学习模型(或其他计算),然后在需要时自动构造计算图。" ] }, { "cell_type": "markdown", "metadata": { "id": "k-6Qi0thw2i9" }, "source": [ "## 安装" ] }, { "cell_type": "markdown", "metadata": { "id": "0d1689fa928f" }, "source": [ "导入一些所需的库:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "goZwOXp_xyQj" }, "outputs": [], "source": [ "import tensorflow as tf\n", "import timeit\n", "from datetime import datetime" ] }, { "cell_type": "markdown", "metadata": { "id": "pSZebVuWxDXu" }, "source": [ "## 利用计算图\n", "\n", "您可以使用 `tf.function` 在 TensorFlow 中创建和运行计算图,要么作为直接调用,要么作为装饰器。`tf.function` 将一个常规函数作为输入并返回一个 `Function`。`Function` 是一个 Python 可调用对象,它通过 Python 函数构建 TensorFlow 计算图。您可以按照与其 Python 等价函数相同的方式使用 `Function`。\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "HKbLeJ1y0Umi" }, "outputs": [], "source": [ "# Define a Python function.\n", "def a_regular_function(x, y, b):\n", " x = tf.matmul(x, y)\n", " x = x + b\n", " return x\n", "\n", "# `a_function_that_uses_a_graph` is a TensorFlow `Function`.\n", "a_function_that_uses_a_graph = tf.function(a_regular_function)\n", "\n", "# Make some tensors.\n", "x1 = tf.constant([[1.0, 2.0]])\n", "y1 = tf.constant([[2.0], [3.0]])\n", "b1 = tf.constant(4.0)\n", "\n", "orig_value = a_regular_function(x1, y1, b1).numpy()\n", "# Call a `Function` like a Python function.\n", "tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()\n", "assert(orig_value == tf_function_value)" ] }, { "cell_type": "markdown", "metadata": { "id": "PNvuAYpdrTOf" }, "source": [ "在外部,一个 `Function` 看起来就像您使用 TensorFlow 运算编写的常规函数​​。然而,[在底层](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/eager/def_function.py),它*迥然不同*。一个 `Function` **在一个 API 后面封装了多个 tf.Graph**(在*多态性*部分了解详情)。这就是 `Function` 能够为您提供计算图执行的好处(例如速度和可部署性,请参阅上面的*计算图的优点*)。" ] }, { "cell_type": "markdown", "metadata": { "id": "MT7U8ozok0gV" }, "source": [ "`tf.function` 适用于一个函数*及其调用的所有其他函数*:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "rpz08iLplm9F" }, "outputs": [ { "data": { "text/plain": [ "array([[12.]], dtype=float32)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def inner_function(x, y, b):\n", " x = tf.matmul(x, y)\n", " x = x + b\n", " return x\n", "\n", "# Use the decorator to make `outer_function` a `Function`.\n", "@tf.function\n", "def outer_function(x):\n", " y = tf.constant([[2.0], [3.0]])\n", " b = tf.constant(4.0)\n", "\n", " return inner_function(x, y, b)\n", "\n", "# Note that the callable will create a graph that\n", "# includes `inner_function` as well as `outer_function`.\n", "outer_function(tf.constant([[1.0, 2.0]])).numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "P88fOr88qgCj" }, "source": [ "如果您使用过 TensorFlow 1.x,会发现根本不需要定义 `Placeholder` 或 `tf.Sesssion`。" ] }, { "cell_type": "markdown", "metadata": { "id": "wfeKf0Nr1OEK" }, "source": [ "### 将 Python 函数转换为计算图\n", "\n", "您使用 TensorFlow 编写的任何函数都将包含内置 TF 运算和 Python 逻辑的混合,例如 `if-then` 子句、循环、`break`、`return`、`continue` 等。虽然 TensorFlow 运算很容易被 `tf.Graph` 捕获,但特定于 Python 的逻辑需要经过额外的步骤才能成为计算图的一部分。`tf.function` 使用称为 AutoGraph (`tf.autograph`) 的库将 Python 代码转换为计算图生成代码。\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "PFObpff1BMEb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "First branch, with graph: 1\n", "Second branch, with graph: 0\n" ] } ], "source": [ "def simple_relu(x):\n", " if tf.greater(x, 0):\n", " return x\n", " else:\n", " return 0\n", "\n", "# `tf_simple_relu` is a TensorFlow `Function` that wraps `simple_relu`.\n", "tf_simple_relu = tf.function(simple_relu)\n", "\n", "print(\"First branch, with graph:\", tf_simple_relu(tf.constant(1)).numpy())\n", "print(\"Second branch, with graph:\", tf_simple_relu(tf.constant(-1)).numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "hO4DBUNZBMwQ" }, "source": [ "虽然您不太可能需要直接查看计算图,但您可以检查输出以验证确切的结果。这些结果都不太容易阅读,因此不需要看得太仔细!" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "lAKaat3w0gnn" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def tf__simple_relu(x):\n", " with ag__.FunctionScope('simple_relu', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:\n", " do_return = False\n", " retval_ = ag__.UndefinedReturnValue()\n", "\n", " def get_state():\n", " return (do_return, retval_)\n", "\n", " def set_state(vars_):\n", " nonlocal retval_, do_return\n", " do_return, retval_ = vars_\n", "\n", " def if_body():\n", " nonlocal retval_, do_return\n", " try:\n", " do_return = True\n", " retval_ = ag__.ld(x)\n", " except:\n", " do_return = False\n", " raise\n", "\n", " def else_body():\n", " nonlocal retval_, do_return\n", " try:\n", " do_return = True\n", " retval_ = 0\n", " except:\n", " do_return = False\n", " raise\n", " ag__.if_stmt(ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(x), 0), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)\n", " return fscope.ret(retval_, do_return)\n", "\n" ] } ], "source": [ "# This is the graph-generating output of AutoGraph.\n", "print(tf.autograph.to_code(simple_relu))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "8x6RAqza1UWf" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "node {\n", " name: \"x\"\n", " op: \"Placeholder\"\n", " attr {\n", " key: \"shape\"\n", " value {\n", " shape {\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", " attr {\n", " key: \"_user_specified_name\"\n", " value {\n", " s: \"x\"\n", " }\n", " }\n", "}\n", "node {\n", " name: \"Greater/y\"\n", " op: \"Const\"\n", " attr {\n", " key: \"value\"\n", " value {\n", " tensor {\n", " dtype: DT_INT32\n", " tensor_shape {\n", " }\n", " int_val: 0\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", "}\n", "node {\n", " name: \"Greater\"\n", " op: \"Greater\"\n", " input: \"x\"\n", " input: \"Greater/y\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", "}\n", "node {\n", " name: \"cond\"\n", " op: \"StatelessIf\"\n", " input: \"Greater\"\n", " input: \"x\"\n", " attr {\n", " key: \"then_branch\"\n", " value {\n", " func {\n", " name: \"cond_true_30\"\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"output_shapes\"\n", " value {\n", " list {\n", " shape {\n", " }\n", " shape {\n", " }\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"else_branch\"\n", " value {\n", " func {\n", " name: \"cond_false_31\"\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"_read_only_resource_inputs\"\n", " value {\n", " list {\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"_lower_using_switch_merge\"\n", " value {\n", " b: true\n", " }\n", " }\n", " attr {\n", " key: \"Tout\"\n", " value {\n", " list {\n", " type: DT_BOOL\n", " type: DT_INT32\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"Tin\"\n", " value {\n", " list {\n", " type: DT_INT32\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"Tcond\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", "}\n", "node {\n", " name: \"cond/Identity\"\n", " op: \"Identity\"\n", " input: \"cond\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", "}\n", "node {\n", " name: \"cond/Identity_1\"\n", " op: \"Identity\"\n", " input: \"cond:1\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", "}\n", "node {\n", " name: \"Identity\"\n", " op: \"Identity\"\n", " input: \"cond/Identity_1\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", "}\n", "versions {\n", " producer: 1882\n", " min_consumer: 12\n", "}\n", "library {\n", " function {\n", " signature {\n", " name: \"cond_false_31\"\n", " input_arg {\n", " name: \"cond_placeholder\"\n", " type: DT_INT32\n", " }\n", " output_arg {\n", " name: \"cond_identity\"\n", " type: DT_BOOL\n", " }\n", " output_arg {\n", " name: \"cond_identity_1\"\n", " type: DT_INT32\n", " }\n", " }\n", " attr {\n", " key: \"_construction_context\"\n", " value {\n", " s: \"kEagerRuntime\"\n", " }\n", " }\n", " arg_attr {\n", " key: 0\n", " value {\n", " attr {\n", " key: \"_output_shapes\"\n", " value {\n", " list {\n", " shape {\n", " }\n", " }\n", " }\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Const\"\n", " op: \"Const\"\n", " attr {\n", " key: \"value\"\n", " value {\n", " tensor {\n", " dtype: DT_BOOL\n", " tensor_shape {\n", " }\n", " bool_val: true\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Const_1\"\n", " op: \"Const\"\n", " attr {\n", " key: \"value\"\n", " value {\n", " tensor {\n", " dtype: DT_BOOL\n", " tensor_shape {\n", " }\n", " bool_val: true\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Const_2\"\n", " op: \"Const\"\n", " attr {\n", " key: \"value\"\n", " value {\n", " tensor {\n", " dtype: DT_INT32\n", " tensor_shape {\n", " }\n", " int_val: 0\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Const_3\"\n", " op: \"Const\"\n", " attr {\n", " key: \"value\"\n", " value {\n", " tensor {\n", " dtype: DT_BOOL\n", " tensor_shape {\n", " }\n", " bool_val: true\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Identity\"\n", " op: \"Identity\"\n", " input: \"cond/Const_3:output:0\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Const_4\"\n", " op: \"Const\"\n", " attr {\n", " key: \"value\"\n", " value {\n", " tensor {\n", " dtype: DT_INT32\n", " tensor_shape {\n", " }\n", " int_val: 0\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Identity_1\"\n", " op: \"Identity\"\n", " input: \"cond/Const_4:output:0\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", " }\n", " ret {\n", " key: \"cond_identity\"\n", " value: \"cond/Identity:output:0\"\n", " }\n", " ret {\n", " key: \"cond_identity_1\"\n", " value: \"cond/Identity_1:output:0\"\n", " }\n", " }\n", " function {\n", " signature {\n", " name: \"cond_true_30\"\n", " input_arg {\n", " name: \"cond_identity_1_x\"\n", " type: DT_INT32\n", " }\n", " output_arg {\n", " name: \"cond_identity\"\n", " type: DT_BOOL\n", " }\n", " output_arg {\n", " name: \"cond_identity_1\"\n", " type: DT_INT32\n", " }\n", " }\n", " attr {\n", " key: \"_construction_context\"\n", " value {\n", " s: \"kEagerRuntime\"\n", " }\n", " }\n", " arg_attr {\n", " key: 0\n", " value {\n", " attr {\n", " key: \"_user_specified_name\"\n", " value {\n", " s: \"x\"\n", " }\n", " }\n", " attr {\n", " key: \"_output_shapes\"\n", " value {\n", " list {\n", " shape {\n", " }\n", " }\n", " }\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Const\"\n", " op: \"Const\"\n", " attr {\n", " key: \"value\"\n", " value {\n", " tensor {\n", " dtype: DT_BOOL\n", " tensor_shape {\n", " }\n", " bool_val: true\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Identity\"\n", " op: \"Identity\"\n", " input: \"cond/Const:output:0\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Identity_1\"\n", " op: \"Identity\"\n", " input: \"cond_identity_1_x\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", " }\n", " ret {\n", " key: \"cond_identity\"\n", " value: \"cond/Identity:output:0\"\n", " }\n", " ret {\n", " key: \"cond_identity_1\"\n", " value: \"cond/Identity_1:output:0\"\n", " }\n", " }\n", "}\n", "\n" ] } ], "source": [ "# This is the graph itself.\n", "print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())" ] }, { "cell_type": "markdown", "metadata": { "id": "GZ4Ieg6tBE6l" }, "source": [ "大多数情况下,`tf.function` 无需特殊考虑即可工作。但是,有一些注意事项,`tf.function` 指南以及[完整的 AutoGraph 参考](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md)可以提供帮助。" ] }, { "cell_type": "markdown", "metadata": { "id": "sIpc_jfjEZEg" }, "source": [ "### 多态性:一个 `Function`,多个计算图\n", "\n", "`tf.Graph` 专门用于特定类型的输入(例如,具有特定 [`dtype`](https://tensorflow.google.cn/api_docs/python/tf/dtypes/DType) 的张量或具有相同 [`id()`](https://docs.python.org/3/library/functions.html#id%5D) 的对象)。\n", "\n", "每次使用一组无法由现有的任何计算图处理的参数(例如具有新 `dtypes` 或不兼容形状的参数)调用 `Function` 时,`Function` 都会创建一个专门用于这些新参数的新 `tf.Graph`。`tf.Graph` 输入的类型规范被称为它的**输入签名**或**签名**。如需详细了解何时生成新的 `tf.Graph` 以及如何控制它,请转到[使用 `tf.function` 提高性能](./function.ipynb)指南的*回溯规则*部分。\n", "\n", "`Function` 在 `ConcreteFunction` 中存储与该签名对应的 `tf.Graph`。`ConcreteFunction` 是围绕 `tf.Graph` 的封装容器。\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "LOASwhbvIv_T" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(5.5, shape=(), dtype=float32)\n", "tf.Tensor([1. 0.], shape=(2,), dtype=float32)\n", "tf.Tensor([3. 0.], shape=(2,), dtype=float32)\n" ] } ], "source": [ "@tf.function\n", "def my_relu(x):\n", " return tf.maximum(0., x)\n", "\n", "# `my_relu` creates new graphs as it observes more signatures.\n", "print(my_relu(tf.constant(5.5)))\n", "print(my_relu([1, -1]))\n", "print(my_relu(tf.constant([3., -3.])))" ] }, { "cell_type": "markdown", "metadata": { "id": "1qRtw7R4KL9X" }, "source": [ "如果已经使用该签名调用了 `Function`,则该 `Function` 不会创建新的 `tf.Graph`。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "TjjbnL5OKNDP" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(0.0, shape=(), dtype=float32)\n", "tf.Tensor([0. 1.], shape=(2,), dtype=float32)\n" ] } ], "source": [ "# These two calls do *not* create new graphs.\n", "print(my_relu(tf.constant(-2.5))) # Signature matches `tf.constant(5.5)`.\n", "print(my_relu(tf.constant([-1., 1.]))) # Signature matches `tf.constant([3., -3.])`." ] }, { "cell_type": "markdown", "metadata": { "id": "UohRmexhIpvQ" }, "source": [ "由于它由多个计算图提供支持,因此 `Function` 是**多态的**。这样,它便能够支持比单个 `tf.Graph` 可以表示的更多的输入类型,并优化每个 `tf.Graph` 来获得更出色的性能。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "dxzqebDYFmLy" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input Parameters:\n", " x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)\n", "Output Type:\n", " TensorSpec(shape=(), dtype=tf.float32, name=None)\n", "Captures:\n", " None\n", "\n", "Input Parameters:\n", " x (POSITIONAL_OR_KEYWORD): List[Literal[1], Literal[-1]]\n", "Output Type:\n", " TensorSpec(shape=(2,), dtype=tf.float32, name=None)\n", "Captures:\n", " None\n", "\n", "Input Parameters:\n", " x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(2,), dtype=tf.float32, name=None)\n", "Output Type:\n", " TensorSpec(shape=(2,), dtype=tf.float32, name=None)\n", "Captures:\n", " None\n" ] } ], "source": [ "# There are three `ConcreteFunction`s (one for each graph) in `my_relu`.\n", "# The `ConcreteFunction` also knows the return type and shape!\n", "print(my_relu.pretty_printed_concrete_signatures())" ] }, { "cell_type": "markdown", "metadata": { "id": "V11zkxU22XeD" }, "source": [ "## 使用 `tf.function`\n", "\n", "到目前为止,您已经学习了如何使用 `tf.function` 作为装饰器或包装容器将 Python 函数简单地转换为计算图。但在实践中,让 `tf.function` 正常工作可能相当棘手!在下面的部分中,您将了解如何使用 `tf.function` 使代码按预期工作。" ] }, { "cell_type": "markdown", "metadata": { "id": "yp_n0B5-P0RU" }, "source": [ "### 计算图执行与 Eager Execution\n", "\n", "`Function` 函数中的代码既能以 Eager 模式执行,也可以作为计算图执行。默认情况下,`Function` 将其代码作为计算图执行:\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "_R0BOvBFxqVZ" }, "outputs": [], "source": [ "@tf.function\n", "def get_MSE(y_true, y_pred):\n", " sq_diff = tf.pow(y_true - y_pred, 2)\n", " return tf.reduce_mean(sq_diff)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "zikMVPGhmDET" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([4 7 4 9 0], shape=(5,), dtype=int32)\n", "tf.Tensor([7 4 1 6 3], shape=(5,), dtype=int32)\n" ] } ], "source": [ "y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)\n", "y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)\n", "print(y_true)\n", "print(y_pred)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "07r08Dh158ft" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_MSE(y_true, y_pred)" ] }, { "cell_type": "markdown", "metadata": { "id": "cyZNCRcQorGO" }, "source": [ "要验证 `Function` 计算图是否与其等效 Python 函数执行相同的计算,您可以使用 `tf.config.run_functions_eagerly(True)` 使其以 Eager 模式执行。这是一个开关,用于关闭 `Function` 创建和运行计算图的能力,无需正常执行代码。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "lKoF6NjPoI8w" }, "outputs": [], "source": [ "tf.config.run_functions_eagerly(True)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "9ZLqTyn0oKeM" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_MSE(y_true, y_pred)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "cV7daQW9odn-" }, "outputs": [], "source": [ "# Don't forget to set it back when you are done.\n", "tf.config.run_functions_eagerly(False)" ] }, { "cell_type": "markdown", "metadata": { "id": "DKT3YBsqy0x4" }, "source": [ "但是,`Function` 在计算图执行和 Eager Execution 下的行为可能有所不同。Python [`print`](https://docs.python.org/3/library/functions.html#print) 函数是这两种模式不同之处的一个示例。我们看看当您将 `print` 语句插入到您的函数并重复调用它时会发生什么。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "BEJeVeBEoGjV" }, "outputs": [], "source": [ "@tf.function\n", "def get_MSE(y_true, y_pred):\n", " print(\"Calculating MSE!\")\n", " sq_diff = tf.pow(y_true - y_pred, 2)\n", " return tf.reduce_mean(sq_diff)" ] }, { "cell_type": "markdown", "metadata": { "id": "3sWTGwX3BzP1" }, "source": [ "观察打印的内容:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "3rJIeBg72T9n" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Calculating MSE!\n" ] } ], "source": [ "error = get_MSE(y_true, y_pred)\n", "error = get_MSE(y_true, y_pred)\n", "error = get_MSE(y_true, y_pred)" ] }, { "cell_type": "markdown", "metadata": { "id": "WLMXk1uxKQ44" }, "source": [ "输出很令人惊讶?尽管 **`get_MSE` 被调用了 *3* 次,但它只打印了一次。**\n", "\n", "解释一下,`print` 语句在 `Function` 运行原始代码时执行,以便在称为“跟踪”(请参阅 [`tf.function` 指南](./function.ipynb)的*跟踪*部分)的过程中创建计算图。跟踪将 TensorFlow 运算捕获到计算图中,而计算图中未捕获 `print`。随后对全部三个调用执行该计算图,**而没有再次运行 Python 代码**。\n", "\n", "作为健全性检查,我们关闭计算图执行来比较:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "oFSxRtcptYpe" }, "outputs": [], "source": [ "# Now, globally set everything to run eagerly to force eager execution.\n", "tf.config.run_functions_eagerly(True)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "qYxrAtvzNgHR" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Calculating MSE!\n", "Calculating MSE!\n", "Calculating MSE!\n" ] } ], "source": [ "# Observe what is printed below.\n", "error = get_MSE(y_true, y_pred)\n", "error = get_MSE(y_true, y_pred)\n", "error = get_MSE(y_true, y_pred)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "_Df6ynXcAaup" }, "outputs": [], "source": [ "tf.config.run_functions_eagerly(False)" ] }, { "cell_type": "markdown", "metadata": { "id": "PUR7qC_bquCn" }, "source": [ "`print` 是 *Python 的副作用*,在将函数转换为 `Function` 时,您还应注意其他差异。请在[使用 `tf.function` 提升性能](./function.ipynb)指南中的*限制*部分中了解详情。" ] }, { "cell_type": "markdown", "metadata": { "id": "oTZJfV_tccVp" }, "source": [ "注:如果您想同时在 Eager Execution 和计算图执行中打印值,请改用 `tf.print`。" ] }, { "cell_type": "markdown", "metadata": { "id": "rMT_Xf5yKn9o" }, "source": [ "### 非严格执行\n", "\n", "\n", "\n", "计算图执行仅执行产生可观察效果所需的运算,这包括:\n", "\n", "- 函数的返回值\n", "- 已记录的著名副作用,例如:\n", " - 输入/输出运算,如 `tf.print`\n", " - 调试运算,如 `tf.debugging` 中的断言函数\n", " - `tf.Variable` 的突变\n", "\n", "这种行为通常称为“非严格执行”,与 Eager Execution 不同,后者会分步执行所有程序运算,无论是否需要。\n", "\n", "特别是,运行时错误检查不计为可观察效果。如果一个运算因为不必要而被跳过,它不会引发任何运行时错误。\n", "\n", "在下面的示例中,计算图执行期间跳过了“不必要的”运算 `tf.gather`,因此不会像在 Eager Execution 中那样引发运行时错误 `InvalidArgumentError`。切勿依赖执行计算图时引发的错误。" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "OdN0nKlUwj7M" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([0.], shape=(1,), dtype=float32)\n" ] } ], "source": [ "def unused_return_eager(x):\n", " # Get index 1 will fail when `len(x) == 1`\n", " tf.gather(x, [1]) # unused \n", " return x\n", "\n", "try:\n", " print(unused_return_eager(tf.constant([0.0])))\n", "except tf.errors.InvalidArgumentError as e:\n", " # All operations are run during eager execution so an error is raised.\n", " print(f'{type(e).__name__}: {e}')" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "d80Fob4MwhTs" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([0.], shape=(1,), dtype=float32)\n" ] } ], "source": [ "@tf.function\n", "def unused_return_graph(x):\n", " tf.gather(x, [1]) # unused\n", " return x\n", "\n", "# Only needed operations are run during graph execution. The error is not raised.\n", "print(unused_return_graph(tf.constant([0.0])))" ] }, { "cell_type": "markdown", "metadata": { "id": "def6MupG9R0O" }, "source": [ "### `tf.function` 最佳做法\n", "\n", "可能需要花一些时间来习惯 `Function` 的行为。为了快速上手,初次使用的用户应当使用 `@tf.function` 来装饰简单函数,以获得从 Eager Execution 转换到计算图执行的经验。\n", "\n", "*为 `tf.function` 设计*可能是您编写与计算图兼容的 TensorFlow 程序的最佳选择。以下是一些提示:\n", "\n", "- 尽早并经常使用 `tf.config.run_functions_eagerly` 在 Eager Execution 和计算图执行之间切换,以查明两种模式是否/何时出现分歧。\n", "- 在 Python 函数外部创建 `tf.Variable` 并在内部修改它们。对使用 `tf.Variable` 的对象也如此操作,例如 `tf.keras.layers`、`tf.keras.Model` 和 `tf.keras.optimizers`。\n", "- 避免编写依赖于外部 Python 变量的函数,不包括 `tf.Variable` 和 Keras 对象。请在 [tf.function 指南](./function.ipynb)的依赖于 Python 全局变量和自由变量中了解详情。\n", "- 尽可能编写以张量和其他 TensorFlow 类型作为输入的函数。您可以传入其他对象类型,但务必小心!请在 [tf.function 指南](./function.ipynb)的依赖于 Python 对象中了解详情。\n", "- 在 `tf.function` 下包含尽可能多的计算以最大程度提高性能收益。例如,装饰整个训练步骤或整个训练循环。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ViM3oBJVJrDx" }, "source": [ "## 见证加速" ] }, { "cell_type": "markdown", "metadata": { "id": "A6NHDp7vAKcJ" }, "source": [ "`tf.function` 通常可以提高代码的性能,但加速的程度取决于您运行的计算种类。小型计算可能以调用计算图的开销为主。您可以按如下方式衡量性能上的差异:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "jr7p1BBjauPK" }, "outputs": [], "source": [ "x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)\n", "\n", "def power(x, y):\n", " result = tf.eye(10, dtype=tf.dtypes.int32)\n", " for _ in range(y):\n", " result = tf.matmul(x, result)\n", " return result" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "ms2yJyAnUYxK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Eager execution: 8.420402663061395 seconds\n" ] } ], "source": [ "print(\"Eager execution:\", timeit.timeit(lambda: power(x, 100), number=1000), \"seconds\")" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "gUB2mTyRYRAe" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Graph execution: 1.5151319501455873 seconds\n" ] } ], "source": [ "power_as_graph = tf.function(power)\n", "print(\"Graph execution:\", timeit.timeit(lambda: power_as_graph(x, 100), number=1000), \"seconds\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Q1Pfo5YwwILi" }, "source": [ "`tf.function` 通常用于加速训练循环,您可以在使用 Keras [从头开始编写训练循环](https://tensorflow.google.cn/guide/keras/writing_a_training_loop_from_scratch)指南的使用 tf.function 加速训练步骤部分中了解详情。\n", "\n", "注:您也可以尝试 tf.function(jit_compile=True) 以获得更显著的性能提升,特别是当您的代码非常依赖于 TF 控制流并且使用许多小张量时。请在 [XLA 概述](https://tensorflow.google.cn/xla)的使用 tf.function(jit_compile=True) 显式编译部分中了解详情。" ] }, { "cell_type": "markdown", "metadata": { "id": "sm0bNFp8PX53" }, "source": [ "### 性能和权衡\n", "\n", "计算图可以加速您的代码,但创建它们的过程有一些开销。对于某些函数,计算图的创建比计算图的执行花费更长的时间。**这种投资通常会随着后续执行的性能提升而迅速得到回报,但重要的是要注意,由于跟踪的原因,任何大型模型训练的前几步可能会较慢。**\n", "\n", "无论您的模型有多大,您都应该避免频繁跟踪。[tf.function 指南](./function.ipynb)在*控制重新跟踪*部分探讨了如何设置输入规范并使用张量参数来避免重新跟踪。如果您发现自己的性能异常糟糕,最好检查一下是否发生了意外重新跟踪。" ] }, { "cell_type": "markdown", "metadata": { "id": "F4InDaTjwmBA" }, "source": [ "## `Function` 何时进行跟踪?\n", "\n", "要确定您的 `Function` 何时进行跟踪,请在其代码中添加一条 `print` 语句。根据经验,`Function` 将在每次跟踪时执行该 `print` 语句。" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "id": "hXtwlbpofLgW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing!\n", "tf.Tensor(6, shape=(), dtype=int32)\n", "tf.Tensor(11, shape=(), dtype=int32)\n" ] } ], "source": [ "@tf.function\n", "def a_function_with_python_side_effect(x):\n", " print(\"Tracing!\") # An eager-only side effect.\n", " return x * x + tf.constant(2)\n", "\n", "# This is traced the first time.\n", "print(a_function_with_python_side_effect(tf.constant(2)))\n", "# The second time through, you won't see the side effect.\n", "print(a_function_with_python_side_effect(tf.constant(3)))" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "inzSg8yzfNjl" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing!\n", "tf.Tensor(6, shape=(), dtype=int32)\n", "Tracing!\n", "tf.Tensor(11, shape=(), dtype=int32)\n" ] } ], "source": [ "# This retraces each time the Python argument changes,\n", "# as a Python argument could be an epoch count or other\n", "# hyperparameter.\n", "print(a_function_with_python_side_effect(2))\n", "print(a_function_with_python_side_effect(3))" ] }, { "cell_type": "markdown", "metadata": { "id": "rtN8NW6AfKye" }, "source": [ "新的 Python 参数总是会触发新计算图的创建,因此需要额外的跟踪。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "D1kbr5ocpS6R" }, "source": [ "## 后续步骤\n", "\n", "您可以在 API 参考页面上详细了解 `tf.function`,并遵循使用 `tf.function` 提升性能指南。" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "intro_to_graphs.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "xxx", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 0 }