仿真量化测试#
定义前端网络
import torch
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv2d = torch.nn.Conv2d(3, 16, 3, 1, 1, bias=False)
self.relu = torch.nn.ReLU()
self.conv2d2 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=True)
self.relu2 = torch.nn.ReLU()
self.conv2d3 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=True)
self.conv2d4 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=True)
def forward(self, x):
x = self.conv2d(x)
x = self.relu(x)
x = self.conv2d2(x)
x = self.relu2(x)
x = self.conv2d3(x)
x = self.conv2d4(x)
return x
input_shape = [1, 3, 32, 32]
torch_model = M().eval()
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_exported_program
from torch.export import export
# Give an example argument to torch.export
example_args = (torch.randn(1, 3, 32, 32, dtype=torch.float32),)
# Convert the model to IRModule
with torch.no_grad():
exported_program = export(torch_model, example_args)
run_mod = from_exported_program(exported_program, keep_params_as_input=False)
# run_mod, params = relax.frontend.detach_params(run_mod)
run_mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((1, 3, 32, 32), dtype="float32")) -> R.Tuple(R.Tensor((1, 16, 32, 32), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(x, metadata["relax.expr.Constant"][0], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv1: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.relu(lv)
lv2: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(lv1, metadata["relax.expr.Constant"][1], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv3: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][2], R.shape([1, 16, 1, 1]))
lv4: R.Tensor((1, 16, 32, 32), dtype="float32") = R.add(lv2, lv3)
lv5: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.relu(lv4)
lv6: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(lv5, metadata["relax.expr.Constant"][3], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv7: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][4], R.shape([1, 16, 1, 1]))
lv8: R.Tensor((1, 16, 32, 32), dtype="float32") = R.add(lv6, lv7)
lv9: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(lv8, metadata["relax.expr.Constant"][5], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv10: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][6], R.shape([1, 16, 1, 1]))
lv11: R.Tensor((1, 16, 32, 32), dtype="float32") = R.add(lv9, lv10)
gv: R.Tuple(R.Tensor((1, 16, 32, 32), dtype="float32")) = (lv11,)
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
import hashlib
from tvm.relax.analysis import post_order_visit
from tvm.relax import expr as _expr
from tvm.ir.op import Op
specific_op_names = ["nn.conv2d"]
specific_op_names = [f"relax.{name}" for name in specific_op_names]
specific_ops = {}
calls = []
def fvisit(expr):
print(type(expr))
# if isinstance(expr, Op):
# if expr.name in specific_op_names:
# hash_value = hashlib.sha256(tvm.ir.save_json(expr).encode("utf-8")).hexdigest()
# specific_ops[hash_value] = expr
if isinstance(expr, _expr.Call):
if expr.op.name in specific_op_names:
# expr = expr.args[0]
# hash_value = hashlib.sha256(tvm.ir.save_json(expr).encode("utf-8")).hexdigest()
# specific_ops[hash_value] = expr
calls.append(expr)
# elif isinstance(expr, _expr.Var):
# hash_value = hashlib.sha256(tvm.ir.save_json(expr).encode("utf-8")).hexdigest()
# specific_ops[hash_value] = expr
# elif isinstance(expr, _expr.Function):
# expr = expr.body
# hash_value = hashlib.sha256(tvm.ir.save_json(expr).encode("utf-8")).hexdigest()
# specific_ops[hash_value] = expr
expr = run_mod["main"]
post_order_visit(expr, fvisit)
expr = _expr.Tuple([v for v in specific_ops.values()])
mod = tvm.IRModule.from_expr(expr)
mod.show()
<class 'tvm.ir.op.Op'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.Var'>
<class 'tvm.relax.expr.Constant'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.Call'>
<class 'tvm.ir.op.Op'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.DataflowVar'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.Call'>
<class 'tvm.ir.op.Op'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.DataflowVar'>
<class 'tvm.relax.expr.Constant'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.Call'>
<class 'tvm.ir.op.Op'>
<class 'tvm.relax.expr.Constant'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.Call'>
<class 'tvm.ir.op.Op'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.DataflowVar'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.DataflowVar'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.Call'>
<class 'tvm.ir.op.Op'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.DataflowVar'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.Call'>
<class 'tvm.ir.op.Op'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.DataflowVar'>
<class 'tvm.relax.expr.Constant'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.Call'>
<class 'tvm.ir.op.Op'>
<class 'tvm.relax.expr.Constant'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.Call'>
<class 'tvm.ir.op.Op'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.DataflowVar'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.DataflowVar'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.Call'>
<class 'tvm.ir.op.Op'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.DataflowVar'>
<class 'tvm.relax.expr.Constant'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.Call'>
<class 'tvm.ir.op.Op'>
<class 'tvm.relax.expr.Constant'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.Call'>
<class 'tvm.ir.op.Op'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.DataflowVar'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.DataflowVar'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.Call'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.DataflowVar'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.Tuple'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.Var'>
<class 'tvm.relax.expr.ShapeExpr'>
<class 'tvm.relax.expr.SeqExpr'>
<class 'tvm.relax.expr.Function'>
# from tvm.script import ir as I
@I.ir_module
class Module:
main = None
print(run_mod["main"])
# from tvm.script import relax as R
@R.function
def main(x: R.Tensor((1, 3, 32, 32), dtype="float32")) -> R.Tuple(R.Tensor((1, 16, 32, 32), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(x, metadata["relax.expr.Constant"][0], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv1: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.relu(lv)
lv2: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(lv1, metadata["relax.expr.Constant"][1], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv3: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][2], R.shape([1, 16, 1, 1]))
lv4: R.Tensor((1, 16, 32, 32), dtype="float32") = R.add(lv2, lv3)
lv5: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.relu(lv4)
lv6: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(lv5, metadata["relax.expr.Constant"][3], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv7: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][4], R.shape([1, 16, 1, 1]))
lv8: R.Tensor((1, 16, 32, 32), dtype="float32") = R.add(lv6, lv7)
lv9: R.Tensor((1, 16, 32, 32), dtype="float32") = R.nn.conv2d(lv8, metadata["relax.expr.Constant"][5], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv10: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][6], R.shape([1, 16, 1, 1]))
lv11: R.Tensor((1, 16, 32, 32), dtype="float32") = R.add(lv9, lv10)
gv: R.Tuple(R.Tensor((1, 16, 32, 32), dtype="float32")) = (lv11,)
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.