{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 自动微分机制\n", "\n", "{guilabel}`参考`:[notes/autograd](https://pytorch.org/docs/stable/notes/autograd.html)\n", "\n", "本文将概述 `autograd` 的工作方式并记录运算。完全没有必要理解所有这些,但建议您熟悉它,因为它将帮助您编写更有效、更干净的程序,并可以帮助您调试。\n", "\n", "## `autograd` 如何编码历史\n", "\n", "Autograd 是反向自动微分系统。从概念上讲,`autograd` 记录了图,该图记录了在执行运算时创建数据的所有运算,从而为您提供了一个有向无环图,其叶是输入张量,根是输出张量。通过从根到叶跟踪这个图,可以使用链式法则自动计算梯度。\n", "\n", "在内部,`autograd` 将这个图表示为 {class}`~torch.autograd.Function` 对象(实际上是表达式)的图,可以借助 {func}`apply` 来计算对图求值的结果。在计算 forward 传递时,`autograd` 同时申请计算,并构建表示计算梯度(每个 {class}`torch.Tensor` 的 `.grad_fn` 属性是这个图的入口点)。当前向传递完成时,在向后传递中计算这个图的梯度。\n", "\n", "```{important}\n", "图在每次迭代时都是从头开始重新创建的,这正是允许使用任意 Python 控制流语句的原因,这些语句可以在每次迭代时改变图的总体形状和大小。在启动训练之前,您不必对所有可能的路径进行编码——您所运行的是您所区分的。\n", "```\n", "\n", "### 已保存的张量\n", "\n", "有些运算需要在向前传播期间保存中间结果,以便执行 backward 传播。例如,函数 $x\\mapsto x^2$ 保存输入的 $x$ 来计算梯度。\n", "\n", "当定义自定义 Python 函数时,你可以使用 {func}`~torch.autograd.Function.save_for_backward` 在正向传播时保存张量,在后向传播时使用 `saved_tensor` 来检索它们。有关更多信息,请参见 [扩展 PyTorch](extending)。\n", "\n", "对于 PyTorch 定义的运算(如 {func}`torch.pow`),张量会根据需要自动保存。通过查找以 `_saved` 前缀开头的属性,您可以研究(出于教育或调试目的)某个 `grad_fn` 保存了哪些张量。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "True\n" ] } ], "source": [ "import torch\n", "\n", "x = torch.randn(5, requires_grad=True)\n", "y = x.pow(2)\n", "print(x.equal(y.grad_fn._saved_self)) # True\n", "print(x is y.grad_fn._saved_self) # True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在前面的代码中,`y.grad_fn._saved_self` 指的是与 `x` 相同的张量对象。但情况并非总是如此。例如" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "False\n" ] } ], "source": [ "x = torch.randn(5, requires_grad=True)\n", "y = x.exp()\n", "print(y.equal(y.grad_fn._saved_result)) # True\n", "print(y is y.grad_fn._saved_result) # False" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在幕后,为了防止循环引用,PyTorch 在保存时将张量打包,并将其解压到不同的张量以供阅读。在这里,你通过访问 `y.grad_fn._saved_result` 得到的张量是不同于 `y` 的张量对象(但它们仍然共享相同的存储)。\n", "\n", "张量是否会被打包到不同的张量对象中取决于它是否是它自己的 `grad_fn` 的输出,这是可能会改变的实现细节,用户不应该依赖它。\n", "\n", "你可以控制 PyTorch 如何借助[保存张量的挂钩](https://pytorch.org/docs/stable/notes/autograd.html#saved-tensors-hooks-doc) 打包/拆包。\n", "\n", "## 局部禁用梯度计算\n", "\n", "Python 中有几种机制可以在局部禁用梯度计算:\n", "\n", "要在整个代码块中禁用梯度,有一些上下文管理器,如 `no-grad` 模式和 `inference` 模式。为了从梯度计算中更细粒度地排除子图,可以设置张量的 `requires_grad` 字段。\n", "\n", "下面,除了讨论上面的机制,还描述了求值模式({func}`~torch.nn.modules.module.Module.eval`),这个方法实际上并不用于禁用梯度计算,但由于它的名称,经常与这三种方法混在一起。\n", "\n", "### 设置 `requires_grad`\n", "\n", "`requires_grad` 是 flag,除非用 {class}`~torch.nn.parameter.Parameter` 包装,否则默认为 `False`,它允许从梯度计算中细粒度地排除子图。它在向前和向后的传播中都起作用:\n", "\n", "在前向传播过程中,如果运算至少有一个输入张量需要梯度,则该运算只记录在 backward 图中。在后向传播(`.backward()`)期间,只有 `requires_grad=True` 的叶张量才会在它们的 `.grad` 字段中累积梯度。\n", "\n", "值得注意的是,即使每个张量都有这个 flag,设置它只对叶张量有意义(没有 `grad_fn` 的张量,例如 `nn.Module` 参数)。非叶张量(确实有 `grad_fn` 的张量)是具有 `backward` 图关联的张量。因此,它们的梯度将需要作为中间结果来计算需要梯度的叶张量的梯度。从这个定义中,很明显,所有非叶张量都将自动具有 `require_grad=True`。\n", "\n", "设置 `requires_grad` 应该是您控制模型的哪些部分是梯度计算的一部分的主要方式,例如,如果您需要在模型微调期间冻结您的预训练模型的部分。\n", "\n", "要冻结模型的部分,只需将 `.requires_grad_(False)` 应用于不希望更新的参数。如上所述,由于使用这些参数作为输入的计算不会在正向传递中被记录,因此它们的 `.grad` 字段不会在向后传递中被更新,因为它们一开始就不是 `backward` 传递图的一部分,正如所希望的那样。\n", "\n", "因为这是非常常见的模式,所以也可以在模块级别使用 `nn.Module.requires_grad_()` 来设置 `requires_grad`。当应用到模块时,`.requires_grad_()` 将对模块的所有参数(默认为 `requires_grad=True`)生效。\n", "\n", "### 模式\n", "\n", "除了设置 `requires_grad` 之外,Python 还支持三种可能的模式,它们可以影响内部 `autograd` 处理 PyTorch 中的计算:default 模式(grad 模式)、no-grad 模式和 `inference` 模式,所有这些模式都可以通过上下文管理器和装饰器进行切换。\n", "\n", "```{rubric} 默认模式(Grad 模式)\n", "```\n", "\n", "默认模式实际上就是当没有启用其他模式(如 no-grad 模式和 inference 模式)时,隐式地处于的模式。与 no-grad 模式相比,默认模式有时也称为 grad模式。\n", "\n", "关于默认模式最重要的一点是,它是 `requires_grad` 生效的唯一模式。在其他两种模式中,`requires_grad` 总是被重写为 `False`。\n", "\n", "```{rubric} No-grad 模式\n", "```\n", "\n", "在 no-grad 模式下的计算表现为没有任何输入需要 grad。换句话说,即使有 `require_grad=True` 的输入,在 no-grad 模式下的计算也不会记录在反向图中。\n", "\n", "当您需要执行不应由 `autograd` 记录的运算,但您仍然希望稍后在 `grad` 模式下使用这些计算的输出时,启用 `no-grad` 模式。这个上下文管理器可以方便地禁用代码块或函数的梯度,而不必临时将张量设置为 `requires_grad=False`,然后返回 `True`。\n", "\n", "例如,在编写优化器时,no-grad 模式可能很有用:在执行训练更新时,您希望在原地更新参数,而不需要由 autograd 记录更新。您还打算在下一个前向传递中在 grad 模式中使用更新的参数进行计算。\n", "\n", "{mod}`torch.nn.init` 中的实现在初始化参数时也依赖于 no-grad 模式,以避免在原地更新初始化参数时进行自 grad 跟踪。\n", "\n", "```{rubric} inference 模式\n", "```\n", "\n", "推理模式是 no-grad 模式的极端版本。就像在 no-grad 模式中一样,推理模式中的计算不会记录在反向图中,但是启用推理模式将允许 PyTorch 进一步加速您的模型。这种更好的运行时有一个缺点:在推理模式中创建的张量不能用于退出推理模式后由 autograd 记录的计算。\n", "\n", "当您执行不需要记录在反向图中的计算时,启用推理模式,并且您不打算在稍后由 autograd 记录的任何计算中使用推理模式中创建的张量。\n", "\n", "建议您在不需要自动跟踪(例如,数据处理和模型评估)的代码部分尝试推理模式。如果它在你的用例中是开箱即用的,这是免费的性能胜利。如果在启用推理模式后遇到错误,请检查是否在退出推理模式后由 autograd 记录的计算中使用了推理模式中创建的张量。如果您无法避免在您的情况下使用这种方法,您总是可以切换回 no-grad 模式。\n", "\n", "有关推理模式的详细信息,请参见 [推理模式](https://pytorch.org/cppdocs/notes/inference_mode.html)。\n", "\n", "有关推理模式的实现细节,请参见 [RFC-0011-InferenceMode](https://github.com/pytorch/rfcs/pull/17)。\n", "\n", "```{rubric} 评估模式\n", "```\n", "\n", "评估模式(`nn.Module.eval`)实际上并不是一种局部禁用梯度计算的机制。无论如何,这里包含了它,因为它有时会被混淆为这样的机制。\n", "\n", "在函数上,{func}`module.eval`(或等效于 {func}`module.train`)完全正交于 no-grade 模式和推理模式。e{func}`module.eval` 如何影响您的模型完全取决于您的模型中使用的特定模块,以及它们是否定义了任何训练模式特定的行为。\n", "\n", "如果你的模型依赖于诸如 {class}`torch.nn.Dropout` 和 {class}`torch.nn.BatchNorm2d ` 这样的模块,你需要负责调用 {func}`model.eval` 和 {func}`model.train`。根据训练模式的不同,`BatchNorm2d` 的行为可能会有所不同,例如,在验证数据上避免更新 `BatchNorm` 运行统计数据。\n", "\n", "```{admonition} 建议\n", "在训练时使用 {func}`model.train`,在评估模型(验证/测试)时使用 {func}`model.eval`,即使你不确定你的模型有特定的训练模式行为,因为你正在使用的模块可能会在 training 和 eval 模式中被更新为不同的行为。\n", "```\n", "\n", "## 使用 autograd 进行就地操作\n", "\n", "在 autograd 中支持就地操作(in-place)是一件困难的事情,不鼓励在大多数情况下使用它们。Autograd 积极的缓冲区释放和重用使其非常高效,而且很少有情况下,就地操作实际上显著降低了内存使用量。除非您在沉重的内存压力下操作,否则您可能永远都不需要使用它们。\n", "\n", "有两个主要原因限制了就地操作的适用性:\n", "\n", "1. 就地操作可能会覆盖计算梯度所需的值。\n", "2. 每个就地操作实际上都需要实现重写计算图。错位(out-of-place)版本只是分配新对象并保持对旧图形的引用,而就地操作则需要更改表示此操作的函数的所有输入的创建者。这可能很棘手,特别是当有很多张量引用了相同的存储(例如通过索引或转置创建),如果修改后的输入的存储被其他张量引用,原地函数实际上会引发错误。\n", "\n", "### 就地正确性检查\n", "\n", "每个张量都有版本计数器,每当它在任何操作中被标记为 dirty 时,这个计数器就会递增。当函数为 backward, 保存任何张量时,包含它们的张量的版本计数器也会被保存。一旦你访问 `self.saved_tensors` 会检查它,如果它大于保存的值,则会引发错误。这确保了如果您使用的是就地函数而没有看到任何错误,您可以确保计算的梯度是正确的。\n", "\n", "## 多线程 Autograd\n", "\n", "autograd 引擎负责运行计算 backward 遍历所需的所有 backward 运算。本节将描述所有可以帮助您在多线程环境中充分利用它的细节。(这只与 PyTorch 1.6+ 相关,因为之前版本的行为是不同的)。\n", "\n", "用户可以用多线程代码训练他们的模型(例如 Hogwild 训练),并且不会阻塞并发 backward 计算,示例代码:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "# Define a train function to be used in different threads\n", "def train_fn():\n", " x = torch.ones(5, 5, requires_grad=True)\n", " # forward\n", " y = (x + 3) * (x + 4) * 0.5\n", " # backward\n", " y.sum().backward()\n", " # potential optimizer update\n", "\n", "\n", "# User write their own threading code to drive the train_fn\n", "threads = []\n", "for _ in range(10):\n", " p = threading.Thread(target=train_fn, args=())\n", " p.start()\n", " threads.append(p)\n", "\n", "for p in threads:\n", " p.join()\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "请注意用户应该注意的一些行为:\n", "\n", "```{rubric} CPU 的并发性\n", "```\n", "\n", "当你通过 Python 或 C++ API 在 CPU 上的多个线程中运行 {func}`backward` 或 {func}`grad` 时,你希望看到额外的并发性,而不是在执行期间以特定的顺序序列化所有的 backward 调用(在 PyTorch 1.6 之前的行为)。\n", "\n", "```{rubric} 非确定性\n", "```\n", "\n", "如果你在多线程上并发地调用 {func}`backward`,但是有共享的输入(例如 Hogwild CPU 训练)。因为参数是在线程之间自动共享的,所以梯度累加在线程之间的 `backward` 调用上可能变得不确定,因为两个 `backward` 调用可能访问并试图累积相同的 `.grad` 属性。这在技术上是不安全的,它可能会导致竞争条件和结果可能是无效的使用。\n", "\n", "但是,如果您使用多线程方法来驱动整个训练过程,但是使用共享参数,那么这是预期的模式,使用多线程的用户应该记住线程模型,并预期会发生这种情况。用户可以使用函数式 API {func}`torch.autograd.grad` 来计算梯度,而不是 {func}`backward` 来避免不确定性。\n", "\n", "```{rubric} Graph retaining\n", "```\n", "\n", "如果 autograd 图的一部分在线程之间共享,也就是说,运行前向单线程的第一部分,然后在多个线程中运行第二部分,那么图的第一部分就是共享的。在这种情况下,在同一个图上执行 {func}`grad` 或 {func}`backward` 的不同线程可能会在一个线程上破坏图,而在这种情况下,另一个线程将崩溃。Autograd 将向用户输出错误信息,类似于在 `retain_graph=True` 的情况下调用 {func}`backward` 两次,并让用户知道他们应该使用 `retain_graph=True`。\n", "\n", "```{rubric} Autograd 节点上的线程安全\n", "```\n", "\n", "由于 Autograd 允许调用方线程驱动 backward 执行以获得潜在的并行性,因此确保 CPU 上的线程安全是很重要的,向后并行共享部分/全部的 GraphTask。\n", "\n", "自定义Python `autograd.function` 是自动线程安全的,因为 GIL。对于内置的 C++ Autograd 节点(例如 AccumulateGrad, CopySlices) 和自定义的 `autograd::Function`,Autograd Engine 使用线程互斥锁来保护可能有状态写/读的 autograd 节点上的线程安全。\n", "\n", "```{rubric} C++ 钩子上没有线程安全\n", "```\n", "\n", "Autograd 依赖于用户来编写线程安全的 C++ 钩子。如果你想在多线程环境中正确的应用钩子,你需要写正确的线程锁定代码来确保钩子是线程安全的。\n", "\n", "## 复数 Autograd\n", "\n", "简短的版本:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{eval-rst}\n", "- When you use PyTorch to differentiate any function :math:`f(z)` with complex domain and/or codomain,\n", " the gradients are computed under the assumption that the function is a part of a larger real-valued\n", " loss function :math:`g(input)=L`. The gradient computed is :math:`\\frac{\\partial L}{\\partial z^*}`\n", " (note the conjugation of z), the negative of which is precisely the direction of steepest descent\n", " used in Gradient Descent algorithm. Thus, all the existing optimizers work out of\n", " the box with complex parameters.\n", "- This convention matches TensorFlow's convention for complex\n", " differentiation, but is different from JAX (which computes\n", " :math:`\\frac{\\partial L}{\\partial z}`).\n", "- If you have a real-to-real function which internally uses complex\n", " operations, the convention here doesn't matter: you will always get\n", " the same result that you would have gotten if it had been implemented\n", " with only real operations.\n", "\n", "If you are curious about the mathematical details, or want to know how\n", "to define complex derivatives in PyTorch, read on.\n", "\n", ".. rubric:: What are complex derivatives?\n", "\n", "The mathematical definition of complex-differentiability takes the\n", "limit definition of a derivative and generalizes it to operate on\n", "complex numbers. Consider a function :math:`f: ℂ → ℂ`,\n", "\n", " .. math::\n", " `f(z=x+yj) = u(x, y) + v(x, y)j`\n", "\n", "where :math:`u` and :math:`v` are two variable real valued functions.\n", "\n", "Using the derivative definition, we can write:\n", "\n", " .. math::\n", " f'(z) = \\lim_{h \\to 0, h \\in C} \\frac{f(z+h) - f(z)}{h}\n", "\n", "In order for this limit to exist, not only must :math:`u` and :math:`v` must be\n", "real differentiable, but :math:`f` must also satisfy the Cauchy-Riemann `equations\n", "`_. In\n", "other words: the limit computed for real and imaginary steps (:math:`h`)\n", "must be equal. This is a more restrictive condition.\n", "\n", "The complex differentiable functions are commonly known as holomorphic\n", "functions. They are well behaved, have all the nice properties that\n", "you've seen from real differentiable functions, but are practically of no\n", "use in the optimization world. For optimization problems, only real valued objective\n", "functions are used in the research community since complex numbers are not part of any\n", "ordered field and so having complex valued loss does not make much sense.\n", "\n", "It also turns out that no interesting real-valued objective fulfill the\n", "Cauchy-Riemann equations. So the theory with homomorphic function cannot be\n", "used for optimization and most people therefore use the Wirtinger calculus.\n", "\n", ".. rubric:: Wirtinger Calculus comes in picture ...\n", "\n", "\n", "So, we have this great theory of complex differentiability and\n", "holomorphic functions, and we can’t use any of it at all, because many\n", "of the commonly used functions are not holomorphic. What’s a poor\n", "mathematician to do? Well, Wirtinger observed that even if :math:`f(z)`\n", "isn’t holomorphic, one could rewrite it as a two variable function\n", ":math:`f(z, z*)` which is always holomorphic. This is because real and\n", "imaginary of the components of :math:`z` can be expressed in terms of\n", ":math:`z` and :math:`z^*` as:\n", "\n", " .. math::\n", " \\begin{aligned}\n", " Re(z) &= \\frac {z + z^*}{2} \\\\\n", " Im(z) &= \\frac {z - z^*}{2j}\n", " \\end{aligned}\n", "\n", "Wirtinger calculus suggests to study :math:`f(z, z^*)` instead, which is\n", "guaranteed to be holomorphic if :math:`f` was real differentiable (another\n", "way to think of it is as a change of coordinate system, from :math:`f(x, y)`\n", "to :math:`f(z, z^*)`.) This function has partial derivatives\n", ":math:`\\frac{\\partial }{\\partial z}` and :math:`\\frac{\\partial}{\\partial z^{*}}`.\n", "We can use the chain rule to establish a\n", "relationship between these partial derivatives and the partial\n", "derivatives w.r.t., the real and imaginary components of :math:`z`.\n", "\n", " .. math::\n", " \\begin{aligned}\n", " \\frac{\\partial }{\\partial x} &= \\frac{\\partial z}{\\partial x} * \\frac{\\partial }{\\partial z} + \\frac{\\partial z^*}{\\partial x} * \\frac{\\partial }{\\partial z^*} \\\\\n", " &= \\frac{\\partial }{\\partial z} + \\frac{\\partial }{\\partial z^*} \\\\\n", " \\\\\n", " \\frac{\\partial }{\\partial y} &= \\frac{\\partial z}{\\partial y} * \\frac{\\partial }{\\partial z} + \\frac{\\partial z^*}{\\partial y} * \\frac{\\partial }{\\partial z^*} \\\\\n", " &= 1j * (\\frac{\\partial }{\\partial z} - \\frac{\\partial }{\\partial z^*})\n", " \\end{aligned}\n", "\n", "From the above equations, we get:\n", "\n", " .. math::\n", " \\begin{aligned}\n", " \\frac{\\partial }{\\partial z} &= 1/2 * (\\frac{\\partial }{\\partial x} - 1j * \\frac{\\partial }{\\partial y}) \\\\\n", " \\frac{\\partial }{\\partial z^*} &= 1/2 * (\\frac{\\partial }{\\partial x} + 1j * \\frac{\\partial }{\\partial y})\n", " \\end{aligned}\n", "\n", "which is the classic definition of Wirtinger calculus that you would find on `Wikipedia `_.\n", "\n", "There are a lot of beautiful consequences of this change.\n", "\n", "- For one, the Cauchy-Riemann equations translate into simply saying that :math:`\\frac{\\partial f}{\\partial z^*} = 0` (that is to say, the function :math:`f` can be written\n", " entirely in terms of :math:`z`, without making reference to :math:`z^*`).\n", "- Another important (and somewhat counterintuitive) result, as we’ll see later, is that when we do optimization on a real-valued loss, the step we should\n", " take while making variable update is given by :math:`\\frac{\\partial Loss}{\\partial z^*}` (not :math:`\\frac{\\partial Loss}{\\partial z}`).\n", "\n", "For more reading, check out: https://arxiv.org/pdf/0906.4835.pdf\n", "\n", ".. rubric:: How is Wirtinger Calculus useful in optimization?\n", "\n", "Researchers in audio and other fields, more commonly, use gradient\n", "descent to optimize real valued loss functions with complex variables.\n", "Typically, these people treat the real and imaginary values as separate\n", "channels that can be updated. For a step size :math:`\\alpha/2` and loss\n", ":math:`L`, we can write the following equations in :math:`ℝ^2`:\n", "\n", " .. math::\n", " \\begin{aligned}\n", " x_{n+1} &= x_n - (\\alpha/2) * \\frac{\\partial L}{\\partial x} \\\\\n", " y_{n+1} &= y_n - (\\alpha/2) * \\frac{\\partial L}{\\partial y}\n", " \\end{aligned}\n", "\n", "How do these equations translate into complex space :math:`ℂ`?\n", "\n", " .. math::\n", " \\begin{aligned}\n", " z_{n+1} &= x_n - (\\alpha/2) * \\frac{\\partial L}{\\partial x} + 1j * (y_n - (\\alpha/2) * \\frac{\\partial L}{\\partial y}) \\\\\n", " &= z_n - \\alpha * 1/2 * (\\frac{\\partial L}{\\partial x} + j \\frac{\\partial L}{\\partial y}) \\\\\n", " &= z_n - \\alpha * \\frac{\\partial L}{\\partial z^*}\n", " \\end{aligned}\n", "\n", "Something very interesting has happened: Wirtinger calculus tells us\n", "that we can simplify the complex variable update formula above to only\n", "refer to the conjugate Wirtinger derivative\n", ":math:`\\frac{\\partial L}{\\partial z^*}`, giving us exactly the step we take in optimization.\n", "\n", "Because the conjugate Wirtinger derivative gives us exactly the correct step for a real valued loss function, PyTorch gives you this derivative\n", "when you differentiate a function with a real valued loss.\n", "\n", ".. rubric:: How does PyTorch compute the conjugate Wirtinger derivative?\n", "\n", "Typically, our derivative formulas take in `grad_output` as an input,\n", "representing the incoming Vector-Jacobian product that we’ve already\n", "computed, aka, :math:`\\frac{\\partial L}{\\partial s^*}`, where :math:`L`\n", "is the loss of the entire computation (producing a real loss) and\n", ":math:`s` is the output of our function. The goal here is to compute\n", ":math:`\\frac{\\partial L}{\\partial z^*}`, where :math:`z` is the input of\n", "the function. It turns out that in the case of real loss, we can\n", "get away with *only* calculating :math:`\\frac{\\partial L}{\\partial z^*}`,\n", "even though the chain rule implies that we also need to\n", "have access to :math:`\\frac{\\partial L}{\\partial z^*}`. If you want\n", "to skip this derivation, look at the last equation in this section\n", "and then skip to the next section.\n", "\n", "Let’s continue working with :math:`f: ℂ → ℂ` defined as\n", ":math:`f(z) = f(x+yj) = u(x, y) + v(x, y)j`. As discussed above,\n", "autograd’s gradient convention is centered around optimization for real\n", "valued loss functions, so let’s assume :math:`f` is a part of larger\n", "real valued loss function :math:`g`. Using chain rule, we can write:\n", "\n", " .. math::\n", " \\frac{\\partial L}{\\partial z^*} = \\frac{\\partial L}{\\partial u} * \\frac{\\partial u}{\\partial z^*} + \\frac{\\partial L}{\\partial v} * \\frac{\\partial v}{\\partial z^*}\n", " :label: [1]\n", "\n", "Now using Wirtinger derivative definition, we can write:\n", "\n", " .. math::\n", " \\begin{aligned}\n", " \\frac{\\partial L}{\\partial s} = 1/2 * (\\frac{\\partial L}{\\partial u} - \\frac{\\partial L}{\\partial v} j) \\\\\n", " \\frac{\\partial L}{\\partial s^*} = 1/2 * (\\frac{\\partial L}{\\partial u} + \\frac{\\partial L}{\\partial v} j)\n", " \\end{aligned}\n", "\n", "It should be noted here that since :math:`u` and :math:`v` are real\n", "functions, and :math:`L` is real by our assumption that :math:`f` is a\n", "part of a real valued function, we have:\n", "\n", " .. math::\n", " (\\frac{\\partial L}{\\partial s})^* = \\frac{\\partial L}{\\partial s^*}\n", " :label: [2]\n", "\n", "i.e., :math:`\\frac{\\partial L}{\\partial s}` equals to :math:`grad\\_output^*`.\n", "\n", "Solving the above equations for :math:`\\frac{\\partial L}{\\partial u}` and :math:`\\frac{\\partial L}{\\partial v}`, we get:\n", "\n", " .. math::\n", " \\begin{aligned}\n", " \\frac{\\partial L}{\\partial u} = \\frac{\\partial L}{\\partial s} + \\frac{\\partial L}{\\partial s^*} \\\\\n", " \\frac{\\partial L}{\\partial v} = -1j * (\\frac{\\partial L}{\\partial s} - \\frac{\\partial L}{\\partial s^*})\n", " \\end{aligned}\n", " :label: [3]\n", "\n", "Substituting :eq:`[3]` in :eq:`[1]`, we get:\n", "\n", " .. math::\n", " \\begin{aligned}\n", " \\frac{\\partial L}{\\partial z^*} &= (\\frac{\\partial L}{\\partial s} + \\frac{\\partial L}{\\partial s^*}) * \\frac{\\partial u}{\\partial z^*} - 1j * (\\frac{\\partial L}{\\partial s} - \\frac{\\partial L}{\\partial s^*}) * \\frac{\\partial v}{\\partial z^*} \\\\\n", " &= \\frac{\\partial L}{\\partial s} * (\\frac{\\partial u}{\\partial z^*} + \\frac{\\partial v}{\\partial z^*} j) + \\frac{\\partial L}{\\partial s^*} * (\\frac{\\partial u}{\\partial z^*} - \\frac{\\partial v}{\\partial z^*} j) \\\\\n", " &= \\frac{\\partial L}{\\partial s^*} * \\frac{\\partial (u + vj)}{\\partial z^*} + \\frac{\\partial L}{\\partial s} * \\frac{\\partial (u + vj)^*}{\\partial z^*} \\\\\n", " &= \\frac{\\partial L}{\\partial s} * \\frac{\\partial s}{\\partial z^*} + \\frac{\\partial L}{\\partial s^*} * \\frac{\\partial s^*}{\\partial z^*} \\\\\n", " \\end{aligned}\n", "\n", "Using :eq:`[2]`, we get:\n", "\n", " .. math::\n", " \\begin{aligned}\n", " \\frac{\\partial L}{\\partial z^*} &= (\\frac{\\partial L}{\\partial s^*})^* * \\frac{\\partial s}{\\partial z^*} + \\frac{\\partial L}{\\partial s^*} * (\\frac{\\partial s}{\\partial z})^* \\\\\n", " &= \\boxed{ (grad\\_output)^* * \\frac{\\partial s}{\\partial z^*} + grad\\_output * {(\\frac{\\partial s}{\\partial z})}^* } \\\\\n", " \\end{aligned}\n", " :label: [4]\n", "\n", "This last equation is the important one for writing your own gradients,\n", "as it decomposes our derivative formula into a simpler one that is easy\n", "to compute by hand.\n", "\n", ".. rubric:: How can I write my own derivative formula for a complex function?\n", "\n", "The above boxed equation gives us the general formula for all\n", "derivatives on complex functions. However, we still need to\n", "compute :math:`\\frac{\\partial s}{\\partial z}` and :math:`\\frac{\\partial s}{\\partial z^*}`.\n", "There are two ways you could do this:\n", "\n", " - The first way is to just use the definition of Wirtinger derivatives directly and calculate :math:`\\frac{\\partial s}{\\partial z}` and :math:`\\frac{\\partial s}{\\partial z^*}` by\n", " using :math:`\\frac{\\partial s}{\\partial x}` and :math:`\\frac{\\partial s}{\\partial y}`\n", " (which you can compute in the normal way).\n", " - The second way is to use the change of variables trick and rewrite :math:`f(z)` as a two variable function :math:`f(z, z^*)`, and compute\n", " the conjugate Wirtinger derivatives by treating :math:`z` and :math:`z^*` as independent variables. This is often easier; for example, if the function in question is holomorphic, only :math:`z` will be used (and :math:`\\frac{\\partial s}{\\partial z^*}` will be zero).\n", "\n", "Let's consider the function :math:`f(z = x + yj) = c * z = c * (x+yj)` as an example, where :math:`c \\in ℝ`.\n", "\n", "Using the first way to compute the Wirtinger derivatives, we have.\n", "\n", ".. math::\n", " \\begin{aligned}\n", " \\frac{\\partial s}{\\partial z} &= 1/2 * (\\frac{\\partial s}{\\partial x} - \\frac{\\partial s}{\\partial y} j) \\\\\n", " &= 1/2 * (c - (c * 1j) * 1j) \\\\\n", " &= c \\\\\n", " \\\\\n", " \\\\\n", " \\frac{\\partial s}{\\partial z^*} &= 1/2 * (\\frac{\\partial s}{\\partial x} + \\frac{\\partial s}{\\partial y} j) \\\\\n", " &= 1/2 * (c + (c * 1j) * 1j) \\\\\n", " &= 0 \\\\\n", " \\end{aligned}\n", "\n", "Using :eq:`[4]`, and `grad\\_output = 1.0` (which is the default grad output value used when :func:`backward` is called on a scalar output in PyTorch), we get:\n", "\n", " .. math::\n", " \\frac{\\partial L}{\\partial z^*} = 1 * 0 + 1 * c = c\n", "\n", "Using the second way to compute Wirtinger derivatives, we directly get:\n", "\n", " .. math::\n", " \\begin{aligned}\n", " \\frac{\\partial s}{\\partial z} &= \\frac{\\partial (c*z)}{\\partial z} \\\\\n", " &= c \\\\\n", " \\frac{\\partial s}{\\partial z^*} &= \\frac{\\partial (c*z)}{\\partial z^*} \\\\\n", " &= 0\n", " \\end{aligned}\n", "\n", "And using :eq:`[4]` again, we get :math:`\\frac{\\partial L}{\\partial z^*} = c`. As you can see, the second way involves lesser calculations, and comes\n", "in more handy for faster calculations.\n", "\n", ".. rubric:: What about cross-domain functions?\n", "\n", "Some functions map from complex inputs to real outputs, or vice versa.\n", "These functions form a special case of :eq:`[4]`, which we can derive using the\n", "chain rule:\n", "\n", " - For :math:`f: ℂ → ℝ`, we get:\n", "\n", " .. math::\n", " \\frac{\\partial L}{\\partial z^*} = 2 * grad\\_output * \\frac{\\partial s}{\\partial z^{*}}\n", "\n", " - For :math:`f: ℝ → ℂ`, we get:\n", "\n", " .. math::\n", " \\frac{\\partial L}{\\partial z^*} = 2 * Re(grad\\_out^* * \\frac{\\partial s}{\\partial z^{*}})\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 保存张量的挂钩\n", "\n", "您可以通过定义一对 `pack_hook / unpack_hook` 钩子来[控制保存的张量如何打包/解包](https://pytorch.org/docs/stable/notes/autograd.html#saved-tensors-doc)。`pack_hook` 函数应该接受张量作为它的单个参数,但是可以返回任何 Python 对象(例如另一个张量,元组,甚至包含文件名的字符串)。`unpack_hook` 函数的唯一参数是 `pack_hook` 的输出,它应该返回一个张量,以便向后传递时使用。`unpack_hook` 返回的张量只需要与作为输入传递给 `pack_hook` 的张量具有相同的内容。特别是,任何与自动加载相关的元数据都可以被忽略,因为它们将在解包期间被覆盖。\n", "\n", "这对组合的一个例子是:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "import os\n", "import uuid\n", "\n", "\n", "class SelfDeletingTempFile:\n", " def __init__(self, tmp_dir):\n", " self.name = os.path.join(tmp_dir, str(uuid.uuid4()))\n", "\n", " def __del__(self):\n", " os.remove(self.name)\n", "\n", "\n", "def pack_hook(tensor, tmp_dir='data'):\n", " temp_file = SelfDeletingTempFile(tmp_dir)\n", " torch.save(tensor, temp_file.name)\n", " return temp_file\n", "\n", "\n", "def unpack_hook(temp_file):\n", " return torch.load(temp_file.name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "注意,`unpack_hook` 不应该删除临时文件,因为它可能会被多次调用:只要返回的 `SelfDeletingTempFile` 对象处于活动状态,该临时文件就应该处于活动状态。在上面的例子中,通过在不再需要临时文件时关闭它来防止泄漏(在删除 `SelfDeletingTempFile` 对象时)。\n", "\n", "```{note}\n", "我们保证 `pack_hook` 只被调用一次,但 `unpack_hook` 可以被调用多次,只要向后传递需要,并且我们希望它每次都返回相同的数据。\n", "```\n", "\n", "```{warning}\n", "禁止对任何函数的输入执行就地操作,因为它们可能会导致意想不到的副作用。如果 pack 钩子的输入被就地修改,PyTorch 将抛出错误,但是没有捕捉到 unpack 钩子的输入被就地修改的情况。\n", "```\n", "\n", "### 注册已保存张量的钩子\n", "\n", "通过调用已保存张量对象上的 {meth}`register_hooks` 方法,可以在已保存的张量上注册一对钩子。这些对象作为 `grad_fn` 的属性公开,并以 `_raw_saved_` 前缀开头。" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(5, requires_grad=True)\n", "y = x.pow(2)\n", "y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "一旦注册了该对,就会调用 `pack_hook` 方法。每次需要访问保存的张量时,都会调用 `unpack_hook` 方法,或者通过 `y.grad_fn._saved_self` 或在向后传递时。\n", "\n", "```{warning}\n", "如果你在保存的张量被释放之后(即在向后调用之后)保持对 `SavedTensor` 的引用,那么调用它的 {func}`register_hooks` 是被禁止的。PyTorch 在大多数情况下会抛出错误,但在某些情况下可能会失败,并可能出现未定义的行为。\n", "```\n", "\n", "### 为保存的张量注册默认钩子\n", "\n", "或者,您可以使用上下文管理器 {func}`torch.autograd.graph.saved_tensors_hooks` 来注册一对钩子,这对钩子将应用于在该上下文中创建的所有保存的张量。\n", "\n", "示例:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "# Only save on disk tensors that have size >= 1000\n", "SAVE_ON_DISK_THRESHOLD = 1000\n", "\n", "def pack_hook(tensor):\n", " if tensor.numel() < SAVE_ON_DISK_THRESHOLD:\n", " return tensor\n", " temp_file = SelfDeletingTempFile(temp_file='data')\n", " torch.save(tensor, temp_file.name)\n", " return temp_file\n", "\n", "def unpack_hook(tensor_or_sctf):\n", " if isinstance(tensor_or_sctf, torch.Tensor):\n", " return tensor_or_sctf\n", " return torch.load(tensor_or_sctf.name)\n", "\n", "class Model(torch.nn.Module):\n", " def forward(self, x):\n", " with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):\n", " # ... compute output\n", " output = x\n", " return output\n", "\n", "model = Model()\n", "net = torch.nn.DataParallel(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "用这个上下文管理器定义的钩子是线程局部的。因此,以下代码不会产生所需的效果,因为钩子没有经过 `DataParallel`。\n", "\n", "```python\n", "# Example what NOT to do\n", "\n", "net = torch.nn.DataParallel(model)\n", "with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):\n", " output = net(input)\n", "```\n", "\n", "注意,使用这些钩子将禁用所有的优化以减少 Tensor 对象的创建。例如:\n", "\n", "```python\n", "with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):\n", " x = torch.randn(5, requires_grad=True)\n", " y = x * x\n", "```\n", "\n", "如果没有钩子,`x`, `y.grad_fn._saved_self` 和 `y.grad_fn._saved_other` 都指向同一个张量对象。通过钩子,PyTorch 将把 `x` 打包和解压到两个新的张量对象中,它们与原始的 `x` 共享相同的存储空间(没有执行拷贝)。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "interpreter": { "hash": "78526419bf48930935ba7e23437b2460cb231485716b036ebb8701887a294fa8" }, "kernelspec": { "display_name": "Python 3.10.0 ('torchx')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }