replace_pattern() 重写子图#

对于仅由替换组成的简单变换,还可以使用 subgraph_rewriter

FX 在直接 Graph 操作的基础上还提供了另一个自动化级别。replace_pattern() API 本质上是编辑 Graph 的“查找/替换”工具。它允许您指定 patternreplacement,它将跟踪这些函数,在 pattern graph 中查找运算组的实例,并用 replacement graph 的副本替换这些实例。随着变换变得更加复杂,这些代码可能会变得笨拙,这有助于极大地自动化繁琐的 graph 操作代码。

GraphModulegm)的 graph 中匹配所有可能不重叠的算子集及其数据依赖关系(pattern),然后用另一个子图替换每个匹配的子图(replacement)。

返回值是 Match 对象列表,表示与 pattern 相匹配的原始 graph 中的位置。如果没有相匹配的,则列表为空。匹配定义为:

class Match(NamedTuple):
    # 从中找到匹配的 Node
    anchor: Node
    # 将 pattern subgraph 中的节点映射到较大 graph 中的节点
    nodes_map: Dict[Node, Node]

备注

pattern 中的 return 语句只根据它的值进行匹配;它可能与较大图中的 return 语句匹配,也可能不匹配。换句话说,模式不必扩展到更大的图的末尾。

比如:

from torch import nn, fx
import torch

