FX API 参考

FX API 参考#

symbolic_trace()#

给定 Module 或函数实例 root,此函数将返回 GraphModule,该 GraphModule 是通过记录跟踪 root 时看到的运算构造的。

concrete_args 允许您对函数进行部分专门化,无论是删除控制流还是数据结构。

例如:

def f(a, b):
    if b == True:
        return a
    else:
        return a*2

由于控制流的存在,FX 通常无法进行跟踪。但是,可以使用 concrete_args 专门化 b 的值来跟踪它。

from torch import fx

f = fx.symbolic_trace(f, 
                      concrete_args={"b": False}) 
assert f(3, False) == 6

注意,尽管您仍然可以传入 b 的不同值,但它们将被忽略。

还可以使用 concrete_args 从函数中消除数据结构处理。这将使用 pytrees 来平展您的输入。为了避免过度专门化,传入 fx.PH 值不应该特化。例如:

def f(x):
    out = 0
    for v in x.values():
        out += v
    return out
f = fx.symbolic_trace(f, 
                      concrete_args={'x': {'a': fx.PH, 
                                           'b': fx.PH, 
                                           'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7

torch.fx.wrap()#

wrap() 函数可以在模块级范围内调用,将 fn_or_name 注册为“叶函数”。“叶函数”将被保留为 FX 跟踪中的 CallFunction 节点,而不是被跟踪:

# foo/bar/baz.py
def my_custom_function(x, y):
    return x * x + y * y

torch.fx.wrap('my_custom_function')

def fn_to_be_traced(x, y):
    # When symbolic tracing, the below call to my_custom_function will be inserted into
    # the graph rather than tracing it.
    return my_custom_function(x, y)

这个函数也可以等价地用作装饰器:

# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):
    return x * x + y * y