PyTorch resnet18 Relax

PyTorch resnet18 Relax#

import torch
from torchvision import models
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx

fold_pipeline = tvm.transform.Sequential([
    relax.transform.FoldBatchnormToConv2D(),
    relax.transform.FoldConstant(),
    relax.transform.RemoveRedundantReshape(),
    # 将 BatchNorm 转换为一组更简单的算子以进行融合
    relax.transform.DecomposeOpsForInference(),
    # 规范化绑定
    relax.transform.CanonicalizeBindings(),

])

# 创建 PyTorch 模型
torch_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT).eval()
shape = [1, 3, 224, 224]
input_info = [(shape, "float32")]
# 变换为 Relay 模型
with torch.no_grad():
    graph_model = torch.fx.symbolic_trace(torch_model)
    mod = from_fx(graph_model, input_info)
# 初次优化模型
run_mod = fold_pipeline(mod)
run_mod.show()
from tvm.relax.dpl.pattern import wildcard, is_op, is_const, make_fused_bias_activation_pattern

def make_fused_op_bias_activation_pattern(op_name="relax.nn.conv2d", activation="relax.nn.relu"):
    op_bias_relu_pat = make_fused_bias_activation_pattern(
        op_name,
        with_bias=True,
        activation=activation
    )
    op_relu_pat = make_fused_bias_activation_pattern(
        op_name,
        with_bias=False,
        activation=activation
    )
    op_bias_pat = make_fused_bias_activation_pattern(
        op_name,
        with_bias=True,
    )
    return op_bias_relu_pat | op_relu_pat | op_bias_pat

