# BranchTupleOutput

In [1]:
import numpy as np

import tvm
from tvm import relax
from tvm.relax.backend.cuda.cublas import partition_for_cublas
from tvm.relax.backend.cuda.cutlass import partition_for_cutlass
from tvm.relax.dpl.pattern import (
    is_op,
    is_tuple_get_item,
    make_fused_bias_activation_pattern,
    wildcard,
)
from tvm.relax.transform import PatternCheckContext
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T

In [2]:
conv2d_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation=None)
conv2d_relu_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation="relax.nn.relu")

In [3]:
@tvm.script.ir_module
class BranchTupleOutput:
    @R.function
    def main(
        data: R.Tensor((1, 64, 56, 56), "float32"),
        weight: R.Tensor((64, 64, 3, 3), "float32"),
    ):
        with R.dataflow():
            conv1 = R.nn.conv2d(data, weight)
            relu1 = R.nn.relu(conv1)
            gelu1 = R.nn.gelu(relu1)
            gelu2 = R.nn.gelu(conv1)
            out = relax.op.add(gelu1, gelu2)
            R.output(out)

        return out

In [4]:
BranchTupleOutput.show()

In [5]:
patterns =  [("dnnl.conv2d_relu", conv2d_relu_pat),]
bind_constants = True
annotate_codegen = True
partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(BranchTupleOutput)

In [6]:
partitioned.show()