注解目标设备#
import os
import sys
import numpy as np
import tvm
import tvm.relay.testing
from tvm.relay import transform
from tvm import relay
from tvm import runtime
from tvm.contrib import utils
注解 DNNL#
def annotated(dtype, ishape, w1shape):
data = relay.var("data", shape=(ishape), dtype=dtype)
weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype)
depthwise_conv2d_1 = relay.nn.conv2d(
data, weight1, kernel_size=(3, 3), padding=(1, 1), groups=32
)
depthwise_conv2d_2 = relay.nn.conv2d(
depthwise_conv2d_1, weight1, kernel_size=(3, 3), padding=(1, 1), groups=32
)
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
f = relay.Function([data, weight1], out)
mod = tvm.IRModule.from_expr(f)
return mod
dtype = "float32"
ishape = (1, 32, 14, 14)
w1shape = (32, 1, 3, 3)
mod = annotated(dtype, ishape, w1shape)
mod = transform.AnnotateTarget("dnnl")(mod)
mod = relay.transform.InferType()(mod)
mod = transform.PartitionGraph()(mod)
print(mod)
def @main(%data: Tensor[(1, 32, 14, 14), float32] /* ty=Tensor[(1, 32, 14, 14), float32] */, %weight1: Tensor[(32, 1, 3, 3), float32] /* ty=Tensor[(32, 1, 3, 3), float32] */) -> Tensor[(1, 32, 14, 14), float32] {
%0 = @tvmgen_default_dnnl_main_0(%data, %weight1) /* ty=Tensor[(1, 32, 14, 14), float32] */;
%1 = @tvmgen_default_dnnl_main_3(%0, %weight1) /* ty=Tensor[(1, 32, 14, 14), float32] */;
@tvmgen_default_dnnl_main_2(%0, %1) /* ty=Tensor[(1, 32, 14, 14), float32] */
}
def @tvmgen_default_dnnl_main_0(%dnnl_0_i0: Tensor[(1, 32, 14, 14), float32] /* ty=Tensor[(1, 32, 14, 14), float32] */, %dnnl_0_i1: Tensor[(32, 1, 3, 3), float32] /* ty=Tensor[(32, 1, 3, 3), float32] */, Compiler="dnnl", Primitive=1, Inline=1, global_symbol="tvmgen_default_dnnl_main_0") -> Tensor[(1, 32, 14, 14), float32] {
nn.conv2d(%dnnl_0_i0, %dnnl_0_i1, padding=[1, 1, 1, 1], groups=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 14, 14), float32] */
}
def @tvmgen_default_dnnl_main_2(%dnnl_2_i0: Tensor[(1, 32, 14, 14), float32] /* ty=Tensor[(1, 32, 14, 14), float32] */, %dnnl_2_i1: Tensor[(1, 32, 14, 14), float32] /* ty=Tensor[(1, 32, 14, 14), float32] */, Compiler="dnnl", Primitive=1, Inline=1, global_symbol="tvmgen_default_dnnl_main_2") -> Tensor[(1, 32, 14, 14), float32] {
add(%dnnl_2_i0, %dnnl_2_i1) /* ty=Tensor[(1, 32, 14, 14), float32] */
}
def @tvmgen_default_dnnl_main_3(%dnnl_3_i0: Tensor[(1, 32, 14, 14), float32] /* ty=Tensor[(1, 32, 14, 14), float32] */, %dnnl_3_i1: Tensor[(32, 1, 3, 3), float32] /* ty=Tensor[(32, 1, 3, 3), float32] */, Compiler="dnnl", Primitive=1, Inline=1, global_symbol="tvmgen_default_dnnl_main_3") -> Tensor[(1, 32, 14, 14), float32] {
nn.conv2d(%dnnl_3_i0, %dnnl_3_i1, padding=[1, 1, 1, 1], groups=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 14, 14), float32] */
}
注解多后端设备#
@tvm.ir.register_op_attr("nn.relu", "target.test")
def relu(expr): # pylint: disable=unused-variable
return True
def before():
x = relay.var("x", shape=(10, 10))
r = relay.nn.relu(x)
a_1 = relay.abs(r)
a_2 = relay.abs(r)
out = relay.add(a_1, a_2)
f = relay.Function([x], out)
return tvm.IRModule.from_expr(f)
result = transform.AnnotateTarget("test")(before())
print(result)
def @main(%x: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
%0 = annotation.compiler_begin(%x, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%1 = nn.relu(%0) /* ty=Tensor[(10, 10), float32] */;
%2 = annotation.compiler_end(%1, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%3 = annotation.compiler_begin(%2, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%4 = abs(%3) /* ty=Tensor[(10, 10), float32] */;
%5 = annotation.compiler_end(%4, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%6 = annotation.compiler_end(%1, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%7 = annotation.compiler_begin(%6, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%8 = abs(%7) /* ty=Tensor[(10, 10), float32] */;
%9 = annotation.compiler_end(%8, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%10 = annotation.compiler_begin(%5, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%11 = annotation.compiler_begin(%9, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%12 = add(%10, %11) /* ty=Tensor[(10, 10), float32] */;
annotation.compiler_end(%12, compiler="default") /* ty=Tensor[(10, 10), float32] */
}
target = "test_type_propagation"
@tvm.ir.register_op_attr("nn.relu", "target." + target)
def relu(expr): # pylint: disable=unused-variable
return expr.args[0].checked_type.dtype == "float32"
def before():
x = relay.var("x", shape=(10, 10))
r = relay.nn.relu(x)
out = relay.nn.relu(r)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod
res = transform.AnnotateTarget(target, True)(before())
res
def @main(%x: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
%0 = annotation.compiler_begin(%x, compiler="test_type_propagation") /* ty=Tensor[(10, 10), float32] */;
%1 = nn.relu(%0) /* ty=Tensor[(10, 10), float32] */;
%2 = annotation.compiler_end(%1, compiler="test_type_propagation") /* ty=Tensor[(10, 10), float32] */;
%3 = annotation.compiler_begin(%2, compiler="test_type_propagation") /* ty=Tensor[(10, 10), float32] */;
%4 = nn.relu(%3) /* ty=Tensor[(10, 10), float32] */;
annotation.compiler_end(%4, compiler="test_type_propagation") /* ty=Tensor[(10, 10), float32] */
}
read & write#
target = "relu"
@tvm.ir.register_op_attr("nn.relu", "target." + target)
def annotate(expr):
return True
def before():
ref = relay.expr.RefCreate(relay.const(1.0))
r = relay.expr.RefWrite(ref, relay.nn.relu(relay.expr.RefRead(ref)))
return tvm.IRModule.from_expr(r)
def after(annotate_non_call_ops):
co = relay.const(1.0)
if annotate_non_call_ops:
co = relay.annotation.compiler_begin(co, "default")
ref = relay.expr.RefCreate(co)
ref1 = ref
if annotate_non_call_ops:
ref = relay.annotation.compiler_end(ref, "default")
ref = relay.annotation.compiler_begin(ref, "default")
ref1 = relay.annotation.compiler_end(ref1, "default")
ref1 = relay.annotation.compiler_begin(ref1, "default")
read = relay.expr.RefRead(ref1)
if annotate_non_call_ops:
read = relay.annotation.compiler_end(read, "default")
beg = relay.annotation.compiler_begin(read, target)
relu = relay.nn.relu(beg)
end = relay.annotation.compiler_end(relu, target)
if annotate_non_call_ops:
end = relay.annotation.compiler_begin(end, "default")
r = relay.expr.RefWrite(ref, end)
if annotate_non_call_ops:
r = relay.annotation.compiler_end(r, "default")
return tvm.IRModule.from_expr(r)
result = transform.AnnotateTarget(target)(before())
result
def @main() -> () {
%0 = annotation.compiler_begin(1f /* ty=float32 */, compiler="default") /* ty=float32 */;
%1 = ref(%0);
%2 = annotation.compiler_end(%1, compiler="default") /* ty=ref(float32) */;
%3 = annotation.compiler_begin(%2, compiler="default") /* ty=ref(float32) */;
%4 = annotation.compiler_end(%1, compiler="default") /* ty=ref(float32) */;
%5 = annotation.compiler_begin(%4, compiler="default") /* ty=ref(float32) */;
%6 = ref_read(%5);
%7 = annotation.compiler_end(%6, compiler="default") /* ty=float32 */;
%8 = annotation.compiler_begin(%7, compiler="relu") /* ty=float32 */;
%9 = nn.relu(%8) /* ty=float32 */;
%10 = annotation.compiler_end(%9, compiler="relu") /* ty=float32 */;
%11 = annotation.compiler_begin(%10, compiler="default") /* ty=float32 */;
%12 = ref_write(%3, %11);
annotation.compiler_end(%12, compiler="default") /* ty=() */
}
result = transform.AnnotateTarget(target, False)(before())
result
def @main() -> () {
%0 = ref(1f /* ty=float32 */);
%1 = ref_read(%0);
%2 = annotation.compiler_begin(%1, compiler="relu") /* ty=float32 */;
%3 = nn.relu(%2) /* ty=float32 */;
%4 = annotation.compiler_end(%3, compiler="relu") /* ty=float32 */;
ref_write(%0, %4)
}
tuple#
target = "test_tuple"
@tvm.ir.register_op_attr("nn.relu", "target." + target)
def relu(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("concatenate", "target." + target)
def concatenate(expr): # pylint: disable=unused-variable
return True
def before():
x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(10, 5))
a_1 = relay.nn.relu(x)
a_2 = relay.nn.relu(y)
out = relay.concatenate((a_1, a_2), axis=1)
f = relay.Function([x, y], out)
mod = tvm.IRModule.from_expr(f)
return mod
def after(annotate_non_call_ops):
x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(10, 5))
cb_1 = relay.annotation.compiler_begin(x, target)
cb_2 = relay.annotation.compiler_begin(y, target)
a_1 = relay.nn.relu(cb_1)
a_2 = relay.nn.relu(cb_2)
ce_1 = relay.annotation.compiler_end(a_1, target)
ce_2 = relay.annotation.compiler_end(a_2, target)
if annotate_non_call_ops:
cb_3 = relay.annotation.compiler_begin(ce_1, target)
cb_4 = relay.annotation.compiler_begin(ce_2, target)
tup = relay.Tuple([cb_3, cb_4])
ce_3 = relay.annotation.compiler_end(tup, target)
else:
ce_3 = relay.Tuple([ce_1, ce_2])
cb_3 = relay.annotation.compiler_begin(ce_3, target)
out = relay.op._make.concatenate(cb_3, 1)
ce_4 = relay.annotation.compiler_end(out, target)
f = relay.Function([x, y], ce_4)
mod = tvm.IRModule.from_expr(f)
return mod
for annotate_non_call_ops in [False, True]:
result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
expected = transform.InferType()(after(annotate_non_call_ops))
assert tvm.ir.structural_equal(expected, result)
composite_function#
def before():
a = relay.var("a", shape=(10, 10))
b = relay.var("b", shape=(10, 10))
# add_relu function
in_1 = relay.var("in_1", shape=(10, 10))
in_2 = relay.var("in_2", shape=(10, 10))
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.with_attr("Composite", "test.add_relu")
# merged function
r = relay.Call(add_relu, [a, b])
f = relay.Function([a, b], r)
mod = tvm.IRModule.from_expr(f)
return mod
def after():
a = relay.var("a", shape=(10, 10))
b = relay.var("b", shape=(10, 10))
# add_relu function
in_1 = relay.var("in_1", shape=(10, 10))
in_2 = relay.var("in_2", shape=(10, 10))
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.with_attr("Composite", "test.add_relu")
# merged function
cb_1 = relay.annotation.compiler_begin(a, "test")
cb_2 = relay.annotation.compiler_begin(b, "test")
r = relay.Call(add_relu, [cb_1, cb_2])
ce_1 = relay.annotation.compiler_end(r, "test")
f = relay.Function([a, b], ce_1)
mod = tvm.IRModule.from_expr(f)
return mod
result = transform.AnnotateTarget("test")(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)
double_target#
@tvm.ir.register_op_attr("nn.relu", "target.double.A")
def relu(expr): # pylint: disable=unused-variable
return True
def before():
x = relay.var("x", shape=(10, 5))
a_1 = relay.nn.relu(x)
mod = tvm.IRModule.from_expr(a_1)
return mod
for annotate_non_call_ops in [True, False]:
mod = before()
mod1 = transform.AnnotateTarget("double.A", annotate_non_call_ops)(mod)
mod2 = transform.AnnotateTarget("double.A", annotate_non_call_ops)(mod1)
assert tvm.ir.structural_equal(mod1, mod2)
different_target#
@tvm.ir.register_op_attr("nn.relu", "target.different.A")
def relu(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("add", "target.different.B")
def relu(expr): # pylint: disable=unused-variable
return True
def before():
x = relay.var("x", shape=(10, 5))
a_1 = relay.nn.relu(x)
b_1 = relay.add(a_1, a_1)
mod = tvm.IRModule.from_expr(b_1)
return mod
for annotate_non_call_ops in [True, False]:
mod = before()
mod1 = transform.AnnotateTarget("different.A", annotate_non_call_ops)(mod)
mod1 = transform.AnnotateTarget("different.B", annotate_non_call_ops)(mod1)
mod2 = transform.AnnotateTarget(["different.A", "different.B"], annotate_non_call_ops)(mod)
assert tvm.ir.structural_equal(mod1, mod2)
multiple_runs#
@tvm.ir.register_op_attr("nn.relu", "target.A")
def relu(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("add", "target.B")
def add(expr): # pylint: disable=unused-variable
return True
def before():
x = relay.var("x", shape=(10, 5))
a_1 = relay.nn.relu(x)
a_2 = relay.abs(a_1)
a_3 = relay.nn.relu(a_1)
out = relay.add(a_2, a_3)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod
for annotate_non_call_ops in [True, False]:
mod = transform.AnnotateTarget("A", annotate_non_call_ops)(before())
mod = transform.AnnotateTarget("B", annotate_non_call_ops)(mod)
expected = transform.AnnotateTarget(["A", "B"], annotate_non_call_ops)(before())
assert tvm.ir.structural_equal(expected, mod)
ends_with_tuple#
trgt = "clip"
@tvm.ir.register_op_attr("clip", "target." + trgt)
def relu(expr): # pylint: disable=unused-variable
return True
def get_model(get_item):
"""Return a model"""
a = relay.var("a", shape=(1, 16, 16, 4), dtype="uint8")
z = relay.op.clip(a, 0, 255)
b = relay.op.clip(z, 0, 15)
c = relay.op.clip(z, 16, 31)
t = relay.Tuple((c, b))
tgi = relay.TupleGetItem(t, 1) if get_item else t
foo = relay.Function([a], tgi)
return tvm.IRModule.from_expr(tgi)
def get_expected(annotate_non_call_ops, get_item):
a_ = relay.var("a", shape=(1, 16, 16, 4), dtype="uint8")
a = relay.annotation.compiler_begin(a_, trgt)
z = relay.op.clip(a, 0, 255)
z1 = relay.annotation.compiler_end(z, trgt)
z1 = relay.annotation.compiler_begin(z1, trgt)
b = relay.op.clip(z1, 0, 15)
b = relay.annotation.compiler_end(b, trgt)
b = relay.annotation.compiler_begin(b, trgt) if annotate_non_call_ops else b
z2 = relay.annotation.compiler_end(z, trgt)
z2 = relay.annotation.compiler_begin(z2, trgt)
c = relay.op.clip(z2, 16, 31)
c = relay.annotation.compiler_end(c, trgt)
c = relay.annotation.compiler_begin(c, trgt) if annotate_non_call_ops else c
t = relay.Tuple((c, b))
t = relay.annotation.compiler_end(t, trgt) if annotate_non_call_ops else t
if get_item:
t = relay.annotation.compiler_begin(t, trgt) if annotate_non_call_ops else t
tgi = relay.TupleGetItem(t, 1)
tgi = relay.annotation.compiler_end(tgi, trgt) if annotate_non_call_ops else tgi
else:
tgi = t
foo = relay.Function([a_], tgi)
return tvm.IRModule.from_expr(foo)
for get_item in [True, False]:
for annotate_non_call_ops in [False, True]:
mod = get_model(get_item)
mod = transform.AnnotateTarget("clip", annotate_non_call_ops)(mod)
expected = transform.InferType()(get_expected(annotate_non_call_ops, get_item))
assert tvm.ir.structural_equal(expected, mod)
注解目标-其他#
def test_if_else():
target = "test_if_else"
@tvm.ir.register_op_attr("equal", "target." + target)
def relu(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("tanh", "target." + target)
def tanh(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("sigmoid", "target." + target)
def sigmoid(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("erf", "target." + target)
def erf(expr): # pylint: disable=unused-variable
return True
"""Test that If-else nodes compiles correctly when surrounded by supported nodes."""
def before():
data = relay.var("data", shape=(1, 32))
eq1 = relay.var("e1", shape=[], dtype="float32")
eq2 = relay.var("e2", shape=[], dtype="float32")
eq = relay.equal(eq1, eq2)
true_branch = relay.tanh(data)
false_branch = relay.sigmoid(data)
ife = relay.If(eq, true_branch, false_branch)
out = relay.erf(ife)
func = relay.Function([data, eq1, eq2], out)
mod = tvm.IRModule.from_expr(func)
return mod
def after():
data = relay.var("data", shape=(1, 32))
eq1 = relay.var("e1", shape=[], dtype="float32")
eq2 = relay.var("e2", shape=[], dtype="float32")
cb_1 = relay.annotation.compiler_begin(eq1, target)
cb_2 = relay.annotation.compiler_begin(eq2, target)
equality_condition = relay.equal(cb_1, cb_2)
ce_1 = relay.annotation.compiler_end(equality_condition, target)
# if condition
cb_3 = relay.annotation.compiler_begin(data, target)
true_branch = relay.tanh(cb_3)
ce_2 = relay.annotation.compiler_end(true_branch, target)
# else condition
cb_4 = relay.annotation.compiler_begin(data, target)
false_branch = relay.sigmoid(cb_4)
ce_3 = relay.annotation.compiler_end(false_branch, target)
if_condition = relay.If(ce_1, ce_2, ce_3)
cb_5 = relay.annotation.compiler_begin(if_condition, target)
erf_out = relay.erf(cb_5)
ce_4 = relay.annotation.compiler_end(erf_out, target)
func = relay.Function([data, eq1, eq2], ce_4)
mod = tvm.IRModule.from_expr(func)
return mod
expected = transform.InferType()(after())
for annotate_non_call_ops in [True, False]:
result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
assert tvm.ir.structural_equal(expected, result)
def test_while_let():
target = "test_while_let"
@tvm.ir.register_op_attr("less", "target." + target)
def less(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("add", "target." + target)
def add(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("zeros_like", "target." + target)
def zeros_like(expr): # pylint: disable=unused-variable
return True
"""Test that let nodes compiles correctly when surrounded by other nodes."""
def before():
var1 = relay.var("var1", shape=(2,))
var2 = relay.var("var2", shape=(), dtype="int32")
var3 = relay.var("var3", shape=(2,))
cond = relay.less(var2, relay.const(10, dtype="int32"))
loop = relay.var("while_loop")
ii = var2 + relay.const(1, dtype="int32")
ss = var3 + var1
true_branch = loop(ii, ss)
ife = relay.If(cond, true_branch, var3)
func_1 = relay.Function([var2, var3], ife)
ret = relay.Let(loop, func_1, loop(relay.const(0, dtype="int32"), relay.zeros_like(var1)))
func_2 = relay.Function([var1], ret)
mod = tvm.IRModule.from_expr(func_2)
return mod
def after(annotate_non_call_ops):
var1 = relay.var("var1", shape=(2,))
var2 = relay.var("var2", shape=(), dtype="int32")
var3 = relay.var("var3", shape=(2,))
var4 = relay.const(10, dtype="int32")
cb_1 = relay.annotation.compiler_begin(var2, target)
cb_2 = relay.annotation.compiler_begin(var4, target)
less_condition = relay.less(cb_1, cb_2)
ce_1 = relay.annotation.compiler_end(less_condition, target)
loop = relay.var("while_loop")
# if condition
cb_3 = relay.annotation.compiler_begin(var2, target)
cb_4 = relay.annotation.compiler_begin(relay.const(1, dtype="int32"), target)
add_op_1 = relay.add(cb_3, cb_4)
ce_2 = relay.annotation.compiler_end(add_op_1, target)
cb_5 = relay.annotation.compiler_begin(ce_2, "default") if annotate_non_call_ops else ce_2
cb_6 = relay.annotation.compiler_begin(var3, target)
cb_7 = relay.annotation.compiler_begin(var1, target)
add_op_2 = relay.add(cb_6, cb_7)
ce_3 = relay.annotation.compiler_end(add_op_2, target)
cb_8 = relay.annotation.compiler_begin(ce_3, "default") if annotate_non_call_ops else ce_3
true_branch = loop(cb_5, cb_8) # while loop
ce_4 = (
relay.annotation.compiler_end(true_branch, "default")
if annotate_non_call_ops
else true_branch
)
if_condition = relay.If(ce_1, ce_4, var3)
const_1 = relay.const(0, dtype="int32")
cb_9 = (
relay.annotation.compiler_begin(const_1, "default")
if annotate_non_call_ops
else const_1
)
cb_10 = relay.annotation.compiler_begin(var1, target)
zeros_like = relay.zeros_like(cb_10)
ce_5 = relay.annotation.compiler_end(zeros_like, target)
cb_11 = relay.annotation.compiler_begin(ce_5, "default") if annotate_non_call_ops else ce_5
while_condition = loop(cb_9, cb_11)
ce_6 = (
relay.annotation.compiler_end(while_condition, "default")
if annotate_non_call_ops
else while_condition
)
func_1 = relay.Function([var2, var3], if_condition)
ret = relay.Let(loop, func_1, ce_6)
func_2 = relay.Function([var1], ret)
mod = tvm.IRModule.from_expr(func_2)
return mod
for annotate_non_call_ops in [False, True]:
result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
expected = transform.InferType()(after(annotate_non_call_ops))
assert tvm.ir.structural_equal(expected, result)
def test_if_free_vars():
target = "test_if_free_vars"
@tvm.ir.register_op_attr("equal", "target." + target)
def equal(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("sigmoid", "target." + target)
def sigmoid(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("erf", "target." + target)
def erf(expr): # pylint: disable=unused-variable
return True
"""Test that If-else nodes compiles correctly when surrounded by free variables"""
def before():
data = relay.var("data", shape=(1, 32))
eq1 = relay.var("e1", shape=[], dtype="float32")
eq2 = relay.var("e2", shape=[], dtype="float32")
eq = relay.equal(eq1, eq2)
true_branch = relay.zeros(shape=(1, 32), dtype="float32")
false_branch = relay.sigmoid(data)
ife = relay.If(eq, true_branch, false_branch)
out = relay.erf(ife)
func = relay.Function([data, eq1, eq2], out)
mod = tvm.IRModule.from_expr(func)
return mod
def after():
data = relay.var("data", shape=(1, 32))
eq1 = relay.var("e1", shape=[], dtype="float32")
eq2 = relay.var("e2", shape=[], dtype="float32")
cb_1 = relay.annotation.compiler_begin(eq1, target)
cb_2 = relay.annotation.compiler_begin(eq2, target)
equality_condition = relay.equal(cb_1, cb_2)
ce_1 = relay.annotation.compiler_end(equality_condition, target)
# if condition
true_branch = relay.zeros(shape=(1, 32), dtype="float32")
# else condition
cb_3 = relay.annotation.compiler_begin(data, target)
false_branch = relay.sigmoid(cb_3)
ce_2 = relay.annotation.compiler_end(false_branch, target)
if_condition = relay.If(ce_1, true_branch, ce_2)
cb_4 = relay.annotation.compiler_begin(if_condition, target)
erf_out = relay.erf(cb_4)
ce_3 = relay.annotation.compiler_end(erf_out, target)
func = relay.Function([data, eq1, eq2], ce_3)
mod = tvm.IRModule.from_expr(func)
return mod
for annotate_non_call_ops in [True, False]:
result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)
def test_free_vars_zeros():
target = "test_free_vars_zeros"
"""Test that free variables compile correctly on their own"""
def before():
func = relay.Function([], relay.zeros(shape=(0), dtype="float32"))
mod = tvm.IRModule.from_expr(func)
return mod
def after():
func = relay.Function([], relay.zeros(shape=(0), dtype="float32"))
mod = tvm.IRModule.from_expr(func)
return mod
result = transform.AnnotateTarget(target)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)
def test_empty_tuple():
target = "test_empty_tuple"
"""An empty tuple should behave just like a call with no args (see above test)."""
def before():
func = relay.Function([], relay.Tuple([]))
mod = tvm.IRModule.from_expr(func)
return mod
def after():
func = relay.Function([], relay.Tuple([]))
mod = tvm.IRModule.from_expr(func)
return mod
for annotate_non_call_ops in [True, False]:
result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)