Autograd 机制#
本文将概述 autograd 的工作原理并记录运算。理解所有这些并不是严格必要的,但我们建议熟悉它们,因为它将帮助您编写更高效、更清晰的程序,并有助于您进行调试。
autograd 如何编码历史#
Autograd 是一种反向自动微分系统(automatic differentiation system)。从概念上讲,autograd 记录了一个计算图(graph),在执行运算时记录了创建数据的所有运算,给出了一个有向无环图(directed acyclic graph),其叶是输入张量,根是输出张量。通过从根到叶跟踪这个计算图图,您可以使用链式法则(chain rule)自动计算梯度。
在内部,autograd 将此计算图表示为 Function
对象(真正的表达式)的计算图,可以使用 apply()
计算计算图的求值结果。在计算前向传递时,autograd 同时执行请求的计算并构建表示计算梯度的函数的计算图(每个 Tensor
的 grad_fn
属性是此计算图的入口点)。当前向传递完成时,我们在后向传递中计算此计算图的梯度。
需要注意的一件重要的事情是,计算图是在每次迭代中从头重新创建的,这正是允许使用任意 Python 控制流语句的原因,它可以在每次迭代中更改计算图的整体形状和大小。在开始训练之前,你不必对所有可能的路径进行编码——你所运行的就是你所区分的。
保存张量#
有些运算需要在前向传递期间保存中间结果,以便执行后向传递。例如,函数 \(x\mapsto x^2\) 保存输入 \(x\) 来计算梯度。
在定义自定义 Python Function
时,可以使用 save_for_backward()
在正向传递期间保存张量,并使用 saved_tensors
在向后传递期间检索它们。有关更多信息,请参见 扩展 PyTorch。
对于 PyTorch 定义的运算(例如 torch.pow()
),张量会根据需要自动保存。您可以通过查找以前缀 _saved
开头的属性来探索(出于教育或调试目的)某个 grad_fn
保存了哪些张量。
import torch
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self)) # True
print(x is y.grad_fn._saved_self) # True
True
True
在前面的代码中,y.grad_fn._saved_self
指的是与 x
相同的张量对象。但情况不一定总是这样。例如:
x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result)) # True
print(y is y.grad_fn._saved_result) # False
True
False
在内部,为了防止引用循环,PyTorch 在保存时打包了张量,并将其解压缩到不同的张量中以供读取。这里,你通过访问 y.grad_fn._saved_result
得到不同于的 y
的张量对象(但它们仍然共享相同的存储空间)。
张量是否会被打包到不同的张量对象中,取决于它是否是它自己的 grad_fn
的输出,这是可能发生变化的实现细节,用户不应该依赖它。
你可以控制 PyTorch 如何用 hook 对保存的张量进行打包/解包。
不可微函数的梯度#
使用自动微分法(Automatic Differentiation)的梯度计算只有在使用的每个初等函数都是可微的情况下才有效。不幸的是,在实践中使用的许多函数没有这个特性(例如,relu
或 sqrt
在 0
处)。为了尽量减少不可微函数的影响,通过应用以下规则来定义初等运算的梯度:
#. 如果函数是可微的,因此在当前点存在梯度,就使用它。
#. 如果函数是凸的(至少在局部),则使用最小范数的次梯度(sub-gradient,它是最陡的下降方向)。
#. 如果函数是凹的(至少在局部),则使用最小范数的超梯度(super-gradient,考虑 -f(x)
并应用前一个点)。
#. 如果定义了函数,则通过连续性定义当前点的梯度(注意,这里可能是 inf
,例如对于 sqrt(0)
)。如果可能有多个值,则任意选择一个。
#. 如果函数没有定义(例如,sqrt(-1)
,log(-1)
或输入为 NaN
时的大多数函数),那么用作梯度的值是任意的(也可能引发错误,但不保证)。大多数函数将使用 NaN
作为梯度,但出于性能原因,一些函数将使用其他值(例如 log(-1)
)。
#. 如果函数不是一个确定性映射(deterministic mapping,即它不是数学函数),它将被标记为不可微。如果用在 no_grad
环境之外的需要 grad 的张量上,这将使它在向后出错。
局部禁用梯度计算#
Python 中有几种机制可以在局部禁用梯度计算:
为了在整个代码块中禁用梯度,可以使用像 no-grad 模式和 inference 模式这样的上下文管理器。为了从梯度计算中更细粒度地排除子图,需要设置张量的 requires_grad
场(field)。
下面,除了讨论上面的机制之外,还将描述求值模式(eval()
),此种情况实际上并不用于禁用梯度计算的方法,但因为它的名字,它经常与这三个方法混在一起。
requires_grad
#
requires_grad
是一个标志,默认为f alse,除非包装在 Parameter
中,允许从梯度计算中细粒度地排除子图。它在向前和向后传递时都有效:
在正向传递过程中,只有当至少一个输入张量需要梯度时,运算才会被记录在后向计算图中。在向后传递(.backward()
)期间,只有 requires_grad=True
的叶张量才会将梯度累积到它们的 .grad
字段中。
重要的是要注意,即使每个张量都有这个标志,设置它只对叶张量(没有 grad_fn
的张量,例如 nn.Module
的参数)有意义。非叶张量(有 grad_fn
的张量)是有反向计算图的张量。因此,它们的梯度将需要作为中间结果来计算需要梯度的叶张量的梯度。从这个定义可以清楚地看出,所有的非叶张量都会自动具有 require_grad=True
。
设置 requires_grad
应该是控制模型的哪些部分是梯度计算的一部分的主要方法,例如,如果您需要在模型微调期间冻结预先训练的模型的部分。
要冻结模型的部分,只需对不想更新的参数应用 .requires_grad_(False)
。正如上面所描述的,由于使用这些参数作为输入的计算不会被记录在正向传递中,因此它们的 .grad
字段不会在向后传递中更新,因为它们不会像预期的那样首先成为后向计算图的一部分。
因为这是非常常见的模式,所以 requires_grad
也可以用 requires_grad_()
在模块级别设置。当应用到模块时,.requires_grad_()
对模块的所有参数生效(默认情况下 requires_grad=True
)。
grad 模式#
除了设置 requires_grad
之外,Python 中还有三种可能的模式可以影响 PyTorch 内部 autograd 处理计算的方式:默认模式(grad 模式)、 no-grad 模式和 inference 模式,所有这些都可以通过上下文管理器和装饰器进行切换。
默认模式(grad 模式)#
“默认模式”实际上是在没有启用其他模式(如 no-grad 和 inference 模式)时隐含的模式。与“no-grad 模式”相比,默认模式有时也被称为“grad 模式”。
关于默认模式,需要知道的最重要的事情是,它是 requires_grad
生效的唯一模式。requires_grad
在其他两种模式中总是被重写为 False
。
no-grad 模式#
在 no-grad 模式下的计算行为就好像没有任何输入需要梯度。换句话说,即使有 require_grad=True
的输入,在 no-grad 模式下的计算也不会记录在后向计算图中。
当您需要执行不应由 autograd 记录的运算,但您仍希望稍后在 grad 模式中使用这些计算的输出时,请启用 no-grad 模式。这个上下文管理器可以方便地禁用代码块或函数的梯度,而不必临时设置张量 requires_grad=False
,然后返回 True
。
例如,在编写优化器时,no-grad 模式可能很有用:在执行训练更新时,您希望在不被 autograd 记录更新的情况下就地更新参数。您还打算在下一个向前传递中使用更新的参数进行 grad 模式下的计算。
torch.nn.init
中的实现在初始化参数时也依赖于 no-grad 模式,以避免在原地更新初始化参数时 autograd 跟踪。
inference 模式#
inference 模式 是 no-grad 模式的极端版本。就像在 no-grad 模式下一样,inference 模式下的计算不会记录在后向计算图中,但是启用 inference 模式将允许 PyTorch 进一步加快模型的速度。这种更好的运行时也有一个缺点:在 inference 模式中创建的张量将不能用于退出 inference 模式后由 autograd 记录的计算。
当您正在执行不需要在后向计算图中记录的计算时,并且您不打算在稍后由 autograd 记录的任何计算中使用在 inference 模式中创建的张量时,启用 inference 模式。
建议您在代码中不需要 autograd 跟踪的部分尝试 inference 模式(例如,数据处理和模型评估)。如果它对您的用例来说是开箱即用的,那么它就获得了自由的性能胜利。如果在启用 inference 模式后遇到错误,请检查是否在退出 inference 模式后由 autograd 记录的计算中使用在 inference 模式中创建的张量。如果您无法避免在您的情况下使用这种方法,您可以总是切换回 no-grad 模式。
eval 模式(torch.nn.Module.eval()
)#
求值模式实际上并不是局部禁用梯度计算的机制。在这里包含它,因为它有时被混淆为这样一种机制。
从函数上讲,module.eval()
(或等价的 module.train(False)
)与 no-grad 模式和 inference 模式完全正交。module.eval()
如何影响模型完全取决于模型中使用的特定模块,以及它们是否定义了任何特定于 training-mode 的行为。
如果你的模型依赖于诸如 torch.nn.Dropout
和 torch.nn.BatchNorm2d
这样的模块,你就有责任调用 model.eval()
和 model.train()
。例如,为了避免在验证数据上更新您的 BatchNorm 运行统计信息,它可能根据训练模式的不同而表现不同。
建议您在训练时使用 model.train()
,在评估模型(验证/测试)时使用 model.eval()
,即使您不确定您的模型是否具有特定于训练模式的行为,因为您正在使用的模块可能会被更新为在训练和评估模式中表现不同。
具有 autograd 的 in-place 运算#
在 autograd 中支持 in-place 运算是一件困难的事情,不鼓励在大多数情况下使用它们。Autograd 积极的缓冲区释放和重用使得它非常高效,而且很少有 in-place 运算真正大幅降低内存使用量的情况。除非您在沉重的内存压力下运行,否则您可能永远都不需要使用它们。
有两个主要原因限制了 in-place 运算的适用性:
in-place 运算可能会覆盖计算梯度所需的值。
每个 in-place 运算实际上都需要实现重写计算图。Out-of-place 版本只需分配新对象并保持对旧图的引用,而 in-place 运算则需要更改表示此运算的
Function
的所有输入的创建者。这可能是棘手的,特别是如果有许多Tensor
引用相同的存储(例如,通过索引或转置创建),如果修改的输入的存储被任何其他张量引用,就位函数实际上会引发错误。
in-place 正确性检查#
每个张量都保留版本计数器,在任何运算中,每当它被标记为 dirty 时,该计数器就会递增。当函数为向后保存任何张量时,包含它们的张量的版本计数器也会被保存。一旦你进入 self.saved_tensors
进行检查,如果它大于保存的值,则会引发错误。这可以确保如果您使用 in-place 函数且没有看到任何错误,则可以确保计算的梯度是正确的。
多线程 Autograd#
autograd 引擎负责运行计算向后传递所需的所有向后运算。本节将描述可以帮助您在多线程环境中最好地利用它的所有细节。
用户可以用多线程代码训练他们的模型(例如 Hogwild training),并且不会阻塞并发的向后计算,示例代码可以是:
import threading
# Define a train function to be used in different threads
def train_fn():
x = torch.ones(5, 5, requires_grad=True)
# forward
y = (x + 3) * (x + 4) * 0.5
# backward
y.sum().backward()
# potential optimizer update
# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
p = threading.Thread(target=train_fn, args=())
p.start()
threads.append(p)
for p in threads:
p.join()
用户应该注意的一些行为:
CPU 并发性#
当你在 CPU 的多个线程中通过 python 或 C++ API backward()
或 grad()
运行时,你期望看到额外的并发性,而不是在执行期间以特定的顺序序列化所有 backward 调用(PyTorch 1.6之前的行为)。
Non-determinism#
如果你在多线程上并发调用 backward()
,但是有共享的输入(例如 Hogwild CPU training)。由于参数是跨线程自动共享的,所以在跨线程的向后调用中梯度积累可能变得不确定,因为两个向后调用可能访问并试图积累相同的 .grad
属性。这在技术上是不安全的,它可能会导致竞争状态,结果可能是无效的使用。
但如果您使用多线程方法来驱动整个训练过程,但使用共享参数,那么这是预期的模式,使用多线程的用户应该在心里有线程模型,并应该预期会发生这种情况。用户可以使用函数 API torch.autograd.grad()
来计算梯度,而不是 backward()
,以避免不确定性。
Graph retaining#
如果 autograd 图的一部分是在线程之间共享的,即运行单线程的第一部分,然后在多线程中运行第二部分,那么图的第一部分是共享的。在这种情况下,不同的线程在同一个图上执行 grad()
或 backward()
可能会在一个线程运行时破坏图,而另一个线程在这种情况下会崩溃。Autograd 将错误传递给用户,类似于调用 backward()
两次,不带 retain_graph=True
,并让用户知道他们应该使用 retain_graph=True
。
Autograd 节点上的线程安全#
由于 Autograd 允许调用者线程为了潜在的并行性而驱动它的向后执行,因此我们必须通过共享 GraphTask 部分/全部的向后并行来确保 CPU 上的线程安全。
自定义 Python torch.autograd.Function
是自动线程安全的,因为 GIL。对于内置的 C++ Autograd Node(例如 AccumulateGrad, CopySlices)和自定义 autograd::Function
,Autograd Engine 使用线程互斥锁来确保可能具有写/读状态的自动升级节点上的线程安全。
C++ 钩子上没有线程安全#
Autograd 依赖于用户编写线程安全的 C++钩子。如果希望在多线程环境中正确应用钩子,则需要编写适当的线程锁定代码,以确保钩子是线程安全的。
复数的 Autograd#
简短的版本:
当您使用 PyTorch 对任何具有复数域和/或 codomain 的函数 \(f(z)\) 求导时,梯度是在假设该函数是更大的实值损失函数 \(g(input)=L\) 的一部分的情况下计算的。计算的梯度是 \(\frac{\partial L}{\partial z^*}\)(注意 \(z\) 的共轭),其负数恰恰是梯度下降算法中使用的最陡下降方向。因此,所有现有的优化器都可以使用复数的参数开箱即用。
该约定与 TensorFlow 关于复数微分的约定相匹配,但与 JAX(JAX 计算 \(\frac{\partial L}{\partial z}\))不同。
如果您有一个在内部使用复数运算的实对实函数,那么这里的约定并不重要:您将始终得到与仅使用实操作实现时相同的结果。
如果您对数学细节感到好奇,或者想知道如何在 PyTorch 中定义复杂的导数,请继续阅读。
什么是复导数?#
复微分的数学定义取导数的极限定义,并将其推广到复数上。考虑函数 \(f: ℂ → ℂ\),
其中 \(u\) 和 \(v\) 是两个变量实值函数。
用导数的定义,可以写
In order for this limit to exist, not only must \(u\) and \(v\) must be real differentiable, but \(f\) must also satisfy the Cauchy-Riemann equations. In other words: the limit computed for real and imaginary steps (\(h\)) must be equal. This is a more restrictive condition.
The complex differentiable functions are commonly known as holomorphic functions. They are well behaved, have all the nice properties that you’ve seen from real differentiable functions, but are practically of no use in the optimization world. For optimization problems, only real valued objective functions are used in the research community since complex numbers are not part of any ordered field and so having complex valued loss does not make much sense.
It also turns out that no interesting real-valued objective fulfill the Cauchy-Riemann equations. So the theory with homomorphic function cannot be used for optimization and most people therefore use the Wirtinger calculus.
Wirtinger Calculus comes in picture …#
So, we have this great theory of complex differentiability and holomorphic functions, and we can’t use any of it at all, because many of the commonly used functions are not holomorphic. What’s a poor mathematician to do? Well, Wirtinger observed that even if \(f(z)\) isn’t holomorphic, one could rewrite it as a two variable function \(f(z, z*)\) which is always holomorphic. This is because real and imaginary of the components of \(z\) can be expressed in terms of \(z\) and \(z^*\) as:
\[\begin{split}\begin{aligned} Re(z) &= \frac {z + z^*}{2} \\ Im(z) &= \frac {z - z^*}{2j} \end{aligned}\end{split}\]
Wirtinger calculus suggests to study \(f(z, z^*)\) instead, which is guaranteed to be holomorphic if \(f\) was real differentiable (another way to think of it is as a change of coordinate system, from \(f(x, y)\) to \(f(z, z^*)\).) This function has partial derivatives \(\frac{\partial }{\partial z}\) and \(\frac{\partial}{\partial z^{*}}\). We can use the chain rule to establish a relationship between these partial derivatives and the partial derivatives w.r.t., the real and imaginary components of \(z\).
\[\begin{split}\begin{aligned} \frac{\partial }{\partial x} &= \frac{\partial z}{\partial x} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial x} * \frac{\partial }{\partial z^*} \\ &= \frac{\partial }{\partial z} + \frac{\partial }{\partial z^*} \\ \\ \frac{\partial }{\partial y} &= \frac{\partial z}{\partial y} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial y} * \frac{\partial }{\partial z^*} \\ &= 1j * (\frac{\partial }{\partial z} - \frac{\partial }{\partial z^*}) \end{aligned}\end{split}\]
From the above equations, we get:
\[\begin{split}\begin{aligned} \frac{\partial }{\partial z} &= 1/2 * (\frac{\partial }{\partial x} - 1j * \frac{\partial }{\partial y}) \\ \frac{\partial }{\partial z^*} &= 1/2 * (\frac{\partial }{\partial x} + 1j * \frac{\partial }{\partial y}) \end{aligned}\end{split}\]
which is the classic definition of Wirtinger calculus that you would find on Wikipedia.
There are a lot of beautiful consequences of this change.
For one, the Cauchy-Riemann equations translate into simply saying that \(\frac{\partial f}{\partial z^*} = 0\) (that is to say, the function \(f\) can be written entirely in terms of \(z\), without making reference to \(z^*\)).
Another important (and somewhat counterintuitive) result, as we’ll see later, is that when we do optimization on a real-valued loss, the step we should take while making variable update is given by \(\frac{\partial Loss}{\partial z^*}\) (not \(\frac{\partial Loss}{\partial z}\)).
For more reading, check out: https://arxiv.org/pdf/0906.4835.pdf
How is Wirtinger Calculus useful in optimization?#
Researchers in audio and other fields, more commonly, use gradient descent to optimize real valued loss functions with complex variables. Typically, these people treat the real and imaginary values as separate channels that can be updated. For a step size \(\alpha/2\) and loss \(L\), we can write the following equations in \(ℝ^2\):
\[\begin{split}\begin{aligned} x_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} \\ y_{n+1} &= y_n - (\alpha/2) * \frac{\partial L}{\partial y} \end{aligned}\end{split}\]
How do these equations translate into complex space \(ℂ\)?
\[\begin{split}\begin{aligned} z_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} + 1j * (y_n - (\alpha/2) * \frac{\partial L}{\partial y}) \\ &= z_n - \alpha * 1/2 * (\frac{\partial L}{\partial x} + j \frac{\partial L}{\partial y}) \\ &= z_n - \alpha * \frac{\partial L}{\partial z^*} \end{aligned}\end{split}\]
Something very interesting has happened: Wirtinger calculus tells us that we can simplify the complex variable update formula above to only refer to the conjugate Wirtinger derivative \(\frac{\partial L}{\partial z^*}\), giving us exactly the step we take in optimization.
Because the conjugate Wirtinger derivative gives us exactly the correct step for a real valued loss function, PyTorch gives you this derivative when you differentiate a function with a real valued loss.
How does PyTorch compute the conjugate Wirtinger derivative?#
Typically, our derivative formulas take in grad_output as an input, representing the incoming Vector-Jacobian product that we’ve already computed, aka, \(\frac{\partial L}{\partial s^*}\), where \(L\) is the loss of the entire computation (producing a real loss) and \(s\) is the output of our function. The goal here is to compute \(\frac{\partial L}{\partial z^*}\), where \(z\) is the input of the function. It turns out that in the case of real loss, we can get away with only calculating \(\frac{\partial L}{\partial z^*}\), even though the chain rule implies that we also need to have access to \(\frac{\partial L}{\partial z^*}\). If you want to skip this derivation, look at the last equation in this section and then skip to the next section.
Let’s continue working with \(f: ℂ → ℂ\) defined as \(f(z) = f(x+yj) = u(x, y) + v(x, y)j\). As discussed above, autograd’s gradient convention is centered around optimization for real valued loss functions, so let’s assume \(f\) is a part of larger real valued loss function \(g\). Using chain rule, we can write:
(9)#\[\frac{\partial L}{\partial z^*} = \frac{\partial L}{\partial u} * \frac{\partial u}{\partial z^*} + \frac{\partial L}{\partial v} * \frac{\partial v}{\partial z^*}\]
Now using Wirtinger derivative definition, we can write:
\[\begin{split}\begin{aligned} \frac{\partial L}{\partial s} = 1/2 * (\frac{\partial L}{\partial u} - \frac{\partial L}{\partial v} j) \\ \frac{\partial L}{\partial s^*} = 1/2 * (\frac{\partial L}{\partial u} + \frac{\partial L}{\partial v} j) \end{aligned}\end{split}\]
It should be noted here that since \(u\) and \(v\) are real functions, and \(L\) is real by our assumption that \(f\) is a part of a real valued function, we have:
(10)#\[(\frac{\partial L}{\partial s})^* = \frac{\partial L}{\partial s^*}\]
i.e., \(\frac{\partial L}{\partial s}\) equals to \(grad\_output^*\).
Solving the above equations for \(\frac{\partial L}{\partial u}\) and \(\frac{\partial L}{\partial v}\), we get:
(11)#\[\begin{split}\begin{aligned} \frac{\partial L}{\partial u} = \frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*} \\ \frac{\partial L}{\partial v} = -1j * (\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}) \end{aligned}\end{split}\]
Substituting (11) in (9), we get:
\[\begin{split}\begin{aligned} \frac{\partial L}{\partial z^*} &= (\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}) * \frac{\partial u}{\partial z^*} - 1j * (\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}) * \frac{\partial v}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * (\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j) + \frac{\partial L}{\partial s^*} * (\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j) \\ &= \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s} * \frac{\partial (u + vj)^*}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\ \end{aligned}\end{split}\]
Using (10), we get:
(12)#\[\begin{split}\begin{aligned} \frac{\partial L}{\partial z^*} &= (\frac{\partial L}{\partial s^*})^* * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * (\frac{\partial s}{\partial z})^* \\ &= \boxed{ (grad\_output)^* * \frac{\partial s}{\partial z^*} + grad\_output * {(\frac{\partial s}{\partial z})}^* } \\ \end{aligned}\end{split}\]
This last equation is the important one for writing your own gradients, as it decomposes our derivative formula into a simpler one that is easy to compute by hand.
How can I write my own derivative formula for a complex function?#
The above boxed equation gives us the general formula for all derivatives on complex functions. However, we still need to compute \(\frac{\partial s}{\partial z}\) and \(\frac{\partial s}{\partial z^*}\). There are two ways you could do this:
The first way is to just use the definition of Wirtinger derivatives directly and calculate \(\frac{\partial s}{\partial z}\) and \(\frac{\partial s}{\partial z^*}\) by using \(\frac{\partial s}{\partial x}\) and \(\frac{\partial s}{\partial y}\) (which you can compute in the normal way).
The second way is to use the change of variables trick and rewrite \(f(z)\) as a two variable function \(f(z, z^*)\), and compute the conjugate Wirtinger derivatives by treating \(z\) and \(z^*\) as independent variables. This is often easier; for example, if the function in question is holomorphic, only \(z\) will be used (and \(\frac{\partial s}{\partial z^*}\) will be zero).
Let’s consider the function \(f(z = x + yj) = c * z = c * (x+yj)\) as an example, where \(c \in ℝ\).
Using the first way to compute the Wirtinger derivatives, we have.
Using (12), and grad_output = 1.0 (which is the default grad output value used when backward()
is called on a scalar output in PyTorch), we get:
\[\frac{\partial L}{\partial z^*} = 1 * 0 + 1 * c = c\]
Using the second way to compute Wirtinger derivatives, we directly get:
\[\begin{split}\begin{aligned} \frac{\partial s}{\partial z} &= \frac{\partial (c*z)}{\partial z} \\ &= c \\ \frac{\partial s}{\partial z^*} &= \frac{\partial (c*z)}{\partial z^*} \\ &= 0 \end{aligned}\end{split}\]
And using (12) again, we get \(\frac{\partial L}{\partial z^*} = c\). As you can see, the second way involves lesser calculations, and comes in more handy for faster calculations.
What about cross-domain functions?#
Some functions map from complex inputs to real outputs, or vice versa. These functions form a special case of (12), which we can derive using the chain rule:
For \(f: ℂ → ℝ\), we get:
\[\frac{\partial L}{\partial z^*} = 2 * grad\_output * \frac{\partial s}{\partial z^{*}}\]For \(f: ℝ → ℂ\), we get:
\[\frac{\partial L}{\partial z^*} = 2 * Re(grad\_out^* * \frac{\partial s}{\partial z^{*}})\]
钩子用于保存张量#
通过定义一对 pack_hook
/ unpack_hook
钩子,可以控制保存的张量如何打包/解包。pack_hook
函数应该接受张量作为它的单个参数,但可以返回任何 Python 对象(例如,另一个张量,一个元组,甚至一个包含文件名的字符串)。unpack_hook
函数将 pack_hook
的输出作为它的单个参数,并且应该返回一个在向后传递中使用的张量。unpack_hook
返回的张量只需要具有与作为输入传递给 pack_hook
的张量相同的内容。特别是,任何与 autograd 相关的元数据都可以被忽略,因为它们将在解包期间被覆盖。
这类配对的一个例子是:
import os
import uuid
class SelfDeletingTempFile:
def __init__(self, tmp_dir=".temp"):
self.name = os.path.join(tmp_dir, str(uuid.uuid4()))
def __del__(self):
os.remove(self.name)
def pack_hook(tensor):
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(temp_file):
return torch.load(temp_file.name)
注意 unpack_hook
不应该删除临时文件,因为它可能会被调用多次:只要返回的 SelfDeletingTempFile
对象是活的,临时文件就应该是活的。在上面的例子中,通过在不再需要临时文件时关闭它(在删除 SelfDeletingTempFile
对象时)来防止临时文件泄漏。
备注
保证 pack_hook
只会被调用一次,但是 unpack_hook
可以被调用多次,只要向后传递需要,期望它每次都返回相同的数据。
警告
禁止对任何函数的输入执行 inplace 运算,因为它们可能导致意想不到的副作用。如果打包钩子的输入被 inplace 修改,PyTorch 将抛出错误,但不会捕获解包钩子的输入被 inplace 修改的情况。
为保存的张量注册钩子#
通过调用 SavedTensor
对象上的 register_hooks()
方法,可以在保存的张量上注册一对钩子。这些对象作为 grad_fn
的属性公开,并以 _raw_saved_
前缀开始。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)
一旦注册了 pair, pack_hook
方法就会被调用。每次需要访问保存的张量时,都会调用 unpack_hook
方法,可以通过 y.grad_fn._saved_self
来访问或在向后传递期间。
警告
如果你在保存的张量被释放后(即在向后被调用后)维持对 SavedTensor
的引用,调用它的 register_hooks()
是被禁止的。PyTorch 在大多数情况下都会抛出错误,但在某些情况下它可能不会这样做,并可能出现未定义的行为。
为保存的张量注册默认挂钩#
或者,您可以使用上下文管理器 :class:~torch.autograd.graph.saved_tensors_hooks
来注册一对钩子,这对钩子将应用于在该上下文中创建的所有已保存的张量。
from torch import nn
# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000
def pack_hook(x):
if x.numel() < SAVE_ON_DISK_THRESHOLD:
return x
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(tensor_or_sctf):
if isinstance(tensor_or_sctf, torch.Tensor):
return tensor_or_sctf
return torch.load(tensor_or_sctf.name)
class Model(nn.Module):
def forward(self, x):
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
# ... compute output
output = x
return output
model = Model()
net = nn.DataParallel(model)
用这个上下文管理器定义的钩子是 thread-local。因此,下面的代码不会产生预期的效果,因为钩子没有经过 torch.nn.DataParallel
。
# Example what NOT to do
net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
output = net(input)
注意,使用这些钩子禁用了所有的优化,以减少张量对象的创建。例如:
with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
x = torch.randn(5, requires_grad=True)
y = x * x
如果没有钩子,则 x
, y.grad_fn._saved_self
和 y.grad_fn._saved_other
都指向同一个张量对象。有了钩子,PyTorch 将把 x
打包和解包到两个新的张量对象中,它们与原始的 x
共享相同的存储空间(没有执行拷贝)。