解读 EnvFunc#
import tvm
import tvm.testing # 保证 "attrs.TestAttrs" 被注册
tvm.ir.EnvFunc??
@tvm.register_func("test.env_func")
def test(x):
return x + 1
f = tvm.get_global_func("test.env_func")
x = tvm.ir.EnvFunc.get("test.env_func")
assert x.name == "test.env_func"
json_str = tvm.ir.save_json([x])
y = tvm.ir.load_json(json_str)[0]
assert y.name == x.name
assert y(1) == 2
assert y.func(1) == 2
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4), func=y)
# assert x.name == "xx"
assert x.padding[0].value == 3
assert x.padding[1].value == 4
assert x.axis == 10
x = tvm.ir.load_json(tvm.ir.save_json(x))
assert isinstance(x.func, tvm.ir.EnvFunc)
assert x.func(10) == 11
type(y), type(f), type(y.func)
(tvm.ir.base.EnvFunc,
tvm.runtime.packed_func.PackedFunc,
tvm.runtime.packed_func.PackedFunc)
import tvm
import numpy as np
import pytest
from tvm import te
from tvm.ffi.access_path import AccessPath
from tvm.script import tir as T, ir as I
def consistent_equal(x, y, map_free_vars=False):
struct_equal0 = tvm.ir.structural_equal(x, y, map_free_vars)
struct_equal1 = tvm.ir.structural_equal(y, x, map_free_vars)
xhash = tvm.ir.structural_hash(x, map_free_vars)
yhash = tvm.ir.structural_hash(y, map_free_vars)
if struct_equal0 != struct_equal1:
raise ValueError(
"Non-commutative {} vs {}, sequal0={}, sequal1={}".format(
x, y, struct_equal0, struct_equal1
)
)
# NOTE: hash colision can happen but should be rare.
# we can confirm that hash colison doesn't happen for our testcases
if struct_equal0 != (xhash == yhash):
raise ValueError(
"Inconsistent {} vs {}, sequal={}, xhash={}, yhash={}".format(
x, y, struct_equal0, xhash, yhash
)
)
return struct_equal0
@tvm.register_func("test.sequal.env_func")
def test(x):
return x + 1
x = tvm.ir.EnvFunc.get("test.sequal.env_func")
y = tvm.ir.EnvFunc.get("test.sequal.env_func")
assert consistent_equal(y, x)