回调 Relay 全局变量

回调 Relay 全局变量#

tvm.relay.expr.Call 可以在模块中回调全局变量。

比如,定义 add 算子:

from tvm import relay
from tvm.ir import IRModule

data = relay.var("data")
bias = relay.var("bias")
add_op = data + bias

初始化 Relay 模块:

mod = IRModule()

创建并绑定 add 全局函数到 mod

mod['AddFunc'] = relay.Function([data, bias], add_op)

下面定义三个变量用于定义“连加”运算:

a, b, c = [relay.var(name) for name in "abc"]

获取全局变量 add

add_gvar = mod.get_global_var('AddFunc')

定义“连加”运算:

add_01 = relay.Call(add_gvar, [a, b])
add_012 = relay.Call(add_gvar, [c, add_01])

绑定到 mod

mod['main'] = relay.Function([a, b, c], add_012)
print(mod)
def @AddFunc(%data, %bias) {
  add(%data, %bias)
}

def @main(%a, %b, %c) {
  %0 = @AddFunc(%a, %b);
  @AddFunc(%c, %0)
}