PyTorch resnet18 Relax

PyTorch resnet18 Relax#

import torch
from torchvision import models
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx

fold_pipeline = tvm.transform.Sequential([
    relax.transform.FoldBatchnormToConv2D(),
    relax.transform.FoldConstant(),
    relax.transform.RemoveRedundantReshape(),
])

# 创建 PyTorch 模型
torch_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT).eval()
shape = [1, 3, 224, 224]
input_info = [(shape, "float32")]
# 变换为 Relay 模型
with torch.no_grad():
    graph_model = torch.fx.symbolic_trace(torch_model)
    mod = from_fx(graph_model, input_info)
# 初次优化模型
run_mod = fold_pipeline(mod)
run_mod.show()
from tvm.relax.dpl.pattern import wildcard, is_op, is_const, make_fused_bias_activation_pattern
def make_fused_op_bias_activation_pattern(op_name="relax.nn.conv2d", activation="relax.nn.relu"):
    op_bias_relu_pat = make_fused_bias_activation_pattern(
        op_name,
        with_bias=True,
        activation=activation
    )
    op_relu_pat = make_fused_bias_activation_pattern(
        op_name,
        with_bias=False,
        activation=activation
    )
    op_bias_pat = make_fused_bias_activation_pattern(
        op_name,
        with_bias=True,
    )
    return op_bias_relu_pat | op_relu_pat | op_bias_pat

compiler = "ccompiler"
patterns = [
    (f"{compiler}.conv2d_bias_relu", make_fused_op_bias_activation_pattern("relax.nn.conv2d")),
    (f"{compiler}.matmul_bias_relu", make_fused_op_bias_activation_pattern("relax.matmul")),
    (f"{compiler}.add_activation", make_fused_bias_activation_pattern("relax.add", with_bias=False, activation="relax.nn.relu")),
    (f"{compiler}.max_pool2d", is_op("relax.nn.max_pool2d")(wildcard())),
    (f"{compiler}.adaptive_avg_pool2d", is_op("relax.nn.adaptive_avg_pool2d")(wildcard())),
    (f"{compiler}.reshape", is_op("relax.reshape")(wildcard(), wildcard())),
]
fuse_pipeline = tvm.transform.Sequential([
    relax.transform.FuseOpsByPattern(patterns, bind_constants=True),
    relax.transform.MergeCompositeFunctions(),
    relax.transform.FuseOps(),
    relax.DeadCodeElimination(),
])
run_mod2 = fuse_pipeline(run_mod)
run_mod2.show()
with tvm.transform.PassContext(opt_level=3):
    ex = tvm.compile(run_mod2, target="llvm")
    vm = relax.VirtualMachine(ex, tvm.cpu())
import numpy as np
dev = tvm.cpu()
inputs_np = [np.random.rand(1, 3, 224, 224).astype("float32")]
inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
# Run model and check outputs.
vm.set_input("main", *inputs)
vm.invoke_stateful("main")
tvm_output = vm.get_outputs("main")
with torch.no_grad():
    torch_output = torch_model(torch.from_numpy(inputs_np[0]))
np.testing.assert_allclose(tvm_output.numpy(), torch_output.numpy(), rtol=1e-7, atol=1e-5)