Conv2D 测试

Conv2D 测试#

定义前端网络

import torch

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv2d = torch.nn.Conv2d(3, 16, 3, 1, 1, bias=False)
        self.relu = torch.nn.ReLU()
        self.conv2d2 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=True)
        self.relu2 = torch.nn.ReLU()
        self.conv2d3 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=True)
        self.conv2d4 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=True)
    def forward(self, x):
        x = self.conv2d(x)
        x = self.relu(x)
        x = self.conv2d2(x)
        x = self.relu2(x)
        x = self.conv2d3(x)
        x = self.conv2d4(x)
        return x
input_shape = [1, 3, 32, 32]
torch_model = M().eval()
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_exported_program
from torch.export import export

# Give an example argument to torch.export
example_args = (torch.randn(1, 3, 32, 32, dtype=torch.float32),)

# Convert the model to IRModule
with torch.no_grad():
    exported_program = export(torch_model, example_args)
    run_mod = from_exported_program(exported_program, keep_params_as_input=False)
# run_mod, params = relax.frontend.detach_params(run_mod)
from tvm.ir import IRModule
from tvm.relax import transform
from tvm.relax.dpl.pattern import (
    DFPattern,
    wildcard,
    is_op,
)
from tvm.relax.transform import PatternCheckContext
from tvm.relax.backend.patterns import (
    make_conv2d_pattern, 
    # make_attention_pattern,
    # make_stacked_attention_pattern,
    # make_layer_norm_pattern,
    # make_rms_norm_pattern,
    # make_matmul_dequantize_pattern,
    # make_matmul_multiply_pattern,
    # make_attention_rewrite_pattern,
    # make_fused_bias_activation_pattern,
    # make_residual_block_pattern,
)
from tvm.relax.backend.pattern_registry import get_patterns_with_prefix, register_patterns
from tvm.relax import expr as _expr
def _check_conv2d(context: PatternCheckContext) -> bool:
    # lhs = context.annotated_expr["lhs"]
    # rhs = context.annotated_expr["rhs"]
    # expr = context.annotated_expr["root"]
    # assert isinstance(lhs, relax.expr.Var) and lhs.name_hint == "data"
    # assert isinstance(rhs, relax.expr.Var) and rhs.name_hint == "weight1"
    # assert isinstance(expr, relax.expr.Call) and expr.op.name == "relax.nn.conv2d"
    # return False
    return True