class M(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, w1, w2):
        m1 = torch.cat([w1, w2]).sum()
        m2 = torch.cat([w1, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)

def pattern(w1, w2):
    return torch.cat([w1, w2]).sum()

def replacement(w1, w2):
    return torch.stack([w1, w2])

traced_module = fx.symbolic_trace(M())
fx.subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
[Match(anchor=max_1, nodes_map={output: max_1, sum_1: sum_1, cat: cat, w1: w1, w2: w2}),
 Match(anchor=max_2, nodes_map={output: max_2, sum_1: sum_2, cat: cat_1, w1: w1, w2: w2})]

上面的代码将首先匹配 traced_moduleforward 方法中的 pattern。例如,如果 p = torch.cat([a, b])pattern 中,你可以在原 forward 函数中匹配 m = torch.cat([a, b]),尽管变量名不同(p vs m)。

traced_module.graph.lint()
print(traced_module.code)
def forward(self, x, w1, w2):
    stack = torch.stack([w1, w2])
    max_1 = torch.max(stack);  stack = None
    add = x + max_1;  x = max_1 = None
    stack_1 = torch.stack([w1, w2]);  w1 = w2 = None
    max_2 = torch.max(stack_1);  stack_1 = None
    add_1 = add + max_2;  add = max_2 = None
    return add_1
    

下面介绍一些使用案例。

定义验证函数:

def eval_result(traced, pattern, replacement, comparison, x):
    comparison_fn = fx.symbolic_trace(comparison)
    fx.subgraph_rewriter.replace_pattern(traced, pattern, replacement)
    traced.graph.lint()
    ref_output = comparison_fn(x)
    test_output = traced.forward(x)
    torch.testing.assert_close(ref_output, test_output)

保留底层逻辑#

用相同的模式替换 pattern,不应该改变底层逻辑

class M(nn.Module):
    def forward(self, x):
        val = torch.neg(x) + torch.relu(x)
        return torch.add(val, val)

def pattern(x):
    return torch.neg(x) + torch.relu(x)

def comparison(x):
    val = torch.neg(x) + torch.relu(x)
    return torch.add(val, val)


x = torch.rand(1, 3)
traced = fx.symbolic_trace(M())
eval_result(traced, pattern, pattern, comparison, x)

替换单个节点#

添加单个线性结构 relu

class M(nn.Module):
    def forward(self, x):
        val = torch.neg(x)
        return torch.add(val, val)

def pattern(x):
    return torch.neg(x)

def replacement(x):
    return torch.relu(x)

def comparison(x):
    val = torch.relu(x)
    return torch.add(val, val)


x = torch.rand(1, 3)
traced = fx.symbolic_trace(M())
eval_result(traced, pattern, replacement, comparison, x)

移除单个节点#

pattern 被匹配时,它将从更大的函数中删除,并被 replacement 替换。

class M(nn.Module):
    def forward(self, x):
        val = torch.neg(x) + torch.relu(x)
        return torch.add(val, val)

def pattern(x):
    return torch.neg(x) + torch.relu(x)

def replacement(x):
    return torch.relu(x)

def comparison(x):
    val = torch.relu(x)
    return torch.add(val, val)
    
x = torch.rand(1, 3)
traced = fx.symbolic_trace(M())
eval_result(traced, pattern, replacement, comparison, x)

多模式匹配#

如果在较大的函数中有多个 pattern 匹配,则每个不重叠的匹配将被替换。在匹配重叠的情况下,将替换重叠匹配集中第一个找到的匹配。(“第一”在这里被定义为节点使用-定义关系拓扑顺序中的第一。在大多数情况下,第一个 Node 是直接出现在 self 之后的参数,而最后一个 Node 是函数返回的任何值。)

class M(nn.Module):
    def forward(self, x, w1, w2):
        m1 = torch.cat([w1, w2]).sum()
        m2 = torch.cat([w1, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)

def pattern(w1, w2):
    return torch.cat([w1, w2]).sum()

def replacement(w1, w2):
    return torch.stack([w1, w2])

def comparison(x, w1, w2):
    m1 = torch.stack([w1, w2])
    m2 = torch.stack([w1, w2])
    return x + torch.max(m1) + torch.max(m2)

traced = fx.symbolic_trace(M())
comparison_fn = fx.symbolic_trace(comparison)

x = torch.rand(1, 3)
w1 = torch.rand(1, 3)
w2 = torch.rand(1, 3)

fx.subgraph_rewriter.replace_pattern(traced, pattern, replacement)

traced.graph.lint()

ref_outs = comparison_fn(x, w1, w2)
test_outs = traced.forward(x, w1, w2)
torch.testing.assert_close(ref_outs, test_outs)

备注

需要注意的一件重要的事情是,pattern Callable 的参数必须在 Callable 本身中使用,而 replacement Callable 的参数必须与 pattern 匹配。

第一个规则是,为什么在上面的代码块中,forward 函数有参数 x, w1, w2,而 pattern 函数只有参数 w1, w2pattern 不使用 x,因此它不应该指定 x 作为参数。

作为第二条规则的例子,考虑使用

def replacement(x, y):
    return torch.relu(x)

替换

def pattern(x, y):
    return torch.neg(x) + torch.relu(y)

在本例中,replacement 需要与 pattern 相同数量的参数(xy),即使 replacement 中没有使用参数 y

可以正确识别参数:

class M(nn.Module):
    def forward(self, x, y):
        val = torch.neg(y) + torch.relu(x)
        return torch.add(val, val)

def pattern(x):
    return torch.relu(x)

def replacement(x):
    return torch.neg(x)

def comparison(x, y):
    val = torch.neg(y) + torch.neg(x)
    return torch.add(val, val)

traced = fx.symbolic_trace(M())
comparison_fn = fx.symbolic_trace(comparison)

x = torch.randn(4, 4)
y = torch.randn(4, 4)

fx.subgraph_rewriter.replace_pattern(traced, pattern, replacement)
traced.graph.lint()

ref_outs = comparison_fn(x, y)
test_outs = traced.forward(x, y)
torch.testing.assert_close(ref_outs, test_outs)

追踪可回调对象#

class M(nn.Module):
    def forward(self, x):
        val = torch.neg(x) + torch.relu(x)
        return torch.add(val, val)

class Pattern(nn.Module):
    def forward(self, x):
        return torch.neg(x) + torch.relu(x)

class Replacement(nn.Module):
    def forward(self, x):
        return torch.sigmoid(x)

def comparison(x):
    val = torch.sigmoid(x)
    return torch.add(val, val)

traced = fx.symbolic_trace(M())
traced_pattern = fx.symbolic_trace(Pattern())
traced_replacement = fx.symbolic_trace(Replacement())
comparison_fn = fx.symbolic_trace(comparison)

x = torch.randn(3, 4)

fx.subgraph_rewriter.replace_pattern(traced, traced_pattern, traced_replacement)

traced.graph.lint()

ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
torch.testing.assert_close(ref_outs, test_outs)

替换整个计算图#

class M(torch.nn.Module):
    def forward(self, x):
        a = torch.neg(x)
        return torch.add(a, a)

def pattern(x):
    a = torch.neg(x)
    return torch.add(a, a)

def replacement(x):
    a = torch.sigmoid(x)
    return torch.cat([a, a])

traced = fx.symbolic_trace(M())
comparison_fn = fx.symbolic_trace(replacement)

x = torch.randn(3, 4)

fx.subgraph_rewriter.replace_pattern(traced, pattern, replacement)

traced.graph.lint()

ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
torch.testing.assert_close(ref_outs, test_outs)

子图重写器模式输出模式节点可以有不匹配的:

class M(nn.Module):
    def forward(self, x):
        y = torch.relu(x)
        return torch.neg(y) - y

def pattern(x):
    return torch.relu(x)

def replacement(x):
    return torch.sigmoid(x)

def comparison(x):
    y = torch.sigmoid(x)
    return torch.neg(y) - y

traced = fx.symbolic_trace(M())
comparison_fn = fx.symbolic_trace(comparison)

x = torch.randn(3, 4)

fx.subgraph_rewriter.replace_pattern(traced, pattern, replacement)

traced.graph.lint()

ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
torch.testing.assert_close(ref_outs, test_outs)

不匹配的情况:

class M(nn.Module):
    def forward(self, x, w1, w2, b1, b2):
        m1 = torch.cat([w1, w2])
        m2 = torch.cat([x, b2])
        t0 = torch.addmm(b1, m1, m2.t())
        t1 = torch.sum(w1, 1)
        t2 = torch.addmm(b1, m1, m2.t())
        return torch.sum(t1), torch.sum(t2)

def pattern(x, w1, w2, b1, b2):
    m1 = torch.cat([w1, w2])
    m2 = torch.cat([x, b2])
    return torch.addmm(b1, m1, m2.t())

def replacement(x, w1, w2, b1, b2):
    return torch.cat([x, w1, w2])

traced = fx.symbolic_trace(M())

# Result should be [] since no matches can be found
res = fx.subgraph_rewriter.replace_pattern(traced, pattern, replacement)
res
[]

匹配 placeholder#

这将测试 placeholder 节点是否可以与具有不同数量输入节点的节点相匹配。

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.dtype = torch.float16

    def forward(self, x):
        x += 3
        x = x.dequantize()
        x = torch.sigmoid(x)
        dtype = self.dtype
        x = x.to(dtype)
        return x

def pattern(x):
    x = x.dequantize()
    x = torch.sigmoid(x)
    x = x.to(torch.float16)
    return x

def replacement(x):
    return x

def comparison(x):
    return x + 3

原始的跟踪模块是这样的:

traced = fx.symbolic_trace(M())
traced.graph.print_tabular()
opcode         name        target                                                      args                      kwargs
-------------  ----------  ----------------------------------------------------------  ------------------------  --------
placeholder    x           x                                                           ()                        {}
call_function  add         <built-in function add>                                     (x, 3)                    {}
call_method    dequantize  dequantize                                                  (add,)                    {}
call_function  sigmoid     <built-in method sigmoid of type object at 0x7f25e4da7200>  (dequantize,)             {}
call_method    to          to                                                          (sigmoid, torch.float16)  {}
output         output      output                                                      (to,)                     {}

而想要匹配的模式是这样的:

comparison_fn = fx.symbolic_trace(comparison)
comparison_fn.graph.print_tabular()
opcode         name    target                   args    kwargs
-------------  ------  -----------------------  ------  --------
placeholder    x       x                        ()      {}
call_function  add     <built-in function add>  (x, 3)  {}
output         output  output                   (add,)  {}
fx.subgraph_rewriter.replace_pattern(traced, pattern, replacement)
traced.graph.lint()

x = torch.randn(3, 4)
ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
torch.testing.assert_close(ref_outs, test_outs)

替换被引用的子模块#

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.sigmoid = nn.Sigmoid()
        self.submod = nn.ReLU()

    def forward(self, x):
        x = x + 1
        return self.submod(self.sigmoid(x))

class Pattern(nn.Module):
    def __init__(self):
        super().__init__()
        self.sigmoid = nn.Sigmoid()
        self.submod = nn.ReLU()

    def forward(self, x):
        return self.submod(self.sigmoid(x))

class Replacement(nn.Module):
    def __init__(self):
        super().__init__()
        self.tanh = nn.Tanh()
        self.submod = nn.ReLU()

    def forward(self, x):
        return self.submod(self.tanh(x))

class Comparison(nn.Module):
    def __init__(self):
        super().__init__()
        self.tanh = nn.Tanh()
        self.submod = nn.ReLU()

    def forward(self, x):
        x = x + 1
        return self.submod(self.tanh(x))

traced = fx.symbolic_trace(M())
comparison = Comparison()

x = torch.randn(3, 4)

fx.subgraph_rewriter.replace_pattern(traced, Pattern(), Replacement())

traced.graph.lint()

ref_outs = comparison(x)
test_outs = traced.forward(x)
torch.testing.assert_close(ref_outs, test_outs)
traced.get_submodule("tanh")
Tanh()
traced.get_submodule("sigmoid")
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/graph/subgraph_rewriter.ipynb Cell 36 in <cell line: 1>()
----> <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/graph/subgraph_rewriter.ipynb#Y112sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> traced.get_submodule("sigmoid")

File /media/pc/data/4tb/lxw/libs/anaconda3/envs/tvmx/lib/python3.10/site-packages/torch/nn/modules/module.py:456, in Module.get_submodule(self, target)
    453 for item in atoms:
    455     if not hasattr(mod, item):
--> 456         raise AttributeError(mod._get_name() + " has no "
    457                              "attribute `" + item + "`")
    459     mod = getattr(mod, item)
    461     if not isinstance(mod, torch.nn.Module):

AttributeError: M has no attribute `sigmoid`
submod = traced.get_submodule("submod")
submod
ReLU()

注解整数#

from torch.fx.annotate import annotate
from torch.fx.experimental.rewriter import RewritingTracer

class M1(nn.Module):
    def forward(self, x):
        y: int = x
        return torch.add(x, y)

class M2(nn.Module):
    def forward(self, x):
        y = annotate(x, int)
        return torch.add(x, y)

ast_rewriter = RewritingTracer()
graph = ast_rewriter.trace(M1())

module = M2()
symbolic_traced = fx.symbolic_trace(module)
for n, m in zip(symbolic_traced.graph.nodes, graph.nodes):
    if n.op == 'placeholder':
        assert n.type == int
        assert m.type == int

重写连续的子模块#

def f(x):
    x = torch.sigmoid(x)
    x = torch.sigmoid(x)
    return torch.sigmoid(x)

def pattern(x):
    return torch.sigmoid(x)

def replacement(x):
    return torch.exp(x)

def comparison(x):
    x = torch.exp(x)
    x = torch.exp(x)
    return torch.exp(x)

traced = fx.symbolic_trace(f)
comparison_fn = fx.symbolic_trace(comparison)

x = torch.randn(3, 4)

fx.subgraph_rewriter.replace_pattern(traced, pattern, replacement)

traced.graph.lint()

ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
torch.testing.assert_close(ref_outs, test_outs)

重叠的匹配#

def f(x):
    x = torch.sigmoid(x)
    x = torch.sigmoid(x)
    x = torch.sigmoid(x)
    return torch.sigmoid(x)

def pattern(x):
    x = torch.sigmoid(x)
    x = torch.sigmoid(x)
    return x

def replacement(x):
    return torch.neg(x)

def comparison(x):
    x = torch.neg(x)
    return torch.neg(x)

traced = fx.symbolic_trace(f)
comparison_fn = fx.symbolic_trace(comparison)

x = torch.randn(3, 4)

fx.subgraph_rewriter.replace_pattern(traced, pattern, replacement)

traced.graph.lint()

ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
torch.testing.assert_close(ref_outs, test_outs)

移除未被使用的 args#

class M(nn.Module):
    def forward(self, x, y, z):
        return x + y

def pattern(x, y):
    return x + y

def replacement(x, y):
    return x - y

def comparison(x1, x2, x3):
    return x1 - x2

traced = fx.symbolic_trace(M())
comparison_fn = fx.symbolic_trace(comparison)

fx.subgraph_rewriter.replace_pattern(traced, pattern, replacement)

traced.graph.lint()

placeholder_nodes = [n for n in traced.graph.nodes if n.op == "placeholder"]
placeholder_nodes
[x, y]
x1 = torch.randn(3, 4)
x2 = torch.randn(3, 4)
x3 = torch.randn(3, 4)
ref_outs = comparison_fn(x1, x2, x3)
test_outs = traced.forward(x1, x2)
torch.testing.assert_close(ref_outs, test_outs)

重写回调方法#

class M(nn.Module):
    def forward(self, x):
        x = x.dequantize()
        x = x.sigmoid()
        x = x.to(torch.float16)
        return x

def pattern(x):
    x = x.dequantize()
    x = x.sigmoid()
    x = x.to(torch.float16)
    return x

def replacement(x):
    return x

traced = fx.symbolic_trace(M())
comparison_fn = fx.symbolic_trace(replacement)
fx.subgraph_rewriter.replace_pattern(traced, pattern, replacement)
traced.graph.lint()

x1 = torch.randn(3, 4)
ref_outs = comparison_fn(x1)
test_outs = traced.forward(x1)
torch.testing.assert_close(ref_outs, test_outs)

通过 kwargs 重写子图#

需要定义模块级方法:

# custom_rewriter.py
from torch import fx, nn

@fx.wrap
def wrapped_gemm_bias_mul(a, b, bias):
    lin_res = nn.functional.linear(a, b, bias=bias)
    mul_res = lin_res * a
    return lin_res, mul_res

@fx.wrap
def wrapped_gemm_bias_mul_with_c(a, b, bias, c):
    lin_res = nn.functional.linear(a, b, bias=bias)
    mul_res = lin_res * c
    return lin_res, mul_res
from custom_rewriter import wrapped_gemm_bias_mul, wrapped_gemm_bias_mul_with_c

class M(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.w0 = nn.Parameter(torch.empty([128, 128]))
        self.b0 = nn.Parameter(torch.empty([128]))

    def forward(self, in0):
        lin_res = nn.functional.linear(in0, self.w0, bias=self.b0)
        mul_res = in0 * lin_res
        sum_res = mul_res + in0
        return sum_res

def pattern(a, b, bias):
    lin_res = nn.functional.linear(a, b, bias=bias)
    mul_res = a * lin_res
    return (lin_res, mul_res)

def replacement(a, b, bias):
    lin_res, mul_res = wrapped_gemm_bias_mul(a, b, bias)
    return (lin_res, mul_res)

traced = fx.symbolic_trace(M())
matches = fx.subgraph_rewriter.replace_pattern(traced, pattern, replacement)
len(matches)
found_repalcement_node = False
for node in traced.graph.nodes:
    if node.target == wrapped_gemm_bias_mul:
        found_repalcement_node = True
        break

found_repalcement_node

重写 loca revert#

下面的模型将有 3 个锚(anchor)作为匹配候选者,锚 1 和锚 3 是真正的匹配,但锚 2 不是。子图重写器应该能够恢复在匹配锚点 2 时所做的更改。与三号锚的最后的匹配应该会成功。

# Following model will have 3 anchors as the matching candidate with the given pattern
# Anchor 1 and 3 is a real match, but anchor 2 is not.
# The subgraph rewriter should be able to revert the changes made while matching anchor 2.
# Final match with anchor 3 should be successful.
class M(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.w0 = nn.Parameter(torch.empty([128, 128]))
        self.b0 = nn.Parameter(torch.empty([128]))
        self.w1 = nn.Parameter(torch.empty([128, 128]))
        self.b1 = nn.Parameter(torch.empty([128]))
        self.w2 = nn.Parameter(torch.empty([128, 128]))
        self.b2 = nn.Parameter(torch.empty([128]))
        self.w3 = nn.Parameter(torch.empty([128, 128]))
        self.b3 = nn.Parameter(torch.empty([128]))
        self.w4 = nn.Parameter(torch.empty([128, 128]))
        self.b4 = nn.Parameter(torch.empty([128]))

    def forward(self, in0, in1):
        lin_res_1 = nn.functional.linear(in1, self.w0, bias=self.b0)
        lin_res_2 = nn.functional.linear(lin_res_1, self.w1, bias=self.b1)
        # potential match at anchor 1
        mul_res_1 = in1 * lin_res_2
        sum_res_1 = mul_res_1 + in1
        lin_res_3 = nn.functional.linear(
            sum_res_1, self.w2, bias=self.b2
        )
        sigmoid_res_1 = torch.sigmoid(lin_res_3)
        # potential match at anchor 2
        mul_res_2 = lin_res_3 * sigmoid_res_1
        lin_res_4 = nn.functional.linear(in0, self.w3, bias=self.b3)
        lin_res_5 = nn.functional.linear(lin_res_4, self.w4, bias=self.b4)
        # potential match at anchor 3
        mul_res_3 = in0 * lin_res_5
        sum_res_2 = mul_res_3 + in0
        cat_res = torch.cat(
            [mul_res_2, sum_res_2],
            dim=1,
        )
        return cat_res

def gemm_bias_mul_pattern_with_c(a, b, bias, c):
    lin_res = nn.functional.linear(a, b, bias=bias)
    mul_res = c * lin_res
    return lin_res, mul_res

def gemm_bias_mul_replacement_with_c(a, b, bias, c):
    lin_res, mul_res = wrapped_gemm_bias_mul_with_c(a, b, bias, c)
    return lin_res, mul_res

traced = fx.symbolic_trace(M())
matches = fx.subgraph_rewriter.replace_pattern(
    traced,
    gemm_bias_mul_pattern_with_c,
    gemm_bias_mul_replacement_with_c)

len(matches)
repalcement_node_found = 0
for node in traced.graph.nodes:
    if node.target == wrapped_gemm_bias_mul_with_c:
        repalcement_node_found += 1

repalcement_node_found

通过过滤器重写子图#

class M(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, scale, zero_point):
        # Match, second input to add is a scalar
        x = x.dequantize()
        x = torch.add(x, 2)
        x = x.relu()
        x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8)

        y = x + 1
        # NOT a match, second input to add is NOT a scalar
        x = x.dequantize()
        x = torch.add(x, y)
        x = x.relu()
        x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8)

        return x

def BinaryOpScalarReLUPattern(x, num, scale, zero_point):
    x = x.dequantize()
    x = torch.add(x, num)
    x = x.relu()
    x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8)
    return x

def BinaryOpScalarReLUReplacement(x, num, scale, zero_point):
    x = torch.mul(x, num)
    return x

def second_input_is_scalar(match, original_graph, pattern_graph):
    """ check the node that's matched to the second input of the pattern graph
    is a scalar number
    """
    input_idx = 0
    for node in pattern_graph.nodes:
        if node.op == "placeholder":
            if input_idx == 1:
                num_node = node
            input_idx += 1
    if not isinstance(match.nodes_map[num_node], (int, float)):
        return False
    return True

def num_repalcement_node_found(traced):
    return sum(1 for node in traced.graph.nodes if node.target == torch.mul)

# match without filter, should find 2 match
traced = fx.symbolic_trace(M())
matches = fx.subgraph_rewriter.replace_pattern(
    traced,
    BinaryOpScalarReLUPattern,
    BinaryOpScalarReLUReplacement)
len(matches)
1
num_repalcement_node_found(traced)
# match with filter, should find 1 match
traced = fx.symbolic_trace(M())
matches = fx.subgraph_rewriter.replace_pattern_with_filters(
    traced,
    BinaryOpScalarReLUPattern,
    BinaryOpScalarReLUReplacement,
    [second_input_is_scalar])

len(matches), num_repalcement_node_found(traced)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/graph/subgraph_rewriter.ipynb Cell 58 in <cell line: 3>()
      <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/graph/subgraph_rewriter.ipynb#Y123sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> # match with filter, should find 1 match
      <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/graph/subgraph_rewriter.ipynb#Y123sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a> traced = fx.symbolic_trace(M())
----> <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/graph/subgraph_rewriter.ipynb#Y123sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a> matches = fx.subgraph_rewriter.replace_pattern_with_filters(
      <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/graph/subgraph_rewriter.ipynb#Y123sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a>     traced,
      <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/graph/subgraph_rewriter.ipynb#Y123sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a>     BinaryOpScalarReLUPattern,
      <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/graph/subgraph_rewriter.ipynb#Y123sdnNjb2RlLXJlbW90ZQ%3D%3D?line=5'>6</a>     BinaryOpScalarReLUReplacement,
      <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/lxw/home/lxw/hub/torch-book/doc/tutorial/fx/graph/subgraph_rewriter.ipynb#Y123sdnNjb2RlLXJlbW90ZQ%3D%3D?line=6'>7</a>     [second_input_is_scalar])

AttributeError: module 'torch.fx.subgraph_rewriter' has no attribute 'replace_pattern_with_filters'

test_matching_pattern_with_list_type_arg

class M(torch.nn.Module):
    def forward(self, x):
        return torch.ops.aten._reshape_alias_copy.default(x, [1, 2], [3, 4])

def pattern(x, arg0, arg1):
    return torch.ops.aten._reshape_alias_copy.default(x, arg0, arg1)

def replacement(x, arg0, arg1):
    return torch.ops.aten._reshape_alias_copy.default(x, arg1, arg0)

traced = fx.symbolic_trace(M())
matches = fx.subgraph_rewriter.replace_pattern(traced, pattern, replacement)
len(matches)
0
print(traced.code.strip())
def forward(self, x):
    _reshape_alias_copy_default = torch.ops.aten._reshape_alias_copy.default(x, [1, 2], [3, 4]);  x = None
    return _reshape_alias_copy_default