融合 Conv2dReLU#

import numpy as np

import tvm
from tvm import relax
from tvm.relax.backend.cuda.cublas import partition_for_cublas
from tvm.relax.backend.cuda.cutlass import partition_for_cutlass
from tvm.relax.dpl.pattern import (
    is_op,
    is_tuple_get_item,
    make_fused_bias_activation_pattern,
    wildcard,
)
from tvm.relax.transform import PatternCheckContext
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T

构建模型:

@tvm.script.ir_module
class Conv2dReLU:
    @R.function
    def main(
        data: R.Tensor((1, 64, 56, 56), "float32"),
        weight1: R.Tensor((64, 64, 3, 3), "float32"),
    ):
        with R.dataflow():
            conv1 = R.nn.relu(R.nn.conv2d(data, weight1, padding=(1, 1)))
            R.output(conv1)

        return conv1
Conv2dReLU.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        with R.dataflow():
            lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(data, weight1, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
            R.output(conv1)
        return conv1
conv2d_relu_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation="relax.nn.relu")

绑定参数#

weight_np = np.random.randn(64, 64, 3, 3).astype("float32")
mod = tvm.transform.Sequential(
    [
        relax.transform.BindParams("main", {"weight1": weight_np}),
        relax.transform.FuseOpsByPattern(
            [("dnnl.conv2d_relu", conv2d_relu_pat)], bind_constants=True
        ),
    ]
)(Conv2dReLU)
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function(private=True)
    def fused_relax_nn_conv2d_relax_nn_relu(data: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        R.func_attr({"Composite": "dnnl.conv2d_relu", "Primitive": 1})
        with R.dataflow():
            lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(data, metadata["relax.expr.Constant"][0], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
            R.output(gv)
        return gv

    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        cls = Module
        with R.dataflow():
            gv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d_relax_nn_relu(data)
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.
assert "fused_relax_nn_conv2d_relax_nn_relu" in [var.name_hint for var in mod.functions.keys()]

for gvar, f in mod.functions.items():
    if gvar.name_hint == "fused_relax_nn_conv2d_relax_nn_relu":
        conv2d = f.body.blocks[0].bindings[0].value
        assert isinstance(conv2d.args[1], relax.Constant)

FuseOpsByPattern#

patterns = [("dnnl.conv2d_relu", conv2d_relu_pat)]
bind_constants = True
annotate_codegen = True
partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(Conv2dReLU)
partitioned.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def fused_relax_nn_conv2d_relax_nn_relu_dnnl(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        R.func_attr({"Codegen": "dnnl"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(data_1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1_1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
            R.func_attr({"Composite": "dnnl.conv2d_relu"})
            with R.dataflow():
                lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(data_1, weight1_1, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
                gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
                R.output(gv)
            return gv

        output: R.Tensor((1, 64, 56, 56), dtype="float32") = local_func(data, weight1)
        return output

    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        cls = Module
        with R.dataflow():
            gv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1)
            R.output(gv)
        return gv
patterns = []
bind_constants = True
annotate_codegen = True
partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(Conv2dReLU)
partitioned.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        with R.dataflow():
            lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(data, weight1, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
            R.output(conv1)
        return conv1
Conv2dReLU.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        with R.dataflow():
            lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(data, weight1, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
            R.output(conv1)
        return conv1

融合 Conv2dReLUx2#

@tvm.script.ir_module
class Conv2dReLUx2:
    @R.function
    def main(
        data: R.Tensor((1, 64, 56, 56), "float32"),
        weight1: R.Tensor((64, 64, 3, 3), "float32"),
        weight2: R.Tensor((64, 64, 3, 3), "float32"),
    ):
        with R.dataflow():
            conv1 = R.nn.relu(R.nn.conv2d(data, weight1, padding=(1, 1)))
            conv2 = R.nn.relu(R.nn.conv2d(conv1, weight2, padding=(0, 0)))
            R.output(conv2)

        return conv2
Conv2dReLUx2.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), weight2: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
        with R.dataflow():
            lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(data, weight1, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
            lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(conv1, weight2, strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            conv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv2)
            R.output(conv2)
        return conv2
patterns = [("dnnl.conv2d_relu", conv2d_relu_pat)]
bind_constants = True
annotate_codegen = True
partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(Conv2dReLUx2)
partitioned.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def fused_relax_nn_conv2d_relax_nn_relu1_dnnl(conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight2: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
        R.func_attr({"Codegen": "dnnl"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(conv1_1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight2_1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
            R.func_attr({"Composite": "dnnl.conv2d_relu"})
            with R.dataflow():
                lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(conv1_1, weight2_1, strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
                gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv2)
                R.output(gv)
            return gv

        output: R.Tensor((1, 64, 54, 54), dtype="float32") = local_func(conv1, weight2)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_nn_relu_dnnl(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        R.func_attr({"Codegen": "dnnl"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(data_1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1_1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
            R.func_attr({"Composite": "dnnl.conv2d_relu"})
            with R.dataflow():
                lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(data_1, weight1_1, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
                gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
                R.output(gv)
            return gv

        output: R.Tensor((1, 64, 56, 56), dtype="float32") = local_func(data, weight1)
        return output

    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), weight2: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1)
            gv: R.Tensor((1, 64, 54, 54), dtype="float32") = cls.fused_relax_nn_conv2d_relax_nn_relu1_dnnl(lv, weight2)
            R.output(gv)
        return gv

模板匹配依赖顺序:

conv2d_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation=None)
patterns =  [("dnnl.conv2d", conv2d_pat), ("dnnl.conv2d_relu", conv2d_relu_pat)]
bind_constants = True
annotate_codegen = True
partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(Conv2dReLUx2)
  • 变换器会优先尝试匹配第一个模式 conv2d_pat

  • 即使存在更复杂的 conv2d_relu 模式,由于顺序优先,会先应用简单的 conv2d 融合

  • 最终结果中只有 conv2d 算子被融合,ReLU 算子保持独立

partitioned.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def fused_relax_nn_conv2d1_dnnl(conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight2: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
        R.func_attr({"Codegen": "dnnl"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(conv1_1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight2_1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
            R.func_attr({"Composite": "dnnl.conv2d"})
            with R.dataflow():
                gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(conv1_1, weight2_1, strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
                R.output(gv)
            return gv

        output: R.Tensor((1, 64, 54, 54), dtype="float32") = local_func(conv1, weight2)
        return output

    @R.function
    def fused_relax_nn_conv2d_dnnl(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        R.func_attr({"Codegen": "dnnl"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(data_1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1_1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
            R.func_attr({"Composite": "dnnl.conv2d"})
            with R.dataflow():
                gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(data_1, weight1_1, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
                R.output(gv)
            return gv

        output: R.Tensor((1, 64, 56, 56), dtype="float32") = local_func(data, weight1)
        return output

    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), weight2: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d_dnnl(data, weight1)
            conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
            lv1: R.Tensor((1, 64, 54, 54), dtype="float32") = cls.fused_relax_nn_conv2d1_dnnl(conv1, weight2)
            conv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1)
            R.output(conv2)
        return conv2

融合 Conv2dConv2dReLU#

@tvm.script.ir_module
class Conv2dConv2dReLU:
    @R.function
    def main(
        data: R.Tensor((1, 64, 56, 56), "float32"),
        weight1: R.Tensor((64, 64, 3, 3), "float32"),
        weight2: R.Tensor((64, 64, 3, 3), "float32"),
    ):
        with R.dataflow():
            conv1 = R.nn.conv2d(data, weight1, padding=(1, 1))
            conv2d = R.nn.relu(R.nn.conv2d(conv1, weight2, padding=(0, 0)))
            R.output(conv2d)

        return conv2d
Conv2dConv2dReLU.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), weight2: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
        with R.dataflow():
            conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(data, weight1, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(conv1, weight2, strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            conv2d: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1)
            R.output(conv2d)
        return conv2d
conv2d_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation=None)
patterns =  [("dnnl.conv2d_relu", conv2d_relu_pat), ("dnnl.conv2d", conv2d_pat)]
bind_constants = True
annotate_codegen = True
partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(Conv2dReLUx2)
partitioned.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def fused_relax_nn_conv2d_relax_nn_relu1_dnnl(conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight2: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
        R.func_attr({"Codegen": "dnnl"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(conv1_1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight2_1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
            R.func_attr({"Composite": "dnnl.conv2d_relu"})
            with R.dataflow():
                lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(conv1_1, weight2_1, strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
                gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv2)
                R.output(gv)
            return gv

        output: R.Tensor((1, 64, 54, 54), dtype="float32") = local_func(conv1, weight2)
        return output

    @R.function
    def fused_relax_nn_conv2d_relax_nn_relu_dnnl(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        R.func_attr({"Codegen": "dnnl"})
        # from tvm.script import relax as R
        
        @R.function
        def local_func(data_1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1_1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
            R.func_attr({"Composite": "dnnl.conv2d_relu"})
            with R.dataflow():
                lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(data_1, weight1_1, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
                gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
                R.output(gv)
            return gv

        output: R.Tensor((1, 64, 56, 56), dtype="float32") = local_func(data, weight1)
        return output

    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), weight2: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1)
            gv: R.Tensor((1, 64, 54, 54), dtype="float32") = cls.fused_relax_nn_conv2d_relax_nn_relu1_dnnl(lv, weight2)
            R.output(gv)
        return gv