回调 Relay 全局变量#
tvm.relay.expr.Call
可以在模块中回调全局变量。
比如,定义 add
算子:
from tvm import relay
from tvm.ir import IRModule
data = relay.var("data")
bias = relay.var("bias")
add_op = data + bias
初始化 Relay 模块:
mod = IRModule()
创建并绑定 add
全局函数到 mod
:
mod['AddFunc'] = relay.Function([data, bias], add_op)
下面定义三个变量用于定义“连加”运算:
a, b, c = [relay.var(name) for name in "abc"]
获取全局变量 add
:
add_gvar = mod.get_global_var('AddFunc')
定义“连加”运算:
add_01 = relay.Call(add_gvar, [a, b])
add_012 = relay.Call(add_gvar, [c, add_01])
绑定到 mod
:
mod['main'] = relay.Function([a, b, c], add_012)
print(mod)
def @AddFunc(%data, %bias) {
add(%data, %bias)
}
def @main(%a, %b, %c) {
%0 = @AddFunc(%a, %b);
@AddFunc(%c, %0)
}