importtvmfromtvmimportrelay@tvm.instrument.pass_instrumentclassPassCounter:def__init__(self):# Just setting a garbage value to test set_up callbackself.counts=1234defenter_pass_ctx(self):self.counts=0defexit_pass_ctx(self):self.counts=0defrun_before_pass(self,module,info):self.counts+=1defget_counts(self):returnself.countsdeftest_print_debug_callback():shape=(1,2,3)tp=relay.TensorType(shape,"float32")x=relay.var("x",tp)y=relay.add(x,x)y=relay.multiply(y,relay.const(2,"float32"))func=relay.Function([x],y)seq=tvm.transform.Sequential([relay.transform.InferType(),relay.transform.FoldConstant(),relay.transform.DeadCodeElimination(),])mod=tvm.IRModule({"main":func})pass_counter=PassCounter()withtvm.transform.PassContext(opt_level=3,instruments=[pass_counter]):# Should be reseted when entering pass contextassertpass_counter.get_counts()==0mod=seq(mod)# TODO(@jroesch): when we remove new fn pass behavior we need to remove# change this back to match correct behaviorassertpass_counter.get_counts()==6# Should be cleanned up after exiting pass contextassertpass_counter.get_counts()==0