Span 设置#

%cd ../..
import set_env
from tools.tag_span import _set_span, _create_span, _verify_structural_equal_with_span
/media/pc/data/lxw/ai/tvm-book/doc/read
import numpy as np

from tvm import relay, testing
from tvm.relay.frontend.common import StrAttrsDict, set_span
def test_key_is_present():
    attrs = StrAttrsDict({"a": 1})
    assert attrs.has_attr("a")


def test_key_is_not_present():
    attrs = StrAttrsDict({"a": 1})
    assert not attrs.has_attr("b")

test_key_is_present()
test_key_is_not_present()

测试 Span pass 开关#

def _res(should_fill):
    if should_fill:
        with testing.enable_span_filling():
            return set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
    else:
        with testing.disable_span_filling():
            return set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
disable = relay.var("x", shape=(1, 64, 56, 56))
enable = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var"))

_verify_structural_equal_with_span(_res(False), disable)
_verify_structural_equal_with_span(_res(True), enable)
print(enable)
free_var %x: Tensor[(1, 64, 56, 56), float32] /* span=x_var:0:0 */;
%x

应该标记所有没有 Span 的表达式,并在表达式被标记为 Span 时停止。

测试内建元组 Span#

def _res():
    a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
    b = relay.const(np.zeros([1, 1, 1]), dtype="int64")
    return set_span(tuple([a, b]), "tuple")

def _golden():
    a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
    b = relay.const(np.zeros([1, 1, 1]), dtype="int64", span=_create_span("tuple"))
    return tuple([a, b])

res_tuple, golden_tuple = _res(), _golden()
assert len(res_tuple) == len(golden_tuple)
for i in range(len(res_tuple)):
    _verify_structural_equal_with_span(res_tuple[i], golden_tuple[i])
print(golden_tuple[0])
meta[relay.Constant][0] /* span=a:0:0 */

测试内建列表 Span#

def _res():
    a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
    b = relay.const(np.zeros([1, 1, 1]), dtype="int64")
    t = relay.Tuple([a, b])
    t_a = relay.TupleGetItem(t, 0)
    t_b = relay.TupleGetItem(t, 1)
    return set_span([t_a, t_b], "list")

def _golden():
    a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
    b = relay.const(np.zeros([1, 1, 1]), dtype="int64", span=_create_span("list"))
    t = relay.Tuple([a, b], span=_create_span("list"))
    t_a = relay.TupleGetItem(t, 0, span=_create_span("list"))
    t_b = relay.TupleGetItem(t, 1, span=_create_span("list"))
    return [t_a, t_b]

res_list, golden_list = _res(), _golden()
assert len(res_list) == len(golden_list)
for i in range(len(res_list)):
    _verify_structural_equal_with_span(res_list[i], golden_list[i])

测试 relay.var Span#

x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
x_expected = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var"))
_verify_structural_equal_with_span(x, x_expected)

测试 relay.const Span#

c = set_span(relay.const(np.ones([64, 64, 3, 3]), dtype="int64"), "const_c")
c_expected = relay.const(
    np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("const_c")
)
_verify_structural_equal_with_span(c, c_expected)

测试 relay.Call Span#

def _res():
    x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
    w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
    y = set_span(
        relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "conv2d"
    )
    return relay.Function([x], y)

def _golden():
    x = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var"))
    w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("conv2d"))
    y = _set_span(
        relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "conv2d"
    )
    return relay.Function([x], y)

_verify_structural_equal_with_span(_res(), _golden())
print(_golden())
fn (%x: Tensor[(1, 64, 56, 56), float32] /* span=x_var:0:0 */) {
  nn.conv2d(%x, meta[relay.Constant][0] /* span=conv2d:0:0 */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* span=conv2d:0:0 */
}

测试 relay.Tuple Span#

def _res():
    a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a")
    b = relay.const(np.ones([1, 1, 1]), dtype="int64")
    t = set_span(relay.Tuple([a, b]), "t")
    return relay.Function([], t)

def _golden():
    a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
    b = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("t"))
    t = relay.Tuple([a, b], span=_create_span("t"))
    return relay.Function([], t)

_verify_structural_equal_with_span(_res(), _golden())

测试 relay.TupleGetItem Span#

def _res():
    a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a")
    b = relay.const(np.ones([1, 1, 1]), dtype="int64")
    t = relay.Tuple([a, b])
    i = set_span(relay.TupleGetItem(t, 0), "i")
    return relay.Function([], i)

def _golden():
    a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
    b = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("i"))
    t = relay.Tuple([a, b], span=_create_span("i"))
    i = relay.TupleGetItem(t, 0, span=_create_span("i"))
    return relay.Function([], i)

_verify_structural_equal_with_span(_res(), _golden())

测试 relay.Let Span#

def _res():
    x = set_span(relay.Var("x"), "x_var")
    c_1 = relay.const(np.ones(10))
    add = relay.add(x, x)
    body = set_span(relay.Let(x, c_1, add), "let")

    c_2 = set_span(relay.const(np.zeros(10)), "zeros")
    y = set_span(relay.add(body, c_2), "add_2")
    return relay.Function([x], y)

def _golden():
    x = relay.Var("x", span=_create_span("x_var"))
    c_1 = relay.const(np.ones(10), span=_create_span("let"))
    add = _set_span(relay.add(x, x), "let")
    body = relay.Let(x, c_1, add, span=_create_span("let"))

    c_2 = relay.const(np.zeros(10), span=_create_span("zeros"))
    y = _set_span(relay.add(body, c_2), "add_2")
    return relay.Function([x], y)

_verify_structural_equal_with_span(_res(), _golden())

测试 relay.If Span#

def _res():
    x = set_span(relay.var("x", shape=[], dtype="float32"), "x_var")
    y = set_span(relay.var("y", shape=[], dtype="float32"), "y_var")
    eq = relay.equal(x, y)

    true_branch = set_span(relay.add(x, y), "true_branch")
    false_branch = relay.subtract(x, y)
    ife = set_span(relay.If(eq, true_branch, false_branch), "if")
    return relay.Function([x, y], ife)

def _golden():
    x = relay.var("x", shape=[], dtype="float32", span=_create_span("x_var"))
    y = relay.var("y", shape=[], dtype="float32", span=_create_span("y_var"))
    eq = _set_span(relay.equal(x, y), "if")

    true_branch = _set_span(relay.add(x, y), "true_branch")
    false_branch = _set_span(relay.subtract(x, y), "if")
    ife = relay.If(eq, true_branch, false_branch, span=_create_span("if"))
    return relay.Function([x, y], ife)

_verify_structural_equal_with_span(_res(), _golden())

测试 relay.Function Span#

def _res():
    x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
    w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
    y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1))
    f = set_span(relay.Function([x], y), "func")
    return f

def _golden():
    x = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var"))
    w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("func"))
    y = _set_span(
        relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "func"
    )
    f = relay.Function([x], y, span=_create_span("func"))
    return f

_verify_structural_equal_with_span(_res(), _golden())
x = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var"))
w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("func"))
y = _set_span(
    relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "func"
)
f = relay.Function([x], y, span=_create_span("func"))
print(f)
fn (%x: Tensor[(1, 64, 56, 56), float32] /* span=x_var:0:0 */) {
  nn.conv2d(%x, meta[relay.Constant][0] /* span=func:0:0 */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* span=func:0:0 */
} /* span=func:0:0 */