BranchTupleOutput

BranchTupleOutput#

import numpy as np

import tvm
from tvm import relax
from tvm.relax.backend.cuda.cublas import partition_for_cublas
from tvm.relax.backend.cuda.cutlass import partition_for_cutlass
from tvm.relax.dpl.pattern import (
    is_op,
    is_tuple_get_item,
    make_fused_bias_activation_pattern,
    wildcard,
)
from tvm.relax.transform import PatternCheckContext
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
conv2d_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation=None)
conv2d_relu_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation="relax.nn.relu")
@tvm.script.ir_module
class BranchTupleOutput:
    @R.function
    def main(
        data: R.Tensor((1, 64, 56, 56), "float32"),
        weight: R.Tensor((64, 64, 3, 3), "float32"),
    ):
        with R.dataflow():
            conv1 = R.nn.conv2d(data, weight)
            relu1 = R.nn.relu(conv1)
            gelu1 = R.nn.gelu(relu1)
            gelu2 = R.nn.gelu(conv1)
            out = relax.op.add(gelu1, gelu2)
            R.output(out)

        return out
BranchTupleOutput.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
        with R.dataflow():
            conv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(data, weight, strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            relu1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(conv1)
            gelu1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(relu1)
            gelu2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(conv1)
            out: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(gelu1, gelu2)
            R.output(out)
        return out
patterns =  [("dnnl.conv2d_relu", conv2d_relu_pat),]
bind_constants = True
annotate_codegen = True
partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(BranchTupleOutput)
partitioned.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def fused_relax_nn_conv2d_relax_nn_relu_dnnl(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tuple(R.Tensor((1, 64, 54, 54), dtype="float32"), R.Tensor((1, 64, 54, 54), dtype="float32")):
        R.func_attr({"Codegen": "dnnl"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(data_1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight_1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tuple(R.Tensor((1, 64, 54, 54), dtype="float32"), R.Tensor((1, 64, 54, 54), dtype="float32")):
            R.func_attr({"Composite": "dnnl.conv2d_relu"})
            with R.dataflow():
                gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(data_1, weight_1, strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
                gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(gv)
                R.output(gv, gv1)
            return (gv1, gv)

        output: R.Tuple(R.Tensor((1, 64, 54, 54), dtype="float32"), R.Tensor((1, 64, 54, 54), dtype="float32")) = local_func(data, weight)
        return output

    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv: R.Tuple(R.Tensor((1, 64, 54, 54), dtype="float32"), R.Tensor((1, 64, 54, 54), dtype="float32")) = cls.fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight)
            lv1: R.Tensor((1, 64, 54, 54), dtype="float32") = lv[1]
            lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv[0]
            gelu1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv2)
            gelu2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv1)
            out: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(gelu1, gelu2)
            R.output(out)
        return out