FX API 参考#
symbolic_trace()
#
给定 Module
或函数实例 root
,此函数将返回 GraphModule
,该 GraphModule
是通过记录跟踪 root
时看到的运算构造的。
concrete_args
允许您对函数进行部分专门化,无论是删除控制流还是数据结构。
例如:
def f(a, b):
if b == True:
return a
else:
return a*2
由于控制流的存在,FX 通常无法进行跟踪。但是,可以使用 concrete_args
专门化 b
的值来跟踪它。
from torch import fx
f = fx.symbolic_trace(f,
concrete_args={"b": False})
assert f(3, False) == 6
注意,尽管您仍然可以传入 b
的不同值,但它们将被忽略。
还可以使用 concrete_args
从函数中消除数据结构处理。这将使用 pytrees 来平展您的输入。为了避免过度专门化,传入 fx.PH
值不应该特化。例如:
def f(x):
out = 0
for v in x.values():
out += v
return out
f = fx.symbolic_trace(f,
concrete_args={'x': {'a': fx.PH,
'b': fx.PH,
'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7
torch.fx.wrap()
#
wrap()
函数可以在模块级范围内调用,将 fn_or_name
注册为“叶函数”。“叶函数”将被保留为 FX 跟踪中的 CallFunction 节点,而不是被跟踪:
# foo/bar/baz.py
def my_custom_function(x, y):
return x * x + y * y
torch.fx.wrap('my_custom_function')
def fn_to_be_traced(x, y):
# When symbolic tracing, the below call to my_custom_function will be inserted into
# the graph rather than tracing it.
return my_custom_function(x, y)
这个函数也可以等价地用作装饰器:
# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):
return x * x + y * y