TVM Relay 构建模块

TVM Relay 构建模块#

参考源代码:

  • tvm/python/tvm/relay/build_module.py

  • tvm/src/relay/backend/build_module.cc

了解一个关键函数:

TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName")
    .set_body([](TVMArgs args, TVMRetValue* rv) {
      Map<String, Constant> params = args[1];
      std::unordered_map<std::string, runtime::NDArray> params_;
      for (const auto& kv : params) {
        params_[kv.first] = kv.second->data;
      }
      *rv = relay::backend::BindParamsByName(args[0], params_);
    });

这段代码是 TVM 中的注册函数,用于将参数定到 Relay 表达式中。

函数名为 relay.build_module.BindParamsByName,是全局 (global) 函数。TVM 通过使用 TVM_REGISTER_GLOBAL 宏来注册该函数。其中,函数的实现是 Lambda 表达式。这个 Lambda 表达式的作用是将函数参数中传入的参数(args[1])所对应的值(Constant 类型的 map)中的 NDArray 数据提取出来,并通过 relay::backend::BindParamsByName API 将这些数据绑定到传入的 Relay 表达式(args[0])中。提取并绑定数据的过程通过名为 params_unordered_map 变量实现。

TVM 中的 Relay 表达式通常用于描述深度学习模型。这些表达式可以在后端编译和优化之后,生成机器代码,实现对模型的快速预测和推理。在这个过程中,定义模型、定义参数、编译和优化处理这些计算过程中的数据,都需要被有效且高效地绑定到一起。而在这项任务中,BindParamsByName API 担当一项重要的作用。

tvm.relay.create_executor#

tvm.relay.create_executor(kind="debug", mod=None, device=None, target="llvm", params=None):

  • kind: str:执行器(executor)的类型。debug 用于解释器(interpreter),graph 用于 graph executor,aot 用于 aot executor,vm 用于 virtual machine。

  • modIRModule):包含函数集合的 Relay 模块。

  • deviceDevice):执行代码的设备。

  • target:任何类似多目标的对象,请参见 tvm.target.Target.canon_multi_target()。对于同构(homogeneous)编译,唯一的构建目标(target)。对于异构(heterogeneous)编译,可能的构建目标的字典或列表。注意:虽然此 API 允许多个目标,但它不允许多个设备,因此尚不支持异构编译。

  • paramsdict[str, NDArray]):在推理期间不改变的 graph 的输入参数。

返回:tvm.relay.backend.interpreter.Executor

简单示例:

import tvm
import numpy as np

x = tvm.relay.var("x", tvm.relay.TensorType([1], dtype="float32"))
expr = tvm.relay.add(x, tvm.relay.Constant(tvm.nd.array(np.array([1], dtype="float32"))))
executor = tvm.relay.create_executor(
    kind="vm", mod=tvm.IRModule.from_expr(tvm.relay.Function([x], expr))
)
executor.evaluate()(np.array([2], dtype="float32"))
<tvm.nd.NDArray shape=(1,), cpu(0)>
array([3.], dtype=float32)

其中 params 主要用于:

def bind_params_by_name(func, params):
    """Bind params to function by name.
    This could be useful when assembling custom Relay optimization
    passes that involve constant folding.

    Parameters
    ----------
    func : relay.Function
        The function to bind parameters to.

    params : dict of str to NDArray
        Input parameters to the graph that do not change
        during inference time. Used for constant folding.

    Returns
    -------
    func : relay.Function
        The function with parameters bound
    """
    inputs = _convert_param_map(params)
    return _build_module.BindParamsByName(func, inputs)

...
raw_targets = Target.canon_multi_target(target)
if mod is None:
    mod = IRModule()
if device is not None:
    assert device.device_type == raw_targets[0].get_target_device_type()
else:
    # Derive the default device from the first target.
    device = _nd.device(raw_targets[0].get_target_device_type(), 0)

if params is not None:
    mod = IRModule.from_expr(bind_params_by_name(mod["main"], params))

看带有参数的例子:

import torch

def assert_shapes_match(tru, est):
    """Verfiy whether the shapes are equal"""
    if tru.shape != est.shape:
        msg = "Output shapes {} and {} don't match"
        raise AssertionError(msg.format(tru.shape, est.shape))

torch.set_grad_enabled(False)

class Conv2D(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 6, 7, bias=False)
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, *args):
        return self.softmax(self.conv(args[0]))
        
input_shape = [1, 3, 10, 10]
baseline_model = Conv2D().float().eval()
input_data = torch.rand(input_shape).float()
baseline_input = [input_data]
with torch.no_grad():
    baseline_outputs = baseline_model(*[input.clone() for input in baseline_input])
if isinstance(baseline_outputs, tuple):
    baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs)
else:
    baseline_outputs = (baseline_outputs.cpu().numpy(),)
trace = torch.jit.trace(baseline_model, [input.clone() for input in baseline_input])
trace = trace.float().eval()
input_names = [f"input{idx}" for idx, _ in enumerate(baseline_input)]
input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input]))
input_names = [f"input{idx}" for idx, _ in enumerate(baseline_input)]
input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input]))
mod, params = tvm.relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map=None)
print(mod["main"])
for arg in mod["main"].params[: len(input_names)]:
    assert arg.name_hint in input_names
compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp in baseline_input]))
fn (%input0: Tensor[(1, 3, 10, 10), float32] /* span=aten::_convolution_0.input0:0:0 */, %aten::_convolution_0.weight: Tensor[(6, 3, 7, 7), float32] /* span=aten::_convolution_0.weight:0:0 */) {
  %0 = nn.conv2d(%input0, %aten::_convolution_0.weight, padding=[0, 0, 0, 0], channels=6, kernel_size=[7, 7]) /* span=aten::_convolution_0:0:0 */;
  nn.softmax(%0, axis=1) /* span=aten::softmax_0:0:0 */
}
kind = "graph"
targets = ["llvm"]
# targets = ["llvm", "cuda"]
check_correctness = True
rtol = 1e-5
atol = 1e-5
expected_ops = ['nn.conv2d']
for target in targets:
    if not tvm.runtime.enabled(target):
        continue
    dev = tvm.device(target, 0)
    executor = tvm.relay.create_executor(
        kind, mod=mod, device=dev, target=target, params=params
    ).evaluate()
    result = executor(**compiled_input)
    if not isinstance(result, list):
        result = [result]

    for i, baseline_output in enumerate(baseline_outputs):
        output = result[i].asnumpy()
        assert_shapes_match(baseline_output, output)
        if check_correctness:
            np.testing.assert_allclose(baseline_output, output, rtol=rtol, atol=atol)
    
    def visit(op):
        if isinstance(op, tvm.ir.op.Op):
            if op.name in expected_ops:
                expected_ops.remove(op.name)

    tvm.relay.analysis.post_order_visit(mod["main"].body, visit)

    if expected_ops:
        msg = "TVM Relay do not contain expected ops {}"
        raise AssertionError(msg.format(expected_ops))
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.