TVM Pass Instrument#
import tvm
import tvm.relay as relay
from tvm.relay.testing import resnet
from tvm.contrib.download import download_testdata
from tvm.relay.build_module import bind_params_by_name
from tvm.ir.instrument import (
PassTimingInstrument,
pass_instrument,
)
batch_size = 1
num_of_image_class = 1000
image_shape = (3, 224, 224)
output_shape = (batch_size, num_of_image_class)
relay_mod, relay_params = resnet.get_workload(num_layers=18, batch_size=1, image_shape=image_shape)
timing_inst = PassTimingInstrument()
with tvm.transform.PassContext(instruments=[timing_inst]):
relay_mod = relay.transform.InferType()(relay_mod)
relay_mod = relay.transform.FoldScaleAxis()(relay_mod)
# 在退出上下文之前,获取 profile 结果。
profiles = timing_inst.render()
print("Printing results of timing profile...")
print(profiles)
Printing results of timing profile...
InferType: 11228us [11228us] (53.85%; 53.85%)
FoldScaleAxis: 9621us [7us] (46.15%; 46.15%)
FoldConstant: 9614us [2007us] (46.11%; 99.92%)
InferType: 7607us [7607us] (36.49%; 79.13%)