Conv2D 测试#
定义前端网络
import torch
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv2d = torch.nn.Conv2d(3, 16, 3, 1, 1, bias=False)
self.relu = torch.nn.ReLU()
self.conv2d2 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=True)
self.relu2 = torch.nn.ReLU()
self.conv2d3 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=True)
self.conv2d4 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=True)
def forward(self, x):
x = self.conv2d(x)
x = self.relu(x)
x = self.conv2d2(x)
x = self.relu2(x)
x = self.conv2d3(x)
x = self.conv2d4(x)
return x
input_shape = [1, 3, 32, 32]
torch_model = M().eval()
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_exported_program
from torch.export import export
# Give an example argument to torch.export
example_args = (torch.randn(1, 3, 32, 32, dtype=torch.float32),)
# Convert the model to IRModule
with torch.no_grad():
exported_program = export(torch_model, example_args)
run_mod = from_exported_program(exported_program, keep_params_as_input=False)
# run_mod, params = relax.frontend.detach_params(run_mod)
from tvm.ir import IRModule
from tvm.relax import transform
from tvm.relax.dpl.pattern import (
DFPattern,
wildcard,
is_op,
)
from tvm.relax.transform import PatternCheckContext
from tvm.relax.backend.patterns import (
make_conv2d_pattern,
# make_attention_pattern,
# make_stacked_attention_pattern,
# make_layer_norm_pattern,
# make_rms_norm_pattern,
# make_matmul_dequantize_pattern,
# make_matmul_multiply_pattern,
# make_attention_rewrite_pattern,
# make_fused_bias_activation_pattern,
# make_residual_block_pattern,
)
from tvm.relax.backend.pattern_registry import get_patterns_with_prefix, register_patterns
from tvm.relax import expr as _expr
def _check_conv2d(context: PatternCheckContext) -> bool:
# lhs = context.annotated_expr["lhs"]
# rhs = context.annotated_expr["rhs"]
# expr = context.annotated_expr["root"]
# assert isinstance(lhs, relax.expr.Var) and lhs.name_hint == "data"
# assert isinstance(rhs, relax.expr.Var) and rhs.name_hint == "weight1"
# assert isinstance(expr, relax.expr.Call) and expr.op.name == "relax.nn.conv2d"
# return False
return True
register_patterns(
[
(
"vta.conv2d",
*make_conv2d_pattern(
with_bias=False,
),
_check_conv2d,
),
(
"vta.conv2d_bias",
*make_conv2d_pattern(
with_bias=True,
),
_check_conv2d,
),
(
"vta.conv2d_bias_relu",
*make_conv2d_pattern(
with_bias=True,
activation="relax.nn.relu",
),
_check_conv2d,
),
(
"vta.conv2d_relu",
*make_conv2d_pattern(
with_bias=False,
activation="relax.nn.relu",
),
_check_conv2d,
),
# (
# "vta.attention.BS3NH",
# *make_stacked_attention_pattern(start_op="split", layout="BS3NH"),
# partial(_check_stacked_attention, layout="BS3NH"),
# ),
# (
# "vta.attention.SBN3H",
# *make_stacked_attention_pattern(start_op="split", layout="SBN3H"),
# partial(_check_stacked_attention, layout="SBN3H"),
# ),
]
)
patterns = get_patterns_with_prefix("vta")
seq = tvm.transform.Sequential(
[
# transform.LegalizeOps(enable_warning=enable_warning),
# transform.RewriteDataflowReshape(),
transform.AnnotateTIROpPattern(),
transform.FoldConstant(),
transform.FuseOps(),
transform.FuseTIR(),
transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=False)
]
)
mod = seq(run_mod)
mod = (mod)
# mod = MergeCompositeFunctions()(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function(private=True)
def fused_relax_nn_conv2d_relax_add(lv: R.Tensor((1, 16, 32, 32), dtype="float32"), param_0: R.Tensor((16, 16, 3, 3), dtype="float32"), param_1: R.Tensor((1, 16, 1, 1), dtype="float32")) -> R.Tensor((1, 16, 32, 32), dtype="float32"):
R.func_attr({"Composite": "vta.conv2d_bias", "Primitive": True})
with R.dataflow():
lv6: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(lv, param_0, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
gv: R.Tensor((1, 16, 32, 32), dtype="float32") = R.add(lv6, param_1)
R.output(gv)
return gv
@R.function(private=True)
def fused_relax_nn_conv2d_relax_add_relax_nn_relu(lv: R.Tensor((1, 16, 32, 32), dtype="float32"), param_0: R.Tensor((16, 16, 3, 3), dtype="float32"), param_1: R.Tensor((1, 16, 1, 1), dtype="float32")) -> R.Tensor((1, 16, 32, 32), dtype="float32"):
R.func_attr({"Composite": "vta.conv2d_bias_relu", "Primitive": True})
with R.dataflow():
lv2: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(lv, param_0, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv4: R.Tensor((1, 16, 32, 32), dtype="float32") = R.add(lv2, param_1)
gv: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.relu(lv4)
R.output(gv)
return gv
@R.function(private=True)
def fused_relax_nn_conv2d_relax_nn_relu(x: R.Tensor((1, 3, 32, 32), dtype="float32"), param_0: R.Tensor((16, 3, 3, 3), dtype="float32")) -> R.Tensor((1, 16, 32, 32), dtype="float32"):
R.func_attr({"Composite": "vta.conv2d_relu", "Primitive": True})
with R.dataflow():
lv: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(x, param_0, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
gv: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.relu(lv)
R.output(gv)
return gv
@R.function
def main(x: R.Tensor((1, 3, 32, 32), dtype="float32")) -> R.Tuple(R.Tensor((1, 16, 32, 32), dtype="float32")):
cls = Module
with R.dataflow():
lv: R.Tensor((1, 16, 32, 32), dtype="float32") = cls.fused_relax_nn_conv2d_relax_nn_relu(x, metadata["relax.expr.Constant"][0])
lv_1: R.Tensor((1, 16, 32, 32), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add_relax_nn_relu(lv, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2])
lv_2: R.Tensor((1, 16, 32, 32), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add(lv_1, metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4])
lv1: R.Tensor((1, 16, 32, 32), dtype="float32") = cls.fused_relax_nn_conv2d_relax_add(lv_2, metadata["relax.expr.Constant"][5], metadata["relax.expr.Constant"][6])
gv: R.Tuple(R.Tensor((1, 16, 32, 32), dtype="float32")) = (lv1,)
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
def fvisit(expr: tvm.relax.Expr):
# print(type(expr), expr)
if isinstance(expr, _expr.GlobalVar):
func = mod[expr]
func = func.with_attr("op_pattern", 2)
print(func.attrs, isinstance(func, tvm.tir.PrimExpr))
# if isinstance(expr, _expr.Call):
# # print(expr.attrs)
# print(expr.op)
expr = mod["main"]
relax.analysis.post_order_visit(expr, fvisit)
{"Composite": "vta.conv2d_relu", "Primitive": True, "op_pattern": 2} False
{"Composite": "vta.conv2d_bias_relu", "Primitive": True, "op_pattern": 2} False
{"Composite": "vta.conv2d_bias", "Primitive": True, "op_pattern": 2} False
{"Composite": "vta.conv2d_bias", "Primitive": True, "op_pattern": 2} False
# op = tvm.ir.Op.get("relax.add")
# op.set_attr("op_pattern", OpPatternKind.kElemWise)