DefuseOps

DefuseOps#

源码文件 tvm/src/relay/transforms/defuse_ops.cc 是 Relay 框架中的运算,用于将融合运算(relay::transform::FuseOps)的结果还原为融合之前的状态。即 x == DefuseOps(FuseOps(x))

名为 DefuseOpsMutator,继承自 ExprMutator。它包含了一个嵌套的类 FuncBodyMutator,也继承自 ExprMutator

DefuseOpsMutator 类有两个成员函数:VisitExpr_(const CallNode* n)DefuseOps(const Expr& expr)

VisitExpr_(const CallNode* n) 函数接受指向 CallNode 对象的指针作为参数,并返回 Expr 对象。在函数内部,首先调用父类 ExprMutatorVisitExpr_ 函数来处理 n。然后,如果返回的对象是 CallNode 类型,就进一步检查其算子是否为 FunctionNode 类型。如果是,则创建无序的哈希表 name_to_args 来存储函数参数的名称和对应的表达式。接下来,遍历函数的参数列表,将每个参数的名称和对应的表达式添加到哈希表中。最后,使用 FuncBodyMutator 类的构造函数创建新的 FuncBodyMutator 对象,并将 name_to_args 作为参数传递给它。然后调用该对象的 Mutate 函数,传入函数体的表达式,最终返回处理后的表达式。

DefuseOps(const Expr& expr) 函数接受 Expr 对象的引用作为参数,并返回 Expr 对象。它的作用是创建 DefuseOpsMutator 对象,并调用其 Mutate 函数来处理输入的表达式。

import numpy as np
from tvm import relay
import tvm
from tvm_book.tvm_utils.llvm_utils import run_llvm_graph
from tvm_book.tvm_utils.split_graph import graph_split
from tvm.relay.dataflow_pattern import is_op, wildcard

def make_conv_add_relu_pattern():
    """创建如下模式

     conv2d
        |
      (add)
        |
      (relu)
    """
    x = wildcard()
    w = wildcard()
    bias = wildcard()
    r = is_op("nn.conv2d")(x, w)
    r = is_op("add")(r, bias) | r
    # 激活函数
    r = r.optional(lambda x: is_op("nn.relu")(x))
    return r

def load_model(input_shape=[1, 3, 224, 224]):
    """加载前端模型"""
    import torch
    from torchvision.models import resnet18
    from torchvision.models.resnet import ResNet18_Weights
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    data = torch.randn(*input_shape)
    return torch.jit.trace(model.eval(), data)

size = 224, 224
input_shape = (1, 3, *size)
input_name = "data"
traced_model = load_model(input_shape).eval()
# 将前端模型翻译为 relay 模型
origin_mod, origin_params = relay.frontend.from_pytorch(traced_model, [(input_name, input_shape)])
# 获取子图
split_conf = [{"op_name": "add", "op_index": 0}]
mod = graph_split(origin_mod["main"], split_conf)[0]
compiler_name = "ccompiler"
pattern_table = [
    (f"{compiler_name}.conv_add_relu", make_conv_add_relu_pattern()),
]
merge_passes = tvm.transform.Sequential([
    relay.transform.InferType(),
    relay.transform.MergeComposite(pattern_table),
    # # relay.transform.AnnotateTarget([compiler_name]),
    relay.transform.PartitionGraph(),
])
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        calibrate_mode="kl_divergence",
        weight_scale="max",
        skip_conv_layers=[],
        skip_dense_layer=False
    ):
        # 量化前准备
        run_mod = relay.quantize.prerequisite_optimize(mod, origin_params)
        run_mod = merge_passes(run_mod) # 算子融合
