import tvm
from tvm import relay
@tvm.instrument.pass_instrument
class PassCounter:
def __init__(self):
# Just setting a garbage value to test set_up callback
self.counts = 1234
def enter_pass_ctx(self):
self.counts = 0
def exit_pass_ctx(self):
self.counts = 0
def run_before_pass(self, module, info):
self.counts += 1
def get_counts(self):
return self.counts
def test_print_debug_callback():
shape = (1, 2, 3)
tp = relay.TensorType(shape, "float32")
x = relay.var("x", tp)
y = relay.add(x, x)
y = relay.multiply(y, relay.const(2, "float32"))
func = relay.Function([x], y)
seq = tvm.transform.Sequential(
[
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.DeadCodeElimination(),
]
)
mod = tvm.IRModule({"main": func})
pass_counter = PassCounter()
with tvm.transform.PassContext(opt_level=3, instruments=[pass_counter]):
# Should be reseted when entering pass context
assert pass_counter.get_counts() == 0
mod = seq(mod)
# TODO(@jroesch): when we remove new fn pass behavior we need to remove
# change this back to match correct behavior
assert pass_counter.get_counts() == 6
# Should be cleanned up after exiting pass context
assert pass_counter.get_counts() == 0