原语库#
在这个例子中,将定义“复合”(composite)运算库。复合运算是定义为可调用函数的运算,这些函数在其实现中由多个其他运算组成。
复合运算允许您选择在什么抽象级别上解释/运算代码。我们演示了可以提供一个函数来内联这些函数,也可以使用自定义 Tracer
来自动内联这些函数。
组合运算对于向后端/变换公开更高级别的上下文,同时仍然保持在更细粒度级别检查内容的能力很有用。
import torch
from torch import fx
def sigmoid_lowp(x: torch.Tensor):
x = x.float()
x = x.sigmoid()
return x.half()
wrap()
表示传入的函数应该始终被记录为 call_function
节点,而不是被跟踪。稍后,我们将看到如何做到:
a. 内联这样一个函数的实现; b. 定义一个跟踪器,自动跟踪这样一个函数
# primitive_library.py
fx.wrap(sigmoid_lowp)
同样:
# primitive_library.py
def add_lowp(a: torch.Tensor, b: torch.Tensor):
a, b = a.float(), b.float()
c = a + b
return c.half()
torch.fx.wrap(add_lowp)
看看在使用这些函数的代码中进行符号跟踪时会发生什么
from primitive_library import sigmoid_lowp, add_lowp
class Foo(torch.nn.Module):
def forward(self, x, y):
x = sigmoid_lowp(x)
y = sigmoid_lowp(y)
return add_lowp(x, y)
traced = fx.symbolic_trace(Foo())
print(traced.code)
def forward(self, x, y):
float_1 = x.float(); x = None
sigmoid = float_1.sigmoid(); float_1 = None
half = sigmoid.half(); sigmoid = None
float_2 = y.float(); y = None
sigmoid_1 = float_2.sigmoid(); float_2 = None
half_1 = sigmoid_1.half(); sigmoid_1 = None
float_3 = half.float(); half = None
float_4 = half_1.float(); half_1 = None
add = float_3 + float_4; float_3 = float_4 = None
half_2 = add.half(); add = None
return half_2
注意 sigmoid_lowp
和 add_lowp
的调用出现在跟踪中;他们自身没有被追踪.
内联回调#
定义一个函数,允许在 graph 运算期间内联这些调用。
def inline_lowp_func(n : fx.Node):
# If we find a call to a function in our "lowp" module, inline it
if n.op == 'call_function' and n.target.__module__ == inline_lowp_func.__module__:
# We want to insert the operations comprising the implementation of the
# function before the function itself. Then, we can swap the output value
# of the function call with the output value for its implementation nodes
tracer = fx.proxy.GraphAppendingTracer(n.graph)
with n.graph.inserting_before(n):
# We can inline code by using `fx.Proxy` instances.
# map_arg traverses all aggregate types and applies the given function
# to Node instances in the data structure. In this case, we are applying
# the fx.Proxy constructor.
proxy_args = torch.fx.node.map_arg(n.args, lambda x: torch.fx.Proxy(x, tracer))
proxy_kwargs = torch.fx.node.map_arg(n.kwargs, lambda x: torch.fx.Proxy(x, tracer))
# Call the function itself with proxy arguments. This will emit
# nodes in the graph corresponding to the operations in the im-
# plementation of the function
output_proxy = n.target(*proxy_args, **proxy_kwargs)
# Now replace the original node's uses with the output node of
# the implementation.
node.replace_all_uses_with(output_proxy.node)
# Delete the old node
node.graph.erase_node(node)
for node in traced.graph.nodes:
if node.op == 'call_function' and node.target is sigmoid_lowp:
inline_lowp_func(node)
# 不要忘记在 Graph 运算之后重新编译
new_code = traced.recompile()
print(traced.code)
def forward(self, x, y):
float_1 = x.float(); x = None
sigmoid = float_1.sigmoid(); float_1 = None
half = sigmoid.half(); sigmoid = None
float_2 = y.float(); y = None
sigmoid_1 = float_2.sigmoid(); float_2 = None
half_1 = sigmoid_1.half(); sigmoid_1 = None
float_3 = half.float(); half = None
float_4 = half_1.float(); half_1 = None
add = float_3 + float_4; float_3 = float_4 = None
half_2 = add.half(); add = None
return half_2
此时,sigmoid_lowp
的实现已被替换为所有对该函数的调用。
跟踪期间的内联调用#
现在将定义自定义跟踪器,它可以有选择地动态内联对某些组合运算的调用。
f = Foo()
class InliningTracer(fx.Tracer):
FNS_TO_INLINE = [add_lowp]
def create_node(self, kind, target, args, kwargs, name=None, type_expr=None):
if kind == 'call_function' and target in self.FNS_TO_INLINE:
tracer = fx.proxy.GraphAppendingTracer(self.graph)
# Trace through the implementation of the function rather than
# create a node
proxy_args = fx.node.map_arg(args, lambda x: torch.fx.Proxy(x, tracer))
proxy_kwargs = fx.node.map_arg(kwargs, lambda x: torch.fx.Proxy(x, tracer))
return target(*proxy_args, **proxy_kwargs).node
else:
return super().create_node(kind, target, args, kwargs, name, type_expr)
tracer = InliningTracer()
graph = tracer.trace(f)
module = torch.fx.GraphModule(f, graph)
print(module.code)
def forward(self, x, y):
float_1 = x.float(); x = None
sigmoid = float_1.sigmoid(); float_1 = None
half = sigmoid.half(); sigmoid = None
float_2 = y.float(); y = None
sigmoid_1 = float_2.sigmoid(); float_2 = None
half_1 = sigmoid_1.half(); sigmoid_1 = None
float_3 = half.float(); half = None
float_4 = half_1.float(); half_1 = None
add = float_3 + float_4; float_3 = float_4 = None
half_2 = add.half(); add = None
return half_2
正如你所看到的,add_lowp
的实现已经在使用我们的 InliningTracer
进行跟踪的过程中内联了。例如,这样的功能可以用于实现后端,该后端希望看到某些运算的低级形式,但希望看到另一些运算的高级形式。