DefuseOps#
源码文件 tvm/src/relay/transforms/defuse_ops.cc
是 Relay 框架中的运算,用于将融合运算(relay::transform::FuseOps
)的结果还原为融合之前的状态。即 x == DefuseOps(FuseOps(x))
。
名为 DefuseOpsMutator
,继承自 ExprMutator
。它包含了一个嵌套的类 FuncBodyMutator
,也继承自 ExprMutator
。
DefuseOpsMutator
类有两个成员函数:VisitExpr_(const CallNode* n)
和 DefuseOps(const Expr& expr)
。
VisitExpr_(const CallNode* n)
函数接受指向 CallNode
对象的指针作为参数,并返回 Expr
对象。在函数内部,首先调用父类 ExprMutator
的 VisitExpr_
函数来处理 n
。然后,如果返回的对象是 CallNode
类型,就进一步检查其算子是否为 FunctionNode
类型。如果是,则创建无序的哈希表 name_to_args
来存储函数参数的名称和对应的表达式。接下来,遍历函数的参数列表,将每个参数的名称和对应的表达式添加到哈希表中。最后,使用 FuncBodyMutator
类的构造函数创建新的 FuncBodyMutator
对象,并将 name_to_args
作为参数传递给它。然后调用该对象的 Mutate
函数,传入函数体的表达式,最终返回处理后的表达式。
DefuseOps(const Expr& expr)
函数接受 Expr
对象的引用作为参数,并返回 Expr
对象。它的作用是创建 DefuseOpsMutator
对象,并调用其 Mutate
函数来处理输入的表达式。
import numpy as np
from tvm import relay
import tvm
from tvm_book.tvm_utils.llvm_utils import run_llvm_graph
from tvm_book.tvm_utils.split_graph import graph_split
from tvm.relay.dataflow_pattern import is_op, wildcard
def make_conv_add_relu_pattern():
"""创建如下模式
conv2d
|
(add)
|
(relu)
"""
x = wildcard()
w = wildcard()
bias = wildcard()
r = is_op("nn.conv2d")(x, w)
r = is_op("add")(r, bias) | r
# 激活函数
r = r.optional(lambda x: is_op("nn.relu")(x))
return r
def load_model(input_shape=[1, 3, 224, 224]):
"""加载前端模型"""
import torch
from torchvision.models import resnet18
from torchvision.models.resnet import ResNet18_Weights
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
data = torch.randn(*input_shape)
return torch.jit.trace(model.eval(), data)
size = 224, 224
input_shape = (1, 3, *size)
input_name = "data"
traced_model = load_model(input_shape).eval()
# 将前端模型翻译为 relay 模型
origin_mod, origin_params = relay.frontend.from_pytorch(traced_model, [(input_name, input_shape)])
# 获取子图
split_conf = [{"op_name": "add", "op_index": 0}]
mod = graph_split(origin_mod["main"], split_conf)[0]
compiler_name = "ccompiler"
pattern_table = [
(f"{compiler_name}.conv_add_relu", make_conv_add_relu_pattern()),
]
merge_passes = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.MergeComposite(pattern_table),
# # relay.transform.AnnotateTarget([compiler_name]),
relay.transform.PartitionGraph(),
])
with tvm.transform.PassContext(opt_level=3):
with relay.quantize.qconfig(
calibrate_mode="kl_divergence",
weight_scale="max",
skip_conv_layers=[],
skip_dense_layer=False
):
# 量化前准备
run_mod = relay.quantize.prerequisite_optimize(mod, origin_params)
run_mod = merge_passes(run_mod) # 算子融合
print(run_mod["main"])
fn (%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 56, 56), float32] {
%5 = fn (%FunctionVar_2_0: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */, %FunctionVar_2_1: Tensor[(64, 3, 7, 7), float32] /* ty=Tensor[(64, 3, 7, 7), float32] */, %FunctionVar_2_2: Tensor[(64, 1, 1), float32] /* ty=Tensor[(64, 1, 1), float32] */, PartitionedFromPattern="nn.conv2d_add_nn.relu_", Composite="ccompiler.conv_add_relu") -> Tensor[(1, 64, 112, 112), float32] {
%3 = nn.conv2d(%FunctionVar_2_0, %FunctionVar_2_1, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%4 = add(%3, %FunctionVar_2_2) /* ty=Tensor[(1, 64, 112, 112), float32] */;
nn.relu(%4) /* ty=Tensor[(1, 64, 112, 112), float32] */
} /* ty=fn (Tensor[(1, 3, 224, 224), float32], Tensor[(64, 3, 7, 7), float32], Tensor[(64, 1, 1), float32]) -> Tensor[(1, 64, 112, 112), float32] */;
%6 = %5(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%7 = nn.max_pool2d(%6, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%8 = fn (%FunctionVar_1_0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %FunctionVar_1_1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %FunctionVar_1_2: Tensor[(64, 1, 1), float32] /* ty=Tensor[(64, 1, 1), float32] */, PartitionedFromPattern="nn.conv2d_add_nn.relu_", Composite="ccompiler.conv_add_relu") -> Tensor[(1, 64, 56, 56), float32] {
%1 = nn.conv2d(%FunctionVar_1_0, %FunctionVar_1_1, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%2 = add(%1, %FunctionVar_1_2) /* ty=Tensor[(1, 64, 56, 56), float32] */;
nn.relu(%2) /* ty=Tensor[(1, 64, 56, 56), float32] */
} /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(64, 1, 1), float32]) -> Tensor[(1, 64, 56, 56), float32] */;
%9 = %8(%7, meta[relay.Constant][2] /* ty=Tensor[(64, 64, 3, 3), float32] */, meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%10 = fn (%FunctionVar_0_0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %FunctionVar_0_1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %FunctionVar_0_2: Tensor[(64, 1, 1), float32] /* ty=Tensor[(64, 1, 1), float32] */, PartitionedFromPattern="nn.conv2d_add_", Composite="ccompiler.conv_add_relu") -> Tensor[(1, 64, 56, 56), float32] {
%0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
add(%0, %FunctionVar_0_2) /* ty=Tensor[(1, 64, 56, 56), float32] */
} /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(64, 1, 1), float32]) -> Tensor[(1, 64, 56, 56), float32] */;
%11 = %10(%9, meta[relay.Constant][4] /* ty=Tensor[(64, 64, 3, 3), float32] */, meta[relay.Constant][5] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
add(%11, %7) /* ty=Tensor[(1, 64, 56, 56), float32] */
} /* ty=fn (Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 64, 56, 56), float32] */
直接调用 tvm.relay.transform.DefuseOps
:
relay.transform.DefuseOps()(run_mod)
def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 56, 56), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%2 = nn.relu(%1) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%3 = nn.max_pool2d(%2, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%4 = nn.conv2d(%3, meta[relay.Constant][2] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%5 = add(%4, meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%6 = nn.relu(%5) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%7 = nn.conv2d(%6, meta[relay.Constant][4] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%8 = add(%7, meta[relay.Constant][5] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
add(%8, %3) /* ty=Tensor[(1, 64, 56, 56), float32] */
}
为了更好理解 tvm.relay.transform.DefuseOps
实现功能,可以使用 Python 模拟其功能:
from dataclasses import dataclass, field
from tvm.relay import Call
from tvm.relay.function import Function
@tvm.relay.transform.function_pass(opt_level=1)
class DefuseTransform:
def __init__(self):
self.reset()
def reset(self):
self.name_to_args_ = {}
def transform_function(self, func, mod, ctx):
obj = self
@dataclass
class FuncBodyMutator(tvm.relay.ExprMutator):
name_to_args_: dict
memo_map: dict = field(default_factory=dict)
def visit_var(self, var):
return self.name_to_args_[var.name_hint]
class Replace(tvm.relay.ExprMutator):
def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
call = Call(new_fn, new_args, call.attrs, call.type_args, call.span)
if isinstance(call.op, Function):
name_to_args = {}
for param, arg in zip(new_fn.params, new_args):
name_to_args[param.name_hint] = arg
call = FuncBodyMutator(name_to_args).visit(new_fn.body)
return call
return Replace().visit(func)
transform = DefuseTransform()
_mod = transform(run_mod)
print(_mod)
def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 56, 56), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%2 = nn.relu(%1) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%3 = nn.max_pool2d(%2, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%4 = nn.conv2d(%3, meta[relay.Constant][2] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%5 = add(%4, meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%6 = nn.relu(%5) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%7 = nn.conv2d(%6, meta[relay.Constant][4] /* ty=Tensor[(64, 64, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%8 = add(%7, meta[relay.Constant][5] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
add(%8, %3) /* ty=Tensor[(1, 64, 56, 56), float32] */
}
from tvm.relay.transform.suffixes import tag_suffixes
print(tag_suffixes(mod))
def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) {
%0 = nn.conv2d(%data, meta[relay.Constant][0], strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%1 = nn.batch_norm(%0, meta[relay.Constant][1], meta[relay.Constant][2], meta[relay.Constant][3], meta[relay.Constant][4]) /* ty=(Tensor[(1, 64, 112, 112), float32], Tensor[(64), float32], Tensor[(64), float32]) */;
%2 = %1.0 /* ty=Tensor[(1, 64, 112, 112), float32] */;
%3 = nn.relu(%2) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%4 = nn.max_pool2d(%3, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%5 = nn.conv2d(%4, meta[relay.Constant][5], padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%6 = nn.batch_norm(%5, meta[relay.Constant][6], meta[relay.Constant][7], meta[relay.Constant][8], meta[relay.Constant][9]) /* ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) */;
%7 = %6.0 /* ty=Tensor[(1, 64, 56, 56), float32] */;
%8 = nn.relu(%7) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%9 = nn.conv2d(%8, meta[relay.Constant][10], padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%10 = nn.batch_norm(%9, meta[relay.Constant][11], meta[relay.Constant][12], meta[relay.Constant][13], meta[relay.Constant][14]) /* ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) */;
%11 = %10.0 /* ty=Tensor[(1, 64, 56, 56), float32] */;
add(%11, %4) /* ty=Tensor[(1, 64, 56, 56), float32] */
}