{ "cells": [ { "cell_type": "markdown", "metadata": { "cellView": "form", "id": "yOYx6tzSnWQ3" }, "source": [ "````{admonition} Copyright 2020 The TensorFlow Authors.\n", "```\n", "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "```\n", "````" ] }, { "cell_type": "markdown", "metadata": { "id": "6xgB0Oz5eGSQ" }, "source": [ "# 计算图和 `tf.function` 简介" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import set_env\n", "!export TF_FORCE_GPU_ALLOW_GROWTH=true" ] }, { "cell_type": "markdown", "metadata": { "id": "w4zzZVZtQb1w" }, "source": [ "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
tf.Graph
**(在*多态性*部分了解详情)。这就是 `Function` 能够为您提供计算图执行的好处(例如速度和可部署性,请参阅上面的*计算图的优点*)。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MT7U8ozok0gV"
},
"source": [
"`tf.function` 适用于一个函数*及其调用的所有其他函数*:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "rpz08iLplm9F"
},
"outputs": [
{
"data": {
"text/plain": [
"array([[12.]], dtype=float32)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def inner_function(x, y, b):\n",
" x = tf.matmul(x, y)\n",
" x = x + b\n",
" return x\n",
"\n",
"# Use the decorator to make `outer_function` a `Function`.\n",
"@tf.function\n",
"def outer_function(x):\n",
" y = tf.constant([[2.0], [3.0]])\n",
" b = tf.constant(4.0)\n",
"\n",
" return inner_function(x, y, b)\n",
"\n",
"# Note that the callable will create a graph that\n",
"# includes `inner_function` as well as `outer_function`.\n",
"outer_function(tf.constant([[1.0, 2.0]])).numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P88fOr88qgCj"
},
"source": [
"如果您使用过 TensorFlow 1.x,会发现根本不需要定义 `Placeholder` 或 `tf.Sesssion`。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wfeKf0Nr1OEK"
},
"source": [
"### 将 Python 函数转换为计算图\n",
"\n",
"您使用 TensorFlow 编写的任何函数都将包含内置 TF 运算和 Python 逻辑的混合,例如 `if-then` 子句、循环、`break`、`return`、`continue` 等。虽然 TensorFlow 运算很容易被 `tf.Graph` 捕获,但特定于 Python 的逻辑需要经过额外的步骤才能成为计算图的一部分。`tf.function` 使用称为 AutoGraph (`tf.autograph`) 的库将 Python 代码转换为计算图生成代码。\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "PFObpff1BMEb"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"First branch, with graph: 1\n",
"Second branch, with graph: 0\n"
]
}
],
"source": [
"def simple_relu(x):\n",
" if tf.greater(x, 0):\n",
" return x\n",
" else:\n",
" return 0\n",
"\n",
"# `tf_simple_relu` is a TensorFlow `Function` that wraps `simple_relu`.\n",
"tf_simple_relu = tf.function(simple_relu)\n",
"\n",
"print(\"First branch, with graph:\", tf_simple_relu(tf.constant(1)).numpy())\n",
"print(\"Second branch, with graph:\", tf_simple_relu(tf.constant(-1)).numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hO4DBUNZBMwQ"
},
"source": [
"虽然您不太可能需要直接查看计算图,但您可以检查输出以验证确切的结果。这些结果都不太容易阅读,因此不需要看得太仔细!"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "lAKaat3w0gnn"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def tf__simple_relu(x):\n",
" with ag__.FunctionScope('simple_relu', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:\n",
" do_return = False\n",
" retval_ = ag__.UndefinedReturnValue()\n",
"\n",
" def get_state():\n",
" return (do_return, retval_)\n",
"\n",
" def set_state(vars_):\n",
" nonlocal retval_, do_return\n",
" do_return, retval_ = vars_\n",
"\n",
" def if_body():\n",
" nonlocal retval_, do_return\n",
" try:\n",
" do_return = True\n",
" retval_ = ag__.ld(x)\n",
" except:\n",
" do_return = False\n",
" raise\n",
"\n",
" def else_body():\n",
" nonlocal retval_, do_return\n",
" try:\n",
" do_return = True\n",
" retval_ = 0\n",
" except:\n",
" do_return = False\n",
" raise\n",
" ag__.if_stmt(ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(x), 0), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)\n",
" return fscope.ret(retval_, do_return)\n",
"\n"
]
}
],
"source": [
"# This is the graph-generating output of AutoGraph.\n",
"print(tf.autograph.to_code(simple_relu))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "8x6RAqza1UWf"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"node {\n",
" name: \"x\"\n",
" op: \"Placeholder\"\n",
" attr {\n",
" key: \"shape\"\n",
" value {\n",
" shape {\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" attr {\n",
" key: \"_user_specified_name\"\n",
" value {\n",
" s: \"x\"\n",
" }\n",
" }\n",
"}\n",
"node {\n",
" name: \"Greater/y\"\n",
" op: \"Const\"\n",
" attr {\n",
" key: \"value\"\n",
" value {\n",
" tensor {\n",
" dtype: DT_INT32\n",
" tensor_shape {\n",
" }\n",
" int_val: 0\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
"}\n",
"node {\n",
" name: \"Greater\"\n",
" op: \"Greater\"\n",
" input: \"x\"\n",
" input: \"Greater/y\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
"}\n",
"node {\n",
" name: \"cond\"\n",
" op: \"StatelessIf\"\n",
" input: \"Greater\"\n",
" input: \"x\"\n",
" attr {\n",
" key: \"then_branch\"\n",
" value {\n",
" func {\n",
" name: \"cond_true_30\"\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"output_shapes\"\n",
" value {\n",
" list {\n",
" shape {\n",
" }\n",
" shape {\n",
" }\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"else_branch\"\n",
" value {\n",
" func {\n",
" name: \"cond_false_31\"\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"_read_only_resource_inputs\"\n",
" value {\n",
" list {\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"_lower_using_switch_merge\"\n",
" value {\n",
" b: true\n",
" }\n",
" }\n",
" attr {\n",
" key: \"Tout\"\n",
" value {\n",
" list {\n",
" type: DT_BOOL\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"Tin\"\n",
" value {\n",
" list {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"Tcond\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
"}\n",
"node {\n",
" name: \"cond/Identity\"\n",
" op: \"Identity\"\n",
" input: \"cond\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
"}\n",
"node {\n",
" name: \"cond/Identity_1\"\n",
" op: \"Identity\"\n",
" input: \"cond:1\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
"}\n",
"node {\n",
" name: \"Identity\"\n",
" op: \"Identity\"\n",
" input: \"cond/Identity_1\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
"}\n",
"versions {\n",
" producer: 1882\n",
" min_consumer: 12\n",
"}\n",
"library {\n",
" function {\n",
" signature {\n",
" name: \"cond_false_31\"\n",
" input_arg {\n",
" name: \"cond_placeholder\"\n",
" type: DT_INT32\n",
" }\n",
" output_arg {\n",
" name: \"cond_identity\"\n",
" type: DT_BOOL\n",
" }\n",
" output_arg {\n",
" name: \"cond_identity_1\"\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" attr {\n",
" key: \"_construction_context\"\n",
" value {\n",
" s: \"kEagerRuntime\"\n",
" }\n",
" }\n",
" arg_attr {\n",
" key: 0\n",
" value {\n",
" attr {\n",
" key: \"_output_shapes\"\n",
" value {\n",
" list {\n",
" shape {\n",
" }\n",
" }\n",
" }\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Const\"\n",
" op: \"Const\"\n",
" attr {\n",
" key: \"value\"\n",
" value {\n",
" tensor {\n",
" dtype: DT_BOOL\n",
" tensor_shape {\n",
" }\n",
" bool_val: true\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Const_1\"\n",
" op: \"Const\"\n",
" attr {\n",
" key: \"value\"\n",
" value {\n",
" tensor {\n",
" dtype: DT_BOOL\n",
" tensor_shape {\n",
" }\n",
" bool_val: true\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Const_2\"\n",
" op: \"Const\"\n",
" attr {\n",
" key: \"value\"\n",
" value {\n",
" tensor {\n",
" dtype: DT_INT32\n",
" tensor_shape {\n",
" }\n",
" int_val: 0\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Const_3\"\n",
" op: \"Const\"\n",
" attr {\n",
" key: \"value\"\n",
" value {\n",
" tensor {\n",
" dtype: DT_BOOL\n",
" tensor_shape {\n",
" }\n",
" bool_val: true\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Identity\"\n",
" op: \"Identity\"\n",
" input: \"cond/Const_3:output:0\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Const_4\"\n",
" op: \"Const\"\n",
" attr {\n",
" key: \"value\"\n",
" value {\n",
" tensor {\n",
" dtype: DT_INT32\n",
" tensor_shape {\n",
" }\n",
" int_val: 0\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Identity_1\"\n",
" op: \"Identity\"\n",
" input: \"cond/Const_4:output:0\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" }\n",
" ret {\n",
" key: \"cond_identity\"\n",
" value: \"cond/Identity:output:0\"\n",
" }\n",
" ret {\n",
" key: \"cond_identity_1\"\n",
" value: \"cond/Identity_1:output:0\"\n",
" }\n",
" }\n",
" function {\n",
" signature {\n",
" name: \"cond_true_30\"\n",
" input_arg {\n",
" name: \"cond_identity_1_x\"\n",
" type: DT_INT32\n",
" }\n",
" output_arg {\n",
" name: \"cond_identity\"\n",
" type: DT_BOOL\n",
" }\n",
" output_arg {\n",
" name: \"cond_identity_1\"\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" attr {\n",
" key: \"_construction_context\"\n",
" value {\n",
" s: \"kEagerRuntime\"\n",
" }\n",
" }\n",
" arg_attr {\n",
" key: 0\n",
" value {\n",
" attr {\n",
" key: \"_user_specified_name\"\n",
" value {\n",
" s: \"x\"\n",
" }\n",
" }\n",
" attr {\n",
" key: \"_output_shapes\"\n",
" value {\n",
" list {\n",
" shape {\n",
" }\n",
" }\n",
" }\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Const\"\n",
" op: \"Const\"\n",
" attr {\n",
" key: \"value\"\n",
" value {\n",
" tensor {\n",
" dtype: DT_BOOL\n",
" tensor_shape {\n",
" }\n",
" bool_val: true\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Identity\"\n",
" op: \"Identity\"\n",
" input: \"cond/Const:output:0\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Identity_1\"\n",
" op: \"Identity\"\n",
" input: \"cond_identity_1_x\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" }\n",
" ret {\n",
" key: \"cond_identity\"\n",
" value: \"cond/Identity:output:0\"\n",
" }\n",
" ret {\n",
" key: \"cond_identity_1\"\n",
" value: \"cond/Identity_1:output:0\"\n",
" }\n",
" }\n",
"}\n",
"\n"
]
}
],
"source": [
"# This is the graph itself.\n",
"print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GZ4Ieg6tBE6l"
},
"source": [
"大多数情况下,`tf.function` 无需特殊考虑即可工作。但是,有一些注意事项,`tf.function` 指南以及[完整的 AutoGraph 参考](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md)可以提供帮助。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sIpc_jfjEZEg"
},
"source": [
"### 多态性:一个 `Function`,多个计算图\n",
"\n",
"`tf.Graph` 专门用于特定类型的输入(例如,具有特定 [`dtype`](https://tensorflow.google.cn/api_docs/python/tf/dtypes/DType) 的张量或具有相同 [`id()`](https://docs.python.org/3/library/functions.html#id%5D) 的对象)。\n",
"\n",
"每次使用一组无法由现有的任何计算图处理的参数(例如具有新 `dtypes` 或不兼容形状的参数)调用 `Function` 时,`Function` 都会创建一个专门用于这些新参数的新 `tf.Graph`。`tf.Graph` 输入的类型规范被称为它的**输入签名**或**签名**。如需详细了解何时生成新的 `tf.Graph` 以及如何控制它,请转到[使用 `tf.function` 提高性能](./function.ipynb)指南的*回溯规则*部分。\n",
"\n",
"`Function` 在 `ConcreteFunction` 中存储与该签名对应的 `tf.Graph`。`ConcreteFunction` 是围绕 `tf.Graph` 的封装容器。\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "LOASwhbvIv_T"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(5.5, shape=(), dtype=float32)\n",
"tf.Tensor([1. 0.], shape=(2,), dtype=float32)\n",
"tf.Tensor([3. 0.], shape=(2,), dtype=float32)\n"
]
}
],
"source": [
"@tf.function\n",
"def my_relu(x):\n",
" return tf.maximum(0., x)\n",
"\n",
"# `my_relu` creates new graphs as it observes more signatures.\n",
"print(my_relu(tf.constant(5.5)))\n",
"print(my_relu([1, -1]))\n",
"print(my_relu(tf.constant([3., -3.])))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1qRtw7R4KL9X"
},
"source": [
"如果已经使用该签名调用了 `Function`,则该 `Function` 不会创建新的 `tf.Graph`。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "TjjbnL5OKNDP"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(0.0, shape=(), dtype=float32)\n",
"tf.Tensor([0. 1.], shape=(2,), dtype=float32)\n"
]
}
],
"source": [
"# These two calls do *not* create new graphs.\n",
"print(my_relu(tf.constant(-2.5))) # Signature matches `tf.constant(5.5)`.\n",
"print(my_relu(tf.constant([-1., 1.]))) # Signature matches `tf.constant([3., -3.])`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UohRmexhIpvQ"
},
"source": [
"由于它由多个计算图提供支持,因此 `Function` 是**多态的**。这样,它便能够支持比单个 `tf.Graph` 可以表示的更多的输入类型,并优化每个 `tf.Graph` 来获得更出色的性能。"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "dxzqebDYFmLy"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input Parameters:\n",
" x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)\n",
"Output Type:\n",
" TensorSpec(shape=(), dtype=tf.float32, name=None)\n",
"Captures:\n",
" None\n",
"\n",
"Input Parameters:\n",
" x (POSITIONAL_OR_KEYWORD): List[Literal[1], Literal[-1]]\n",
"Output Type:\n",
" TensorSpec(shape=(2,), dtype=tf.float32, name=None)\n",
"Captures:\n",
" None\n",
"\n",
"Input Parameters:\n",
" x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(2,), dtype=tf.float32, name=None)\n",
"Output Type:\n",
" TensorSpec(shape=(2,), dtype=tf.float32, name=None)\n",
"Captures:\n",
" None\n"
]
}
],
"source": [
"# There are three `ConcreteFunction`s (one for each graph) in `my_relu`.\n",
"# The `ConcreteFunction` also knows the return type and shape!\n",
"print(my_relu.pretty_printed_concrete_signatures())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V11zkxU22XeD"
},
"source": [
"## 使用 `tf.function`\n",
"\n",
"到目前为止,您已经学习了如何使用 `tf.function` 作为装饰器或包装容器将 Python 函数简单地转换为计算图。但在实践中,让 `tf.function` 正常工作可能相当棘手!在下面的部分中,您将了解如何使用 `tf.function` 使代码按预期工作。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yp_n0B5-P0RU"
},
"source": [
"### 计算图执行与 Eager Execution\n",
"\n",
"`Function` 函数中的代码既能以 Eager 模式执行,也可以作为计算图执行。默认情况下,`Function` 将其代码作为计算图执行:\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "_R0BOvBFxqVZ"
},
"outputs": [],
"source": [
"@tf.function\n",
"def get_MSE(y_true, y_pred):\n",
" sq_diff = tf.pow(y_true - y_pred, 2)\n",
" return tf.reduce_mean(sq_diff)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "zikMVPGhmDET"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([4 7 4 9 0], shape=(5,), dtype=int32)\n",
"tf.Tensor([7 4 1 6 3], shape=(5,), dtype=int32)\n"
]
}
],
"source": [
"y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)\n",
"y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)\n",
"print(y_true)\n",
"print(y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "07r08Dh158ft"
},
"outputs": [
{
"data": {
"text/plain": [
"tf.function
指南](./function.ipynb)的依赖于 Python 全局变量和自由变量中了解详情。\n",
"- 尽可能编写以张量和其他 TensorFlow 类型作为输入的函数。您可以传入其他对象类型,但务必小心!请在 [tf.function
指南](./function.ipynb)的依赖于 Python 对象中了解详情。\n",
"- 在 `tf.function` 下包含尽可能多的计算以最大程度提高性能收益。例如,装饰整个训练步骤或整个训练循环。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ViM3oBJVJrDx"
},
"source": [
"## 见证加速"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A6NHDp7vAKcJ"
},
"source": [
"`tf.function` 通常可以提高代码的性能,但加速的程度取决于您运行的计算种类。小型计算可能以调用计算图的开销为主。您可以按如下方式衡量性能上的差异:"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"id": "jr7p1BBjauPK"
},
"outputs": [],
"source": [
"x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)\n",
"\n",
"def power(x, y):\n",
" result = tf.eye(10, dtype=tf.dtypes.int32)\n",
" for _ in range(y):\n",
" result = tf.matmul(x, result)\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"id": "ms2yJyAnUYxK"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Eager execution: 8.420402663061395 seconds\n"
]
}
],
"source": [
"print(\"Eager execution:\", timeit.timeit(lambda: power(x, 100), number=1000), \"seconds\")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"id": "gUB2mTyRYRAe"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Graph execution: 1.5151319501455873 seconds\n"
]
}
],
"source": [
"power_as_graph = tf.function(power)\n",
"print(\"Graph execution:\", timeit.timeit(lambda: power_as_graph(x, 100), number=1000), \"seconds\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q1Pfo5YwwILi"
},
"source": [
"`tf.function` 通常用于加速训练循环,您可以在使用 Keras [从头开始编写训练循环](https://tensorflow.google.cn/guide/keras/writing_a_training_loop_from_scratch)指南的使用 tf.function
加速训练步骤部分中了解详情。\n",
"\n",
"注:您也可以尝试 tf.function(jit_compile=True)
以获得更显著的性能提升,特别是当您的代码非常依赖于 TF 控制流并且使用许多小张量时。请在 [XLA 概述](https://tensorflow.google.cn/xla)的使用 tf.function(jit_compile=True)
显式编译部分中了解详情。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sm0bNFp8PX53"
},
"source": [
"### 性能和权衡\n",
"\n",
"计算图可以加速您的代码,但创建它们的过程有一些开销。对于某些函数,计算图的创建比计算图的执行花费更长的时间。**这种投资通常会随着后续执行的性能提升而迅速得到回报,但重要的是要注意,由于跟踪的原因,任何大型模型训练的前几步可能会较慢。**\n",
"\n",
"无论您的模型有多大,您都应该避免频繁跟踪。[tf.function
指南](./function.ipynb)在*控制重新跟踪*部分探讨了如何设置输入规范并使用张量参数来避免重新跟踪。如果您发现自己的性能异常糟糕,最好检查一下是否发生了意外重新跟踪。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "F4InDaTjwmBA"
},
"source": [
"## `Function` 何时进行跟踪?\n",
"\n",
"要确定您的 `Function` 何时进行跟踪,请在其代码中添加一条 `print` 语句。根据经验,`Function` 将在每次跟踪时执行该 `print` 语句。"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"id": "hXtwlbpofLgW"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing!\n",
"tf.Tensor(6, shape=(), dtype=int32)\n",
"tf.Tensor(11, shape=(), dtype=int32)\n"
]
}
],
"source": [
"@tf.function\n",
"def a_function_with_python_side_effect(x):\n",
" print(\"Tracing!\") # An eager-only side effect.\n",
" return x * x + tf.constant(2)\n",
"\n",
"# This is traced the first time.\n",
"print(a_function_with_python_side_effect(tf.constant(2)))\n",
"# The second time through, you won't see the side effect.\n",
"print(a_function_with_python_side_effect(tf.constant(3)))"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"id": "inzSg8yzfNjl"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing!\n",
"tf.Tensor(6, shape=(), dtype=int32)\n",
"Tracing!\n",
"tf.Tensor(11, shape=(), dtype=int32)\n"
]
}
],
"source": [
"# This retraces each time the Python argument changes,\n",
"# as a Python argument could be an epoch count or other\n",
"# hyperparameter.\n",
"print(a_function_with_python_side_effect(2))\n",
"print(a_function_with_python_side_effect(3))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rtN8NW6AfKye"
},
"source": [
"新的 Python 参数总是会触发新计算图的创建,因此需要额外的跟踪。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D1kbr5ocpS6R"
},
"source": [
"## 后续步骤\n",
"\n",
"您可以在 API 参考页面上详细了解 `tf.function`,并遵循使用 `tf.function` 提升性能指南。"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "intro_to_graphs.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "xxx",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}