VTA Demo

VTA Demo#

import set_env
from PIL import Image
import numpy as np
import tvm
from tvm import rpc, autotvm, relay
from tvm.ir.transform import PassContext
from tvm_book.transforms.common import FuseTransform
from tvm_book.transforms import graphpack
import vta
from vta.testing import simulator
# Make sure that TVM was compiled with RPC=1
assert tvm.runtime.enabled("rpc")

加载前端模型:

import torch
from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 50, 1, 1, 0, bias=False, groups=1)

    def forward(self, x):
        x = self.conv1(x)
        # 下面两个等同
        # x = x.view(x.size(0), -1)
        x = torch.reshape(x, (x.size(0), -1))
        # x = self.dense(x)
        return x
# dev = tvm.cpu()
# target = "llvm"
input_shape = 1, 3, 8, 8
input_data = torch.rand(input_shape).float()
input_shapes = [("data", input_shape)]
model = Model().eval()
trace_model = torch.jit.trace(model, [input_data.clone()])
trace_model = trace_model.float().eval()
from tvm.relay.testing import run_opt_pass
import tvm
from tvm import relay
from tvm.relay import op
from tvm.relay import ExprMutator
from vta.top import graphpack
env = vta.get_env()
with autotvm.tophub.context(env.target):
    # 开始前端编译
    mod, params = relay.frontend.from_pytorch(trace_model, input_shapes)
    with PassContext(opt_level=3):
        with relay.quantize.qconfig(global_scale=8.0,
                                    skip_conv_layers=[]):
            qmod = relay.quantize.quantize(mod, params=params)
        qmod.show()
        # 对 VTA target 进行 graph packing 和 constant folding
        assert env.BLOCK_IN == env.BLOCK_OUT
        # anf = run_opt_pass(mod["main"], transform.ToANormalForm())
        # anf = graphpack.get_subgraph(
        #     mod["main"], 
        #     start_name="cast", 
        #     stop_name="cast", 
        #     start_name_idx=None, 
        #     stop_name_idx=None, 
        #     count_meta = {}
        # )
        # print(anf)
        # relay_prog = graphpack.graph_pack(
        #     mod["main"],
        #     env.BATCH,
        #     env.BLOCK_OUT,
        #     env.WGT_WIDTH,
        #     start_name="nn.conv2d", #pack_dict[model][0],
        #     stop_name="multiply",
        #     device_annot=(env.TARGET == "intelfocl"),
        # )
        # with vta.build_config(
        #     opt_level=3,
        #     disabled_pass={"AlterOpLayout",
        #                     "tir.CommonSubexprElimTIR"}
        # ):
        #     lib = relay.build(relay_prog,
        #                         target=env.target,
        #                         params=params)
        
        # # # 将 inference library 发送到远程 RPC 服务器
        # # lib.export_library("graphlib.tar")
        # # remote.upload("graphlib.tar")
        # # loaded_lib = remote.load_module("graphlib.tar")
from tvm.relay.dataflow_pattern import is_op, wildcard, is_constant

def make_preprocess_pattern():
    r = is_op("multiply")(wildcard(), is_constant())
    r = is_op("round")(r)
    r = is_op("clip")(r)
    r = is_op("cast")(r)
    return r

def make_conv2d_bias_pattern():
    r = is_op("nn.conv2d")(wildcard(), is_constant())
    r = is_op("add")(r, is_constant())
    r = is_op("right_shift")(r, is_constant())
    r = is_op("clip")(r)
    r = is_op("cast")(r)
    return r

# def make_output_pattern():
#     r = is_op("cast")(wildcard())
#     r = is_op("multiply")(r, wildcard())
#     return r

def make_reshape_output_pattern():
    x = wildcard()
    x = is_op("annotation.stop_fusion")(x) | x
    r = is_op("reshape")(x)
    r = is_op("cast")(r)
    r = is_op("multiply")(r, is_constant())
    return r
compiler_name = "pack_special"
pattern_table = [
    (f"{compiler_name}.preprocess", make_preprocess_pattern()),
    (f"{compiler_name}.conv2d_bias", make_conv2d_bias_pattern()),
    (f"{compiler_name}.reshape_output", make_reshape_output_pattern()),
]
merge_passes = tvm.transform.Sequential([
    relay.transform.InferType(),
    relay.transform.MergeComposite(pattern_table),
    # # relay.transform.AnnotateTarget([compiler_name]),
    # relay.transform.PartitionGraph(),
    relay.transform.InferType(),
    FuseTransform(),
    relay.transform.InferType(),
])
with tvm.transform.PassContext(opt_level=3):
    run_qmod = merge_passes(qmod)
