FoldBatchnormToConv2D

FoldBatchnormToConv2D#

参考:tvm/tests/python/relax/test_transform_fold_batch_norm_to_conv2d.py

import numpy as np

import tvm
import tvm.testing
from tvm import relax
from tvm.script import relax as R
from tvm.script import ir as I
from tvm.script.ir_builder import IRBuilder
from tvm.ir.module import IRModule
from tvm.script.ir_builder import relax as relax_builder
from tvm.relax.expr_functor import PyExprVisitor, visitor
def get_conv2d_batchnorm_sample():
    with IRBuilder() as builder:
        with relax_builder.function():
            R.func_name("main")
            data = R.arg("data", R.Tensor((1, 3, 224, 224), "float32"))
            weight = R.arg("weight", R.Tensor((32, 3, 3, 3), "float32"))
            with R.dataflow() as frame:
                output = R.emit(
                    R.nn.conv2d(
                        data,
                        weight,
                        out_dtype="float32",
                        strides=(1, 1),
                        dilation=(1, 1),
                        padding=(1, 1),
                        data_layout="NCHW",
                        kernel_layout="OIHW",
                        groups=1,
                    )
                )
                gamma = R.arg("gamma", R.Tensor((32,), "float32"))
                beta = R.arg("beta", R.Tensor((32,), "float32"))
                mean = R.arg("mean", R.Tensor((32,), "float32"))
                variance = R.arg("variance", R.Tensor((32,), "float32"))
                output = R.emit(
                    R.nn.batch_norm(output, gamma, beta, mean, variance, axis=1, epsilon=1e-5)[0]
                )
                R.output(output)

            R.func_ret_value(frame.output_vars[0])

    func = builder.get()

    return tvm.IRModule({"main": func})

验证 BatchNorm+Conv2D 融合的正确性#

target = tvm.target.Target("llvm", host="llvm")

创建测试模型:

mod = get_conv2d_batchnorm_sample()  # 原始模型
mod_fold = get_conv2d_batchnorm_sample()  # 用于优化的副本

准备测试数据:

# 生成随机输入数据(1张3通道224x224图片)
data_in = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype(np.float32))

# 生成模型参数(Conv2D权重和BatchNorm参数)
weight_data = tvm.nd.array(np.random.rand(32, 3, 3, 3).astype(np.float32))
gamma_data = tvm.nd.array(np.random.rand(32).astype(np.float32))
beta_data = tvm.nd.array(np.random.rand(32).astype(np.float32))
mean_data = tvm.nd.array(np.random.rand(32).astype(np.float32))
variance_data = tvm.nd.array(np.random.rand(32).astype(np.float32))

参数绑定:

# 将随机生成的参数绑定到两个模型
params_np = {
    "weight": weight_data,
    "gamma": gamma_data,
    "beta": beta_data,
    "mean": mean_data,
    "variance": variance_data,
}
mod = tvm.relax.transform.BindParams("main", params_np)(mod)
mod_fold = tvm.relax.transform.BindParams("main", params_np)(mod_fold)

基准模型执行:

# 原始模型处理流程
mod = tvm.relax.transform.DecomposeOpsForInference()(mod)  # 算子分解
ex = tvm.compile(mod, target)  # 编译
vm = relax.VirtualMachine(ex, tvm.cpu())  # 创建虚拟机
out = vm["main"](data_in)  # 执行推理
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(data: R.Tensor((1, 3, 224, 224), dtype="float32")) -> R.Tensor((1, 32, 224, 224), dtype="float32"):
        with R.dataflow():
            lv: R.Tensor((1, 32, 224, 224), dtype="float32") = R.nn.conv2d(data, metadata["relax.expr.Constant"][0], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv_1: R.Tensor((1, 32, 1, 1), dtype="float32") = R.expand_dims(metadata["relax.expr.Constant"][1], axis=[0, 2, 3])
            lv1: R.Tensor((1, 32, 224, 224), dtype="float32") = R.subtract(lv, lv_1)
            lv2: R.Tensor((1, 32, 1, 1), dtype="float32") = R.expand_dims(metadata["relax.expr.Constant"][2], axis=[0, 2, 3])
            lv3: R.Tensor((1, 32, 1, 1), dtype="float32") = R.add(lv2, R.const(9.9999997473787516e-06, "float32"))
            lv4: R.Tensor((1, 32, 1, 1), dtype="float32") = R.sqrt(lv3)
            lv5: R.Tensor((1, 32, 224, 224), dtype="float32") = R.divide(lv1, lv4)
            lv6: R.Tensor((1, 32, 1, 1), dtype="float32") = R.expand_dims(metadata["relax.expr.Constant"][3], axis=[0, 2, 3])
            lv7: R.Tensor((1, 32, 224, 224), dtype="float32") = R.multiply(lv5, lv6)
            lv8: R.Tensor((1, 32, 1, 1), dtype="float32") = R.expand_dims(metadata["relax.expr.Constant"][4], axis=[0, 2, 3])
            lv9: R.Tensor((1, 32, 224, 224), dtype="float32") = R.add(lv7, lv8)
            lv1_1: R.Tuple(R.Tensor((1, 32, 224, 224), dtype="float32"), R.Tensor((32,), dtype="float32"), R.Tensor((32,), dtype="float32")) = lv9, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]
            lv2_1: R.Tensor((1, 32, 224, 224), dtype="float32") = lv1_1[0]
            R.output(lv2_1)
        return lv2_1

# Metadata omitted. Use show_meta=True in script() method to show it.

折叠 BN 到 Conv2D:

# 应用BatchNorm折叠优化器
mod_fold = relax.transform.FoldBatchnormToConv2D()(mod_fold)
mod_fold = relax.transform.FoldConstant()(mod_fold)
ex_fold = tvm.compile(mod_fold, target)
vm_fold = relax.VirtualMachine(ex_fold, tvm.cpu())
out_fold = vm_fold["main"](data_in)

结果验证:

# 比较优化前后结果是否一致(容差1e-5)
tvm.testing.assert_allclose(out.numpy(), out_fold.numpy(), rtol=1e-5, atol=1e-5)
mod_fold.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(data: R.Tensor((1, 3, 224, 224), dtype="float32")) -> R.Tensor((1, 32, 224, 224), dtype="float32"):
        with R.dataflow():
            lv5: R.Tensor((1, 32, 224, 224), dtype="float32") = R.nn.conv2d(data, metadata["relax.expr.Constant"][0], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv2: R.Tensor((1, 32, 224, 224), dtype="float32") = R.add(lv5, metadata["relax.expr.Constant"][1])
            R.output(lv2)
        return lv2

# Metadata omitted. Use show_meta=True in script() method to show it.