compiler = "ccompiler"
patterns = [
    (f"{compiler}.conv2d_bias_relu", make_fused_op_bias_activation_pattern("relax.nn.conv2d")),
    (f"{compiler}.matmul_bias_relu", make_fused_op_bias_activation_pattern("relax.matmul")),
    (f"{compiler}.add_activation", make_fused_bias_activation_pattern("relax.add", with_bias=False, activation="relax.nn.relu")),
    (f"{compiler}.max_pool2d", is_op("relax.nn.max_pool2d")(wildcard())),
    (f"{compiler}.adaptive_avg_pool2d", is_op("relax.nn.adaptive_avg_pool2d")(wildcard())),
    (f"{compiler}.reshape", is_op("relax.reshape")(wildcard(), wildcard())),
]
fuse_pipeline = tvm.transform.Sequential([
    relax.transform.FuseOpsByPattern(patterns, bind_constants=True, annotate_codegen=True),
    # relax.transform.MergeCompositeFunctions(),
    # relax.transform.FuseOps(),
    # relax.transform.DeadCodeElimination(),
])
run_mod2 = fuse_pipeline(run_mod)
run_mod2.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def fused_relax_add_relax_nn_relu1_ccompiler(lv6: R.Tensor((1, 128, 28, 28), dtype="float32"), lv7: R.Tensor((1, 128, 28, 28), dtype="float32")) -> R.Tensor((1, 128, 28, 28), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv6_1: R.Tensor((1, 128, 28, 28), dtype="float32"), lv7_1: R.Tensor((1, 128, 28, 28), dtype="float32")) -> R.Tensor((1, 128, 28, 28), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.add_activation"})
            with R.dataflow():
                lv33: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv6_1, lv7_1)
                gv: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv33)
                R.output(gv)
            return gv

        output: R.Tensor((1, 128, 28, 28), dtype="float32") = local_func(lv6, lv7)
        return output

    @R.function
    def fused_relax_add_relax_nn_relu2_ccompiler(lv11: R.Tensor((1, 256, 14, 14), dtype="float32"), lv12: R.Tensor((1, 256, 14, 14), dtype="float32")) -> R.Tensor((1, 256, 14, 14), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv11_1: R.Tensor((1, 256, 14, 14), dtype="float32"), lv12_1: R.Tensor((1, 256, 14, 14), dtype="float32")) -> R.Tensor((1, 256, 14, 14), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.add_activation"})
            with R.dataflow():
                lv54: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv11_1, lv12_1)
                gv: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv54)
                R.output(gv)
            return gv

        output: R.Tensor((1, 256, 14, 14), dtype="float32") = local_func(lv11, lv12)
        return output

    @R.function
    def fused_relax_add_relax_nn_relu3_ccompiler(lv16: R.Tensor((1, 512, 7, 7), dtype="float32"), lv17: R.Tensor((1, 512, 7, 7), dtype="float32")) -> R.Tensor((1, 512, 7, 7), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv16_1: R.Tensor((1, 512, 7, 7), dtype="float32"), lv17_1: R.Tensor((1, 512, 7, 7), dtype="float32")) -> R.Tensor((1, 512, 7, 7), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.add_activation"})
            with R.dataflow():
                lv75: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv16_1, lv17_1)
                gv: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv75)
                R.output(gv)
            return gv

        output: R.Tensor((1, 512, 7, 7), dtype="float32") = local_func(lv16, lv17)
        return output

    @R.function
    def fused_relax_add_relax_nn_relu_ccompiler(lv2: R.Tensor((1, 64, 56, 56), dtype="float32"), lv4: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv2_1: R.Tensor((1, 64, 56, 56), dtype="float32"), lv4_1: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.add_activation"})
            with R.dataflow():
                lv12: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(lv2_1, lv4_1)
                gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv12)
                R.output(gv)
            return gv

        output: R.Tensor((1, 64, 56, 56), dtype="float32") = local_func(lv2, lv4)
        return output

    @R.function
    def fused_relax_matmul_relax_add_ccompiler(lv87: R.Tensor((1, 512), dtype="float32")) -> R.Tensor((1, 1000), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv87_1: R.Tensor((1, 512), dtype="float32")) -> R.Tensor((1, 1000), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.matmul_bias_relu"})
            with R.dataflow():
                lv89: R.Tensor((1, 1000), dtype="float32") = R.matmul(lv87_1, metadata["relax.expr.Constant"][0], out_dtype="float32")
                gv: R.Tensor((1, 1000), dtype="float32") = R.add(lv89, metadata["relax.expr.Constant"][1])
                R.output(gv)
            return gv

        output: R.Tensor((1, 1000), dtype="float32") = local_func(lv87)
        return output

    @R.function
    def fused_relax_nn_adaptive_avg_pool2d_ccompiler(lv7: R.Tensor((1, 512, 7, 7), dtype="float32")) -> R.Tensor((1, 512, 1, 1), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv7_1: R.Tensor((1, 512, 7, 7), dtype="float32")) -> R.Tensor((1, 512, 1, 1), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.adaptive_avg_pool2d"})
            with R.dataflow():
                gv: R.Tensor((1, 512, 1, 1), dtype="float32") = R.nn.adaptive_avg_pool2d(lv7_1, output_size=[1, 1], layout="NCHW", out_layout="NCHW")
                R.output(gv)
            return gv

        output: R.Tensor((1, 512, 1, 1), dtype="float32") = local_func(lv7)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add10_ccompiler(lv80: R.Tensor((1, 512, 7, 7), dtype="float32")) -> R.Tensor((1, 512, 7, 7), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv80_1: R.Tensor((1, 512, 7, 7), dtype="float32")) -> R.Tensor((1, 512, 7, 7), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv192: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv80_1, metadata["relax.expr.Constant"][2], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                gv: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv192, metadata["relax.expr.Constant"][3])
                R.output(gv)
            return gv

        output: R.Tensor((1, 512, 7, 7), dtype="float32") = local_func(lv80)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add1_ccompiler(lv17: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv17_1: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv45: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv17_1, metadata["relax.expr.Constant"][4], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(lv45, metadata["relax.expr.Constant"][5])
                R.output(gv)
            return gv

        output: R.Tensor((1, 64, 56, 56), dtype="float32") = local_func(lv17)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add2_ccompiler(lv26: R.Tensor((1, 128, 28, 28), dtype="float32")) -> R.Tensor((1, 128, 28, 28), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv26_1: R.Tensor((1, 128, 28, 28), dtype="float32")) -> R.Tensor((1, 128, 28, 28), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv65: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv26_1, metadata["relax.expr.Constant"][6], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                gv: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv65, metadata["relax.expr.Constant"][7])
                R.output(gv)
            return gv

        output: R.Tensor((1, 128, 28, 28), dtype="float32") = local_func(lv26)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add3_ccompiler(lv22: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 128, 28, 28), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv22_1: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 128, 28, 28), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv74: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv22_1, metadata["relax.expr.Constant"][8], strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                gv: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv74, metadata["relax.expr.Constant"][9])
                R.output(gv)
            return gv

        output: R.Tensor((1, 128, 28, 28), dtype="float32") = local_func(lv22)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add4_ccompiler(lv38: R.Tensor((1, 128, 28, 28), dtype="float32")) -> R.Tensor((1, 128, 28, 28), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv38_1: R.Tensor((1, 128, 28, 28), dtype="float32")) -> R.Tensor((1, 128, 28, 28), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv94: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv38_1, metadata["relax.expr.Constant"][10], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                gv: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv94, metadata["relax.expr.Constant"][11])
                R.output(gv)
            return gv

        output: R.Tensor((1, 128, 28, 28), dtype="float32") = local_func(lv38)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add5_ccompiler(lv47: R.Tensor((1, 256, 14, 14), dtype="float32")) -> R.Tensor((1, 256, 14, 14), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv47_1: R.Tensor((1, 256, 14, 14), dtype="float32")) -> R.Tensor((1, 256, 14, 14), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv114: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv47_1, metadata["relax.expr.Constant"][12], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                gv: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv114, metadata["relax.expr.Constant"][13])
                R.output(gv)
            return gv

        output: R.Tensor((1, 256, 14, 14), dtype="float32") = local_func(lv47)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add6_ccompiler(lv43: R.Tensor((1, 128, 28, 28), dtype="float32")) -> R.Tensor((1, 256, 14, 14), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv43_1: R.Tensor((1, 128, 28, 28), dtype="float32")) -> R.Tensor((1, 256, 14, 14), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv123: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv43_1, metadata["relax.expr.Constant"][14], strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                gv: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv123, metadata["relax.expr.Constant"][15])
                R.output(gv)
            return gv

        output: R.Tensor((1, 256, 14, 14), dtype="float32") = local_func(lv43)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add7_ccompiler(lv59: R.Tensor((1, 256, 14, 14), dtype="float32")) -> R.Tensor((1, 256, 14, 14), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv59_1: R.Tensor((1, 256, 14, 14), dtype="float32")) -> R.Tensor((1, 256, 14, 14), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv143: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv59_1, metadata["relax.expr.Constant"][16], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                gv: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv143, metadata["relax.expr.Constant"][17])
                R.output(gv)
            return gv

        output: R.Tensor((1, 256, 14, 14), dtype="float32") = local_func(lv59)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add8_ccompiler(lv68: R.Tensor((1, 512, 7, 7), dtype="float32")) -> R.Tensor((1, 512, 7, 7), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv68_1: R.Tensor((1, 512, 7, 7), dtype="float32")) -> R.Tensor((1, 512, 7, 7), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv163: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv68_1, metadata["relax.expr.Constant"][18], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                gv: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv163, metadata["relax.expr.Constant"][19])
                R.output(gv)
            return gv

        output: R.Tensor((1, 512, 7, 7), dtype="float32") = local_func(lv68)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add9_ccompiler(lv64: R.Tensor((1, 256, 14, 14), dtype="float32")) -> R.Tensor((1, 512, 7, 7), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv64_1: R.Tensor((1, 256, 14, 14), dtype="float32")) -> R.Tensor((1, 512, 7, 7), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv172: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv64_1, metadata["relax.expr.Constant"][20], strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                gv: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv172, metadata["relax.expr.Constant"][21])
                R.output(gv)
            return gv

        output: R.Tensor((1, 512, 7, 7), dtype="float32") = local_func(lv64)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add_ccompiler(lv8: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv8_1: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv25: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv8_1, metadata["relax.expr.Constant"][22], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(lv25, metadata["relax.expr.Constant"][23])
                R.output(gv)
            return gv

        output: R.Tensor((1, 64, 56, 56), dtype="float32") = local_func(lv8)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add_relax_nn_relu1_ccompiler(lv4: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv4_1: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv15: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv4_1, metadata["relax.expr.Constant"][24], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                lv7: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(lv15, metadata["relax.expr.Constant"][25])
                gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv7)
                R.output(gv)
            return gv

        output: R.Tensor((1, 64, 56, 56), dtype="float32") = local_func(lv4)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add_relax_nn_relu2_ccompiler(lv13: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv13_1: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv35: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv13_1, metadata["relax.expr.Constant"][26], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                lv16: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(lv35, metadata["relax.expr.Constant"][27])
                gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv16)
                R.output(gv)
            return gv

        output: R.Tensor((1, 64, 56, 56), dtype="float32") = local_func(lv13)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add_relax_nn_relu3_ccompiler(lv22: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 128, 28, 28), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv22_1: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 128, 28, 28), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv55: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv22_1, metadata["relax.expr.Constant"][28], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                lv25: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv55, metadata["relax.expr.Constant"][29])
                gv: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv25)
                R.output(gv)
            return gv

        output: R.Tensor((1, 128, 28, 28), dtype="float32") = local_func(lv22)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add_relax_nn_relu4_ccompiler(lv34: R.Tensor((1, 128, 28, 28), dtype="float32")) -> R.Tensor((1, 128, 28, 28), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv34_1: R.Tensor((1, 128, 28, 28), dtype="float32")) -> R.Tensor((1, 128, 28, 28), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv84: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv34_1, metadata["relax.expr.Constant"][30], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                lv37: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv84, metadata["relax.expr.Constant"][31])
                gv: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv37)
                R.output(gv)
            return gv

        output: R.Tensor((1, 128, 28, 28), dtype="float32") = local_func(lv34)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add_relax_nn_relu5_ccompiler(lv43: R.Tensor((1, 128, 28, 28), dtype="float32")) -> R.Tensor((1, 256, 14, 14), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv43_1: R.Tensor((1, 128, 28, 28), dtype="float32")) -> R.Tensor((1, 256, 14, 14), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv104: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv43_1, metadata["relax.expr.Constant"][32], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                lv46: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv104, metadata["relax.expr.Constant"][33])
                gv: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv46)
                R.output(gv)
            return gv

        output: R.Tensor((1, 256, 14, 14), dtype="float32") = local_func(lv43)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add_relax_nn_relu6_ccompiler(lv55: R.Tensor((1, 256, 14, 14), dtype="float32")) -> R.Tensor((1, 256, 14, 14), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv55_1: R.Tensor((1, 256, 14, 14), dtype="float32")) -> R.Tensor((1, 256, 14, 14), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv133: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv55_1, metadata["relax.expr.Constant"][34], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                lv58: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv133, metadata["relax.expr.Constant"][35])
                gv: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv58)
                R.output(gv)
            return gv

        output: R.Tensor((1, 256, 14, 14), dtype="float32") = local_func(lv55)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add_relax_nn_relu7_ccompiler(lv64: R.Tensor((1, 256, 14, 14), dtype="float32")) -> R.Tensor((1, 512, 7, 7), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv64_1: R.Tensor((1, 256, 14, 14), dtype="float32")) -> R.Tensor((1, 512, 7, 7), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv153: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv64_1, metadata["relax.expr.Constant"][36], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                lv67: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv153, metadata["relax.expr.Constant"][37])
                gv: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv67)
                R.output(gv)
            return gv

        output: R.Tensor((1, 512, 7, 7), dtype="float32") = local_func(lv64)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add_relax_nn_relu8_ccompiler(lv76: R.Tensor((1, 512, 7, 7), dtype="float32")) -> R.Tensor((1, 512, 7, 7), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv76_1: R.Tensor((1, 512, 7, 7), dtype="float32")) -> R.Tensor((1, 512, 7, 7), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv182: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv76_1, metadata["relax.expr.Constant"][38], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                lv79: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv182, metadata["relax.expr.Constant"][39])
                gv: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv79)
                R.output(gv)
            return gv

        output: R.Tensor((1, 512, 7, 7), dtype="float32") = local_func(lv76)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_add_relax_nn_relu_ccompiler(inp_0: R.Tensor((1, 3, 224, 224), dtype="float32")) -> R.Tensor((1, 64, 112, 112), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(inp_0_1: R.Tensor((1, 3, 224, 224), dtype="float32")) -> R.Tensor((1, 64, 112, 112), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.conv2d_bias_relu"})
            with R.dataflow():
                lv5: R.Tensor((1, 64, 112, 112), dtype="float32") = R.nn.conv2d(inp_0_1, metadata["relax.expr.Constant"][40], strides=[2, 2], padding=[3, 3, 3, 3], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
                lv2: R.Tensor((1, 64, 112, 112), dtype="float32") = R.add(lv5, metadata["relax.expr.Constant"][41])
                gv: R.Tensor((1, 64, 112, 112), dtype="float32") = R.nn.relu(lv2)
                R.output(gv)
            return gv

        output: R.Tensor((1, 64, 112, 112), dtype="float32") = local_func(inp_0)
        return output

    @R.function
    def fused_relax_nn_max_pool2d_ccompiler(lv: R.Tensor((1, 64, 112, 112), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv_1: R.Tensor((1, 64, 112, 112), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.max_pool2d"})
            with R.dataflow():
                gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.max_pool2d(lv_1, pool_size=[3, 3], strides=[2, 2], dilation=[1, 1], padding=[1, 1, 1, 1], ceil_mode=False, count_include_pad=False, layout="NCHW", out_layout="NCHW")
                R.output(gv)
            return gv

        output: R.Tensor((1, 64, 56, 56), dtype="float32") = local_func(lv)
        return output

    @R.function
    def fused_relax_reshape_ccompiler(lv: R.Tensor((1, 512, 1, 1), dtype="float32")) -> R.Tensor((1, 512), dtype="float32"):
        R.func_attr({"Codegen": "ccompiler"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(lv_1: R.Tensor((1, 512, 1, 1), dtype="float32")) -> R.Tensor((1, 512), dtype="float32"):
            R.func_attr({"Composite": "ccompiler.reshape"})
            with R.dataflow():
                gv: R.Tensor((1, 512), dtype="float32") = R.reshape(lv_1, R.shape([1, 512]))
                R.output(gv)
            return gv

        output: R.Tensor((1, 512), dtype="float32") = local_func(lv)
        return output

    @R.function
    def main(inp_0: R.Tensor((1, 3, 224, 224), dtype="float32")) -> R.Tensor((1, 1000), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv: R.Tensor((1, 64, 112, 112), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add_relax_nn_relu_ccompiler(inp_0)
            lv_1: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_max_pool2d_ccompiler(lv)
            lv1: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add_relax_nn_relu1_ccompiler(lv_1)
            lv2: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add_ccompiler(lv1)
            lv_2: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_add_relax_nn_relu_ccompiler(lv2, lv_1)
            lv3: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add_relax_nn_relu2_ccompiler(lv_2)
            lv4: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add1_ccompiler(lv3)
            lv1_1: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_add_relax_nn_relu_ccompiler(lv4, lv_2)
            lv5: R.Tensor((1, 128, 28, 28), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add_relax_nn_relu3_ccompiler(lv1_1)
            lv6: R.Tensor((1, 128, 28, 28), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add2_ccompiler(lv5)
            lv7: R.Tensor((1, 128, 28, 28), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add3_ccompiler(lv1_1)
            lv2_1: R.Tensor((1, 128, 28, 28), dtype="float32") = cls.fused_relax_add_relax_nn_relu1_ccompiler(lv6, lv7)
            lv8: R.Tensor((1, 128, 28, 28), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add_relax_nn_relu4_ccompiler(lv2_1)
            lv9: R.Tensor((1, 128, 28, 28), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add4_ccompiler(lv8)
            lv3_1: R.Tensor((1, 128, 28, 28), dtype="float32") = cls.fused_relax_add_relax_nn_relu1_ccompiler(lv9, lv2_1)
            lv10: R.Tensor((1, 256, 14, 14), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add_relax_nn_relu5_ccompiler(lv3_1)
            lv11: R.Tensor((1, 256, 14, 14), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add5_ccompiler(lv10)
            lv12: R.Tensor((1, 256, 14, 14), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add6_ccompiler(lv3_1)
            lv4_1: R.Tensor((1, 256, 14, 14), dtype="float32") = cls.fused_relax_add_relax_nn_relu2_ccompiler(lv11, lv12)
            lv13: R.Tensor((1, 256, 14, 14), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add_relax_nn_relu6_ccompiler(lv4_1)
            lv14: R.Tensor((1, 256, 14, 14), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add7_ccompiler(lv13)
            lv5_1: R.Tensor((1, 256, 14, 14), dtype="float32") = cls.fused_relax_add_relax_nn_relu2_ccompiler(lv14, lv4_1)
            lv15: R.Tensor((1, 512, 7, 7), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add_relax_nn_relu7_ccompiler(lv5_1)
            lv16: R.Tensor((1, 512, 7, 7), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add8_ccompiler(lv15)
            lv17: R.Tensor((1, 512, 7, 7), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add9_ccompiler(lv5_1)
            lv6_1: R.Tensor((1, 512, 7, 7), dtype="float32") = cls.fused_relax_add_relax_nn_relu3_ccompiler(lv16, lv17)
            lv18: R.Tensor((1, 512, 7, 7), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add_relax_nn_relu8_ccompiler(lv6_1)
            lv19: R.Tensor((1, 512, 7, 7), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add10_ccompiler(lv18)
            lv7_1: R.Tensor((1, 512, 7, 7), dtype="float32") = cls.fused_relax_add_relax_nn_relu3_ccompiler(lv19, lv6_1)
            lv_3: R.Tensor((1, 512, 1, 1), dtype="float32") = cls.fused_relax_nn_adaptive_avg_pool2d_ccompiler(lv7_1)
            lv_4: R.Tensor((1, 512), dtype="float32") = cls.fused_relax_reshape_ccompiler(lv_3)
            gv: R.Tensor((1, 1000), dtype="float32") = cls.fused_relax_matmul_relax_add_ccompiler(lv_4)
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.
with tvm.transform.PassContext(opt_level=3):
    ex = tvm.compile(run_mod2, target="llvm")
    vm = relax.VirtualMachine(ex, tvm.cpu())
import numpy as np
dev = tvm.cpu()
inputs_np = [np.random.rand(1, 3, 224, 224).astype("float32")]
inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
# Run model and check outputs.
vm.set_input("main", *inputs)
vm.invoke_stateful("main")
tvm_output = vm.get_outputs("main")
with torch.no_grad():
    torch_output = torch_model(torch.from_numpy(inputs_np[0]))
np.testing.assert_allclose(tvm_output.numpy(), torch_output.numpy(), rtol=1e-7, atol=1e-5)