VM instrument#

import numpy as np
import tvm

from tvm import relax
from tvm.relax.testing import nn
from tvm.relax.testing.lib_comparator import LibCompareVMInstrument
def get_exec(data_shape):
    builder = relax.BlockBuilder()
    weight1_np = np.random.randn(64, 64).astype("float32")
    weight2_np = np.random.randn(64, 64).astype("float32")

    with builder.function("main"):
        model = nn.Sequential(
            nn.Linear(data_shape[1], weight1_np.shape[0], bias=False),
            nn.ReLU(),
            nn.Linear(weight2_np.shape[0], weight2_np.shape[1], bias=False),
            nn.ReLU(),
        )
        data = nn.Placeholder(data_shape, name="data")
        output = model(data)
        params = [data] + model.parameters()
        builder.emit_func_output(output, params=params)

    mod = builder.get()

    params = {"linear_weight": weight1_np, "linear_weight1": weight2_np}
    mod = relax.transform.BindParams("main", params)(mod)

    target = "llvm"
    return tvm.compile(mod, target)


def get_exec_int32(data_shape):
    builder = relax.BlockBuilder()

    with builder.function("main"):
        model = nn.ReLU()
        data = nn.Placeholder(data_shape, dtype="int32", name="data")
        output = model(data)
        params = [data] + model.parameters()
        builder.emit_func_output(output, params=params)

    mod = builder.get()
    target = "llvm"
    return tvm.compile(mod, target)

测试 conv2d_cpu#

data_np = np.random.randn(1, 64).astype("float32")
ex = get_exec(data_np.shape)
vm = relax.VirtualMachine(ex, tvm.cpu())
hit_count = {}

def instrument(func, name, before_run, ret_val, *args):
    if (name, before_run) not in hit_count:
        hit_count[(name, before_run)] = 0
    hit_count[(name, before_run)] += 1
    assert callable(func)
    if before_run:
        assert ret_val is None
    if name == "matmul":
        return relax.VMInstrumentReturnKind.SKIP_RUN

vm.set_instrument(instrument)
vm["main"](tvm.nd.array(data_np))
assert hit_count[("matmul", True)] == 2
assert ("matmul", False) not in hit_count
assert hit_count[("relu", True)] == 2
assert hit_count[("relu", False)] == 2
/tmp/ipykernel_4023119/2478082253.py:17: UserWarning: Returning type `vm.Storage` which is not registered via register_object, fallback to Object
  vm["main"](tvm.nd.array(data_np))

测试 LibCompareVMInstrument#

data_np = np.random.randn(1, 64).astype("int32")
ex = get_exec_int32(data_np.shape)
vm = relax.VirtualMachine(ex, tvm.cpu())
# compare against library module
cmp = LibCompareVMInstrument(vm.module.imported_modules[0], tvm.cpu(), verbose=False)
vm.set_instrument(cmp)
vm["main"](tvm.nd.array(data_np))
/tmp/ipykernel_4023119/2973990986.py:7: UserWarning: Returning type `vm.Storage` which is not registered via register_object, fallback to Object
  vm["main"](tvm.nd.array(data_np))
<tvm.nd.NDArray shape=(1, 64), cpu:0>
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0, 0, 1,
        1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 2,
        0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
      dtype=int32)