register_patterns(
    [
        (
            "vta.conv2d",
            *make_conv2d_pattern(
                with_bias=False,
            ),
            _check_conv2d,
        ),
        (
            "vta.conv2d_bias",
            *make_conv2d_pattern(
                with_bias=True,
            ),
            _check_conv2d,
        ),
        (
            "vta.conv2d_bias_relu",
            *make_conv2d_pattern(
                with_bias=True,
                activation="relax.nn.relu",
            ),
            _check_conv2d,
        ),
        (
            "vta.conv2d_relu",
            *make_conv2d_pattern(
                with_bias=False,
                activation="relax.nn.relu",
            ),
            _check_conv2d,
        ),
        # (
        #     "vta.attention.BS3NH",
        #     *make_stacked_attention_pattern(start_op="split", layout="BS3NH"),
        #     partial(_check_stacked_attention, layout="BS3NH"),
        # ),
        # (
        #     "vta.attention.SBN3H",
        #     *make_stacked_attention_pattern(start_op="split", layout="SBN3H"),
        #     partial(_check_stacked_attention, layout="SBN3H"),
        # ),
    ]
)
patterns = get_patterns_with_prefix("vta")
seq = tvm.transform.Sequential(
    [
        # transform.LegalizeOps(enable_warning=enable_warning),
        # transform.RewriteDataflowReshape(),
        transform.AnnotateTIROpPattern(),
        transform.FoldConstant(),
        transform.FuseOps(),
        transform.FuseTIR(),
        transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=False)
    ]
)
mod = seq(run_mod)
mod = (mod)
# mod = MergeCompositeFunctions()(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function(private=True)
    def fused_relax_nn_conv2d_relax_add(lv: R.Tensor((1, 16, 32, 32), dtype="float32"), param_0: R.Tensor((16, 16, 3, 3), dtype="float32"), param_1: R.Tensor((1, 16, 1, 1), dtype="float32")) -> R.Tensor((1, 16, 32, 32), dtype="float32"):
        R.func_attr({"Composite": "vta.conv2d_bias", "Primitive": True})
        with R.dataflow():
            lv6: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(lv, param_0, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            gv: R.Tensor((1, 16, 32, 32), dtype="float32") = R.add(lv6, param_1)
            R.output(gv)
        return gv

    @R.function(private=True)
    def fused_relax_nn_conv2d_relax_add_relax_nn_relu(lv: R.Tensor((1, 16, 32, 32), dtype="float32"), param_0: R.Tensor((16, 16, 3, 3), dtype="float32"), param_1: R.Tensor((1, 16, 1, 1), dtype="float32")) -> R.Tensor((1, 16, 32, 32), dtype="float32"):
        R.func_attr({"Composite": "vta.conv2d_bias_relu", "Primitive": True})
        with R.dataflow():
            lv2: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(lv, param_0, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv4: R.Tensor((1, 16, 32, 32), dtype="float32") = R.add(lv2, param_1)
            gv: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.relu(lv4)
            R.output(gv)
        return gv

    @R.function(private=True)
    def fused_relax_nn_conv2d_relax_nn_relu(x: R.Tensor((1, 3, 32, 32), dtype="float32"), param_0: R.Tensor((16, 3, 3, 3), dtype="float32")) -> R.Tensor((1, 16, 32, 32), dtype="float32"):
        R.func_attr({"Composite": "vta.conv2d_relu", "Primitive": True})
        with R.dataflow():
            lv: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(x, param_0, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            gv: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.relu(lv)
            R.output(gv)
        return gv

    @R.function
    def main(x: R.Tensor((1, 3, 32, 32), dtype="float32")) -> R.Tuple(R.Tensor((1, 16, 32, 32), dtype="float32")):
        cls = Module
        with R.dataflow():
            lv: R.Tensor((1, 16, 32, 32), dtype="float32") = cls.fused_relax_nn_conv2d_relax_nn_relu(x, metadata["relax.expr.Constant"][0])
            lv_1: R.Tensor((1, 16, 32, 32), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add_relax_nn_relu(lv, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2])
            lv_2: R.Tensor((1, 16, 32, 32), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add(lv_1, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4])
            lv1: R.Tensor((1, 16, 32, 32), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add(lv_2, metadata["relax.expr.Constant"][5], metadata["relax.expr.Constant"][6])
            gv: R.Tuple(R.Tensor((1, 16, 32, 32), dtype="float32")) = (lv1,)
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.
def fvisit(expr: tvm.relax.Expr):
    # print(type(expr), expr)
    if isinstance(expr, _expr.GlobalVar):
        func = mod[expr]
        func = func.with_attr("op_pattern", 2)
        print(func.attrs, isinstance(func, tvm.tir.PrimExpr))
    # if isinstance(expr, _expr.Call):
    #     # print(expr.attrs)
    #     print(expr.op)
expr = mod["main"]
relax.analysis.post_order_visit(expr, fvisit)
{"Composite": "vta.conv2d_relu", "Primitive": True, "op_pattern": 2} False
{"Composite": "vta.conv2d_bias_relu", "Primitive": True, "op_pattern": 2} False
{"Composite": "vta.conv2d_bias", "Primitive": True, "op_pattern": 2} False
{"Composite": "vta.conv2d_bias", "Primitive": True, "op_pattern": 2} False
# op = tvm.ir.Op.get("relax.add")
# op.set_attr("op_pattern", OpPatternKind.kElemWise)