NameSupply

NameSupply#

import tvm
import tvm.testing

from tvm import relay
from tvm.ir import GlobalVar, structural_equal
from tvm.ir.supply import NameSupply
from tvm.ir.supply import GlobalVarSupply


def test_name_supply():
    name_supply = NameSupply("prefix")
    name_supply.reserve_name("test")

    assert name_supply.contains_name("test")
    assert name_supply.fresh_name("test") == "prefix_test_1"
    assert name_supply.contains_name("test_1")
    assert not name_supply.contains_name("test_1", False)
    assert not name_supply.contains_name("test_2")


def test_global_var_supply_from_none():
    var_supply = GlobalVarSupply()
    global_var = GlobalVar("test")
    var_supply.reserve_global(global_var)

    assert structural_equal(var_supply.unique_global_for("test"), global_var)
    assert not structural_equal(var_supply.fresh_global("test"), global_var)


def test_global_var_supply_from_name_supply():
    name_supply = NameSupply("prefix")
    var_supply = GlobalVarSupply(name_supply)
    global_var = GlobalVar("test")
    var_supply.reserve_global(global_var)

    assert structural_equal(var_supply.unique_global_for("test", False), global_var)
    assert not structural_equal(var_supply.unique_global_for("test"), global_var)


def test_global_var_supply_from_ir_mod():
    x = relay.var("x")
    y = relay.var("y")
    mod = tvm.IRModule()
    global_var = GlobalVar("test")
    mod[global_var] = relay.Function([x, y], relay.add(x, y))
    var_supply = GlobalVarSupply(mod)

    second_global_var = var_supply.fresh_global("test", False)

    assert structural_equal(var_supply.unique_global_for("test", False), global_var)
    assert not structural_equal(var_supply.unique_global_for("test"), global_var)
    assert not structural_equal(second_global_var, global_var)