run_qmod.show()
from tvm_book.transforms import graphpack
class ExprPack(ExprMutator):
    """Visitor to perform graph packing on an AST."""

    def __init__(self, bfactor, cfactor, weight_bits):
        self.bfactor = bfactor
        self.cfactor = cfactor
        self.weight_bits = weight_bits
        self.start_pack = False
        # Cache Operator the algorithm matches against.
        self.conv2d = op.op.get("nn.conv2d")
        self.conv2d_transpose = op.op.get("nn.conv2d_transpose")
        self.add = op.op.get("add")
        self.multiply = op.op.get("multiply")
        self.bias_add = op.op.get("nn.bias_add")
        self.pad = op.op.get("nn.pad")
        self.upsampling = op.op.get("nn.upsampling")
        self.reshape = op.op.get("reshape")
        self.number_of_conv2d = 0
        self.unpack_transpose = True
        super().__init__()

    def visit_call(self, call):
        oshape = graphpack._get_tensor_shape(call)
        odtype = graphpack._get_tensor_type(call)
        input_types = [arg.checked_type for arg in call.args]
        args = [self.visit(arg) for arg in call.args]

        if call.op == self.conv2d and odtype == "int32":
            self.number_of_conv2d += 1
            assert 8 % self.weight_bits == 0
            w_lanes = 8 // self.weight_bits
            data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor)
            kernel_layout = "OIHW%do%di" % (self.cfactor, self.cfactor)
            data, weight = args
            data_shape = graphpack._to_shape(input_types[0].shape)
            kernel_shape = graphpack._to_shape(input_types[1].shape)
            channels = call.attrs.channels
            weight, kernel_shape, channels = graphpack._weight_shape_match(
                weight, kernel_shape, channels, self.cfactor
            )
            kernel = graphpack._pack_weight(weight, kernel_shape, self.cfactor)
            # insert bit packing when necessary
            if w_lanes != 1:
                assert 8 % w_lanes == 0
                kernel = op.bitpack(kernel, lanes=w_lanes)

            conv2d = op.nn.conv2d(
                data,
                kernel,
                strides=call.attrs.strides,
                padding=call.attrs.padding,
                dilation=call.attrs.dilation,
                groups=call.attrs.groups,
                channels=channels,
                kernel_size=call.attrs.kernel_size,
                data_layout=data_layout,
                kernel_layout=kernel_layout,
                out_dtype=call.attrs.out_dtype,
            )
            return conv2d
        return relay.Call(self.visit(call.op), args, call.attrs)
# bfactor = 1
# cfactor = 16
func = run_qmod["pack_special.conv2d_bias_1"]
# call = func.body
# oshape = graphpack._get_tensor_shape(call)
# odtype = graphpack._get_tensor_type(call)
# input_types = [arg.checked_type for arg in call.args]
transform = ExprPack(env.BATCH, env.BLOCK_OUT, env.WGT_WIDTH,)
func = transform.visit(func)
param = func.params[0]
oshape = graphpack._get_tensor_shape(param)
odtype = graphpack._get_tensor_type(param)
# input_types = [arg.checked_type for arg in call.args]
checked_type = param.checked_type
checked_type.shape
relay.var(param.name_hint, shape=oshape, dtype=odtype)
# binds = {
#     for param in func.params
# }
# new_func = relay.bind(func, binds)
# transform = PreprocessPack(1, 16)
# func = transform.visit(run_qmod["pack_special.preprocess_0"])
# from tvm.contrib import graph_executor, download
# # 下载 ImageNet categories
# categ_url = "https://github.com/uwsampl/web-data/raw/main/vta/models"
# categ_fn = "synset.txt"
# download.download(f"{categ_url}/{categ_fn}", categ_fn)
# synset = eval(open(categ_fn).read())
# # 准备用于推理的测试图像
# image = Image.open("tests/cat.jpg").resize((32, 32))
# # plt.imshow(image)
# # plt.show()
# image = np.array(image) - np.array([123.0, 117.0, 104.0])
# image /= np.array([58.395, 57.12, 57.375])
# image = image.transpose((2, 0, 1))
# image = image[np.newaxis, :]
# image = np.repeat(image, env.BATCH, axis=0)

# with autotvm.tophub.context(env.target):
#     # 生成图执行器(graph executor) `m`。
#     m = graph_executor.GraphModule(lib["default"](tvm.ext_dev(0)))
#     # 设置网络参数和输入
#     m.set_input(**params)
#     m.set_input("data", image)
#     num = 4  # 为单个度量运行模块的次数
#     rep = 3  # 测量的数量(由此得出 std dev)
#     timer = m.module.time_evaluator("run",
#                                     tvm.ext_dev(0),
#                                     number=num,
#                                     repeat=rep)
#     simulator.clear_stats()
#     timer()
#     sim_stats = simulator.stats()
#     print("\nExecution statistics:")
#     for k, v in sim_stats.items():
#         # 由于多次执行 workload,需要 normalize 统计数据。
#         # 注意,总是有一次 warm up 运行
#         # 因此,将整体统计数据除以 (num * rep + 1)
#         print(f"\t{k:<16}: {v // (num * rep + 1):>16}")
# tvm_output = m.get_output(0)
# lib.ir_mod.show()