FoldBatchnormToConv2D
#
参考:tvm/tests/python/relax/test_transform_fold_batch_norm_to_conv2d.py
import numpy as np
import tvm
import tvm.testing
from tvm import relax
from tvm.script import relax as R
from tvm.script import ir as I
from tvm.script.ir_builder import IRBuilder
from tvm.ir.module import IRModule
from tvm.script.ir_builder import relax as relax_builder
from tvm.relax.expr_functor import PyExprVisitor, visitor
def get_conv2d_batchnorm_sample():
with IRBuilder() as builder:
with relax_builder.function():
R.func_name("main")
data = R.arg("data", R.Tensor((1, 3, 224, 224), "float32"))
weight = R.arg("weight", R.Tensor((32, 3, 3, 3), "float32"))
with R.dataflow() as frame:
output = R.emit(
R.nn.conv2d(
data,
weight,
out_dtype="float32",
strides=(1, 1),
dilation=(1, 1),
padding=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
groups=1,
)
)
gamma = R.arg("gamma", R.Tensor((32,), "float32"))
beta = R.arg("beta", R.Tensor((32,), "float32"))
mean = R.arg("mean", R.Tensor((32,), "float32"))
variance = R.arg("variance", R.Tensor((32,), "float32"))
output = R.emit(
R.nn.batch_norm(output, gamma, beta, mean, variance, axis=1, epsilon=1e-5)[0]
)
R.output(output)
R.func_ret_value(frame.output_vars[0])
func = builder.get()
return tvm.IRModule({"main": func})
验证 BatchNorm+Conv2D 融合的正确性#
target = tvm.target.Target("llvm", host="llvm")
创建测试模型:
mod = get_conv2d_batchnorm_sample() # 原始模型
mod_fold = get_conv2d_batchnorm_sample() # 用于优化的副本
准备测试数据:
# 生成随机输入数据(1张3通道224x224图片)
data_in = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype(np.float32))
# 生成模型参数(Conv2D权重和BatchNorm参数)
weight_data = tvm.nd.array(np.random.rand(32, 3, 3, 3).astype(np.float32))
gamma_data = tvm.nd.array(np.random.rand(32).astype(np.float32))
beta_data = tvm.nd.array(np.random.rand(32).astype(np.float32))
mean_data = tvm.nd.array(np.random.rand(32).astype(np.float32))
variance_data = tvm.nd.array(np.random.rand(32).astype(np.float32))
参数绑定:
# 将随机生成的参数绑定到两个模型
params_np = {
"weight": weight_data,
"gamma": gamma_data,
"beta": beta_data,
"mean": mean_data,
"variance": variance_data,
}
mod = tvm.relax.transform.BindParams("main", params_np)(mod)
mod_fold = tvm.relax.transform.BindParams("main", params_np)(mod_fold)
基准模型执行:
# 原始模型处理流程
mod = tvm.relax.transform.DecomposeOpsForInference()(mod) # 算子分解
ex = tvm.compile(mod, target) # 编译
vm = relax.VirtualMachine(ex, tvm.cpu()) # 创建虚拟机
out = vm["main"](data_in) # 执行推理
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(data: R.Tensor((1, 3, 224, 224), dtype="float32")) -> R.Tensor((1, 32, 224, 224), dtype="float32"):
with R.dataflow():
lv: R.Tensor((1, 32, 224, 224), dtype="float32") = R.nn.conv2d(data, metadata["relax.expr.Constant"][0], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv_1: R.Tensor((1, 32, 1, 1), dtype="float32") = R.expand_dims(metadata["relax.expr.Constant"][1], axis=[0, 2, 3])
lv1: R.Tensor((1, 32, 224, 224), dtype="float32") = R.subtract(lv, lv_1)
lv2: R.Tensor((1, 32, 1, 1), dtype="float32") = R.expand_dims(metadata["relax.expr.Constant"][2], axis=[0, 2, 3])
lv3: R.Tensor((1, 32, 1, 1), dtype="float32") = R.add(lv2, R.const(9.9999997473787516e-06, "float32"))
lv4: R.Tensor((1, 32, 1, 1), dtype="float32") = R.sqrt(lv3)
lv5: R.Tensor((1, 32, 224, 224), dtype="float32") = R.divide(lv1, lv4)
lv6: R.Tensor((1, 32, 1, 1), dtype="float32") = R.expand_dims(metadata["relax.expr.Constant"][3], axis=[0, 2, 3])
lv7: R.Tensor((1, 32, 224, 224), dtype="float32") = R.multiply(lv5, lv6)
lv8: R.Tensor((1, 32, 1, 1), dtype="float32") = R.expand_dims(metadata["relax.expr.Constant"][4], axis=[0, 2, 3])
lv9: R.Tensor((1, 32, 224, 224), dtype="float32") = R.add(lv7, lv8)
lv1_1: R.Tuple(R.Tensor((1, 32, 224, 224), dtype="float32"), R.Tensor((32,), dtype="float32"), R.Tensor((32,), dtype="float32")) = lv9, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2]
lv2_1: R.Tensor((1, 32, 224, 224), dtype="float32") = lv1_1[0]
R.output(lv2_1)
return lv2_1
# Metadata omitted. Use show_meta=True in script() method to show it.
折叠 BN 到 Conv2D:
# 应用BatchNorm折叠优化器
mod_fold = relax.transform.FoldBatchnormToConv2D()(mod_fold)
mod_fold = relax.transform.FoldConstant()(mod_fold)
ex_fold = tvm.compile(mod_fold, target)
vm_fold = relax.VirtualMachine(ex_fold, tvm.cpu())
out_fold = vm_fold["main"](data_in)
结果验证:
# 比较优化前后结果是否一致(容差1e-5)
tvm.testing.assert_allclose(out.numpy(), out_fold.numpy(), rtol=1e-5, atol=1e-5)
mod_fold.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(data: R.Tensor((1, 3, 224, 224), dtype="float32")) -> R.Tensor((1, 32, 224, 224), dtype="float32"):
with R.dataflow():
lv5: R.Tensor((1, 32, 224, 224), dtype="float32") = R.nn.conv2d(data, metadata["relax.expr.Constant"][0], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv2: R.Tensor((1, 32, 224, 224), dtype="float32") = R.add(lv5, metadata["relax.expr.Constant"][1])
R.output(lv2)
return lv2
# Metadata omitted. Use show_meta=True in script() method to show it.