print(run_mod["main"])
fn (%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 56, 56), float32] {
  %5 = fn (%FunctionVar_2_0: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */, %FunctionVar_2_1: Tensor[(64, 3, 7, 7), float32] /* ty=Tensor[(64, 3, 7, 7), float32] */, %FunctionVar_2_2: Tensor[(64, 1, 1), float32] /* ty=Tensor[(64, 1, 1), float32] */, PartitionedFromPattern="nn.conv2d_add_nn.relu_", Composite="ccompiler.conv_add_relu") -> Tensor[(1, 64, 112, 112), float32] {
    %3 = nn.conv2d(%FunctionVar_2_0, %FunctionVar_2_1, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
    %4 = add(%3, %FunctionVar_2_2) /* ty=Tensor[(1, 64, 112, 112), float32] */;
    nn.relu(%4) /* ty=Tensor[(1, 64, 112, 112), float32] */
  } /* ty=fn (Tensor[(1, 3, 224, 224), float32], Tensor[(64, 3, 7, 7), float32], Tensor[(64, 1, 1), float32]) -> Tensor[(1, 64, 112, 112), float32] */;
  %6 = %5(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %7 = nn.max_pool2d(%6, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %8 = fn (%FunctionVar_1_0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %FunctionVar_1_1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %FunctionVar_1_2: Tensor[(64, 1, 1), float32] /* ty=Tensor[(64, 1, 1), float32] */, PartitionedFromPattern="nn.conv2d_add_nn.relu_", Composite="ccompiler.conv_add_relu") -> Tensor[(1, 64, 56, 56), float32] {
    %1 = nn.conv2d(%FunctionVar_1_0, %FunctionVar_1_1, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
    %2 = add(%1, %FunctionVar_1_2) /* ty=Tensor[(1, 64, 56, 56), float32] */;
    nn.relu(%2) /* ty=Tensor[(1, 64, 56, 56), float32] */
  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(64, 1, 1), float32]) -> Tensor[(1, 64, 56, 56), float32] */;
  %9 = %8(%7, meta[relay.Constant][2] /* ty=Tensor[(64, 64, 3, 3), float32] */, meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %10 = fn (%FunctionVar_0_0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %FunctionVar_0_1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %FunctionVar_0_2: Tensor[(64, 1, 1), float32] /* ty=Tensor[(64, 1, 1), float32] */, PartitionedFromPattern="nn.conv2d_add_", Composite="ccompiler.conv_add_relu") -> Tensor[(1, 64, 56, 56), float32] {
    %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
    add(%0, %FunctionVar_0_2) /* ty=Tensor[(1, 64, 56, 56), float32] */
  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(64, 1, 1), float32]) -> Tensor[(1, 64, 56, 56), float32] */;
  %11 = %10(%9, meta[relay.Constant][4] /* ty=Tensor[(64, 64, 3, 3), float32] */, meta[relay.Constant][5] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  add(%11, %7) /* ty=Tensor[(1, 64, 56, 56), float32] */
} /* ty=fn (Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 64, 56, 56), float32] */

直接调用 tvm.relay.transform.DefuseOps

relay.transform.DefuseOps()(run_mod)
def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 56, 56), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %2 = nn.relu(%1) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %3 = nn.max_pool2d(%2, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][2] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %5 = add(%4, meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %6 = nn.relu(%5) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %7 = nn.conv2d(%6, meta[relay.Constant][4] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %8 = add(%7, meta[relay.Constant][5] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  add(%8, %3) /* ty=Tensor[(1, 64, 56, 56), float32] */
}

为了更好理解 tvm.relay.transform.DefuseOps 实现功能,可以使用 Python 模拟其功能:

from dataclasses import dataclass, field
from tvm.relay import Call
from tvm.relay.function import Function


@tvm.relay.transform.function_pass(opt_level=1)
class DefuseTransform:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.name_to_args_ = {}

    def transform_function(self, func, mod, ctx):
        obj = self

        @dataclass
        class FuncBodyMutator(tvm.relay.ExprMutator):
            name_to_args_: dict
            memo_map: dict = field(default_factory=dict)

            def visit_var(self, var):
                return self.name_to_args_[var.name_hint]

        class Replace(tvm.relay.ExprMutator):
            def visit_call(self, call):
                new_fn = self.visit(call.op)
                new_args = [self.visit(arg) for arg in call.args]
                call = Call(new_fn, new_args, call.attrs, call.type_args, call.span)
                if isinstance(call.op, Function):
                    name_to_args = {}
                    for param, arg in zip(new_fn.params, new_args):
                        name_to_args[param.name_hint] = arg
                    call = FuncBodyMutator(name_to_args).visit(new_fn.body)
                return call
        return Replace().visit(func)
transform = DefuseTransform()
_mod = transform(run_mod)
print(_mod)
def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 56, 56), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %2 = nn.relu(%1) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %3 = nn.max_pool2d(%2, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][2] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %5 = add(%4, meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %6 = nn.relu(%5) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %7 = nn.conv2d(%6, meta[relay.Constant][4] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %8 = add(%7, meta[relay.Constant][5] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  add(%8, %3) /* ty=Tensor[(1, 64, 56, 56), float32] */
}
from tvm.relay.transform.suffixes import tag_suffixes

print(tag_suffixes(mod))
def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) {
  %0 = nn.conv2d(%data, meta[relay.Constant][0], strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %1 = nn.batch_norm(%0, meta[relay.Constant][1], meta[relay.Constant][2], meta[relay.Constant][3], meta[relay.Constant][4]) /* ty=(Tensor[(1, 64, 112, 112), float32], Tensor[(64), float32], Tensor[(64), float32]) */;
  %2 = %1.0 /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %3 = nn.relu(%2) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %4 = nn.max_pool2d(%3, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %5 = nn.conv2d(%4, meta[relay.Constant][5], padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %6 = nn.batch_norm(%5, meta[relay.Constant][6], meta[relay.Constant][7], meta[relay.Constant][8], meta[relay.Constant][9]) /* ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) */;
  %7 = %6.0 /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %8 = nn.relu(%7) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %9 = nn.conv2d(%8, meta[relay.Constant][10], padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %10 = nn.batch_norm(%9, meta[relay.Constant][11], meta[relay.Constant][12], meta[relay.Constant][13], meta[relay.Constant][14]) /* ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) */;
  %11 = %10.0 /* ty=Tensor[(1, 64, 56, 56), float32] */;
  add(%11, %4) /* ty=Tensor[(1, 64, 56, 56), float32] */
}