# 融合 Conv2dReLU

In [1]:
import numpy as np

import tvm
from tvm import relax
from tvm.relax.backend.cuda.cublas import partition_for_cublas
from tvm.relax.backend.cuda.cutlass import partition_for_cutlass
from tvm.relax.dpl.pattern import (
    is_op,
    is_tuple_get_item,
    make_fused_bias_activation_pattern,
    wildcard,
)
from tvm.relax.transform import PatternCheckContext
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T

构建模型：

In [2]:
@tvm.script.ir_module
class Conv2dReLU:
    @R.function
    def main(
        data: R.Tensor((1, 64, 56, 56), "float32"),
        weight1: R.Tensor((64, 64, 3, 3), "float32"),
    ):
        with R.dataflow():
            conv1 = R.nn.relu(R.nn.conv2d(data, weight1, padding=(1, 1)))
            R.output(conv1)

        return conv1

In [3]:
Conv2dReLU.show()

In [4]:
conv2d_relu_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation="relax.nn.relu")

## 绑定参数

In [5]:
weight_np = np.random.randn(64, 64, 3, 3).astype("float32")
mod = tvm.transform.Sequential(
    [
        relax.transform.BindParams("main", {"weight1": weight_np}),
        relax.transform.FuseOpsByPattern(
            [("dnnl.conv2d_relu", conv2d_relu_pat)], bind_constants=True
        ),
    ]
)(Conv2dReLU)

In [6]:
mod.show()

In [7]:
assert "fused_relax_nn_conv2d_relax_nn_relu" in [var.name_hint for var in mod.functions.keys()]

for gvar, f in mod.functions.items():
    if gvar.name_hint == "fused_relax_nn_conv2d_relax_nn_relu":
        conv2d = f.body.blocks[0].bindings[0].value
        assert isinstance(conv2d.args[1], relax.Constant)

## `FuseOpsByPattern`

In [8]:
patterns = [("dnnl.conv2d_relu", conv2d_relu_pat)]
bind_constants = True
annotate_codegen = True
partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(Conv2dReLU)

In [9]:
partitioned.show()

In [10]:
patterns = []
bind_constants = True
annotate_codegen = True
partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(Conv2dReLU)

In [11]:
partitioned.show()

In [12]:
Conv2dReLU.show()

## 融合 `Conv2dReLUx2`

In [13]:
@tvm.script.ir_module
class Conv2dReLUx2:
    @R.function
    def main(
        data: R.Tensor((1, 64, 56, 56), "float32"),
        weight1: R.Tensor((64, 64, 3, 3), "float32"),
        weight2: R.Tensor((64, 64, 3, 3), "float32"),
    ):
        with R.dataflow():
            conv1 = R.nn.relu(R.nn.conv2d(data, weight1, padding=(1, 1)))
            conv2 = R.nn.relu(R.nn.conv2d(conv1, weight2, padding=(0, 0)))
            R.output(conv2)

        return conv2

In [14]:
Conv2dReLUx2.show()

In [15]:
patterns = [("dnnl.conv2d_relu", conv2d_relu_pat)]
bind_constants = True
annotate_codegen = True
partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(Conv2dReLUx2)

In [16]:
partitioned.show()

模板匹配依赖顺序：

In [17]:
conv2d_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation=None)
patterns =  [("dnnl.conv2d", conv2d_pat), ("dnnl.conv2d_relu", conv2d_relu_pat)]
bind_constants = True
annotate_codegen = True
partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(Conv2dReLUx2)

- 变换器会优先尝试匹配第一个模式 `conv2d_pat`
- 即使存在更复杂的 `conv2d_relu` 模式，由于顺序优先，会先应用简单的 `conv2d` 融合
- 最终结果中只有 `conv2d` 算子被融合，ReLU 算子保持独立

In [18]:
partitioned.show()

## 融合 `Conv2dConv2dReLU`

In [19]:
@tvm.script.ir_module
class Conv2dConv2dReLU:
    @R.function
    def main(
        data: R.Tensor((1, 64, 56, 56), "float32"),
        weight1: R.Tensor((64, 64, 3, 3), "float32"),
        weight2: R.Tensor((64, 64, 3, 3), "float32"),
    ):
        with R.dataflow():
            conv1 = R.nn.conv2d(data, weight1, padding=(1, 1))
            conv2d = R.nn.relu(R.nn.conv2d(conv1, weight2, padding=(0, 0)))
            R.output(conv2d)

        return conv2d

In [20]:
Conv2dConv2dReLU.show()

In [21]:
conv2d_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation=None)
patterns =  [("dnnl.conv2d_relu", conv2d_relu_pat), ("dnnl.conv2d", conv2d_pat)]
bind_constants = True
annotate_codegen = True
partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants, annotate_codegen)(Conv2dReLUx2)

In [22]:
partitioned.show()