构建 Relay 模型

构建 Relay 模型#

直接构建:

import testing
import tvm
import tvm.testing
from tvm import relay
from tvm.target.target import Target
from tvm.relay.backend import Runtime, Executor, graph_executor_codegen
def add(shape, dtype):
    lhs = relay.var("A", shape=shape, dtype=dtype)
    rhs = relay.var("B", shape=shape, dtype=dtype)
    out = relay.add(lhs, rhs)
    expr = relay.Function((lhs, rhs), out)
    mod = tvm.IRModule.from_expr(expr)
    return mod
mod = add((1, 8), "float32")
mod.show()
def @main(%A: Tensor[(1, 8), float32], %B: Tensor[(1, 8), float32]) {
  add(%A, %B)
}
target = tvm.target.Target("llvm")
target, target_host = tvm.target.Target.canon_target_and_host(target)
mod, _ = relay.optimize(mod, target)
grc = graph_executor_codegen.GraphExecutorCodegen(None, target)
_, lowered_funcs, _ = grc.codegen(mod, mod["main"])
_ = relay.backend._backend.build(lowered_funcs, target)
"""Test to build a nn model and get schedule_record from build_module"""
from tvm.relay import testing
def check_schedule(executor):
    for func_name, func_meta in executor.function_metadata.items():
        # check converted op only
        if "main" not in func_name:
            primfunc = list(func_meta.relay_primfuncs.values())[0]
            # make sure schedule is well-stored in function metadata
            assert "schedule" in primfunc.attrs
            sch = primfunc.attrs["schedule"]
            assert len(sch.schedule_record) == len(sch.primitive_record)

relay_mod, params = testing.mobilenet.get_workload(batch_size=1, dtype="float32")
target_llvm = tvm.target.Target("llvm")
config = {"te.keep_schedule_record": True}

with tvm.transform.PassContext(opt_level=3, config=config):
    aot_executor_factory = relay.build(
        relay_mod,
        target_llvm,
        runtime=Runtime("cpp"),
        executor=Executor("aot"),
        params=params,
    )
    graph_executor_factory = relay.build(
        relay_mod,
        target_llvm,
        params=params,
    )

check_schedule(aot_executor_factory)
check_schedule(graph_executor_factory)
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.