TVM 自动量化过程剖析#
以 PyTorch 的 resnet18 模型为例剖析 TVM 自动量化过程。
PyTorch 模型翻译为 relay 模型#
加载 PyTorch 模型:
def load_model(input_shape):
from torchvision.models import resnet18, ResNet18_Weights
import torch
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
data = torch.randn(*input_shape)
model = torch.jit.trace(model.eval(), data)
return model.eval()
PyTorch 模型翻译为 relay 模型:
import tvm
from tvm import relay
input_shape = 1, 3, 224, 224
input_name = "data"
traced_model = load_model(input_shape)
mod, params = relay.frontend.from_pytorch(
traced_model,
[(input_name, input_shape)],
# use_parser_friendly_name=True
)
with tvm.transform.PassContext(opt_level=3): # 预处理
opt_mod = relay.quantize.prerequisite_optimize(mod, params)
加载数据
import tvm.testing
from tvm import relay
from tvm.relay import transform, build_module
from tvm.relay.testing import run_opt_pass
from dataclasses import dataclass
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import numpy as np
from tvm_book.data.classification import ImageFolderDataset
def preprocess_image(
image: np.ndarray,
size: tuple[int] = (224, 224),
mean: tuple[float] = (0.485, 0.456, 0.406),
std: tuple[float] = (0.229, 0.224, 0.225)
):
im = Image.fromarray(image)
im = im.resize((256, 256), Image.Resampling.BILINEAR)
ori_H, ori_W = im.size
H, W = size
space_W, space_H = (ori_W - W)//2, (ori_H - H)//2
im = im.crop((space_W, space_H, ori_W-space_W, ori_H-space_H))
image = np.array(im, dtype="float32")
im.close()
image = image/256
image -= mean
image /= std
return image.astype(np.float32)
@dataclass
class ImageNet:
root: str
size: tuple[int] = (224, 224)
mean: tuple[float] = (0.485, 0.456, 0.406)
std: tuple[float] = (0.229, 0.224, 0.225)
def __post_init__(self):
self.root = Path(self.root) # 数据根目录
self.valset = ImageFolderDataset(f"{self.root}/val")
self.trainset = ImageFolderDataset(f"{self.root}/train")
def calibrateset(self, calibrate_num: int = 200):
"""用于 TVM 量化的校准数据集
"""
for k, (data, label) in tqdm(enumerate(self.trainset)):
if k >= calibrate_num:
break
image = preprocess_image(data, self.size, self.mean, self.std)
images = np.expand_dims(image, 0)
images = images.transpose((0, 3, 1, 2))
yield {"data": images}
dataset = ImageNet("/media/pc/data/lxw/home/data/datasets/ILSVRC/")
resnet18 算子融合#
from tvm.relay import Call
from tvm.relay.function import Function, FunctionWithFields
from tvm.relay.quantize._partition import QPartitionExpr
@tvm.relay.transform.function_pass(opt_level=1)
class QPartitionTransform:
"""为融合的函数添加 QPartitionExpr
"""
def transform_function(self, func, mod, ctx):
class Replace(tvm.relay.ExprMutator):
def visit_function(self, fn):
new_params = [self.visit(x) for x in fn.params]
new_body = self.visit(fn.body)
if not isinstance(new_body.op, Function): # 防止循环添加 QPartitionExpr
new_body = QPartitionExpr(new_body).realize()
if new_params == list(fn.params) and new_body == fn.body:
new_fn = fn
else:
new_fn = FunctionWithFields(fn, list(new_params), new_body)
return new_fn
return Replace().visit(func)
@tvm.relay.transform.function_pass(opt_level=1)
class SplitGraphTransform:
"""保存子图到不同是子函数
"""
def __init__(self):
self.reset()
def reset(self):
self._func_index = 0
def transform_function(self, func, mod, ctx):
obj = self
class Replace(tvm.relay.ExprMutator):
def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
if isinstance(new_fn, Function):
func_name = f"f_{obj._func_index:04d}"
new_fn = run_opt_pass(new_fn, relay.transform.FoldConstant())
# print(new_fn)
mod[func_name] = new_fn
new_fn = mod.get_global_var(func_name)
obj._func_index += 1
if new_fn == call.op and new_args == list(call.args):
new_call = call
else:
new_call = Call(new_fn, new_args, call.attrs, call.type_args, call.span)
return new_call
return Replace().visit(func)
from tvm_book.tvm_utils.relay_pattern import *
# 配置融合规则
compiler_name = "ccompiler"
pattern_table = [
(f"{compiler_name}.conv_add_relu_max_pool2d", make_conv_add_relu_max_pool2d_pattern()),
(f"{compiler_name}.conv2d_transpose_add_activate", make_conv2d_transpose_add_activate_pattern()),
(f"{compiler_name}.conv_add_activate", make_conv_add_activate_pattern()),
(f"{compiler_name}.max_pool2d", make_max_pool2d_pattern()),
(f"{compiler_name}.dense_add", make_dense_add_pattern()),
(f"{compiler_name}.adaptive_avg_pool2d", make_adaptive_avg_pool2d_pattern()),
(f"{compiler_name}.avg_pool2dd", make_avg_pool2d_pattern()),
(f"{compiler_name}.add_multiply_add", make_add_multiply_add_pattern()),
(f"{compiler_name}.add", make_add_pattern()),
(f"{compiler_name}.multiply", make_multiply_pattern()),
# (f"{compiler_name}.strided_slice", make_strided_slice_pattern()),
]
merge_passes = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.MergeComposite(pattern_table),
QPartitionTransform(), # 为融合函数添加 `QPartitionExpr` 算子
# relay.transform.DefuseOps(),
# relay.transform.MergeComposite(pattern_table),
SplitGraphTransform(),
])
with tvm.transform.PassContext(opt_level=3):
run_mod = merge_passes(opt_mod)
print(run_mod["f_0000"])
fn (%FunctionVar_19_0: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */, %FunctionVar_19_1: Tensor[(64, 3, 7, 7), float32] /* ty=Tensor[(64, 3, 7, 7), float32] */, %FunctionVar_19_2: Tensor[(64, 1, 1), float32] /* ty=Tensor[(64, 1, 1), float32] */, PartitionedFromPattern="nn.conv2d_add_nn.relu_", Composite="ccompiler.conv_add_activate") -> Tensor[(1, 64, 112, 112), float32] {
%0 = nn.conv2d(%FunctionVar_19_0, %FunctionVar_19_1, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%1 = add(%0, %FunctionVar_19_2) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%2 = nn.relu(%1) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::relu__0:0:0 */;
%3 = annotation.cast_hint(%2, dtype="int8") /* ty=Tensor[(1, 64, 112, 112), float32] */;
annotation.stop_fusion(%3) /* ty=Tensor[(1, 64, 112, 112), float32] */
} /* ty=fn (Tensor[(1, 3, 224, 224), float32], Tensor[(64, 3, 7, 7), float32], Tensor[(64, 1, 1), float32]) -> Tensor[(1, 64, 112, 112), float32] */
print(run_mod["main"])
fn (%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 1000), float32] {
%0 = @f_0000(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%1 = @f_0001(%0) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%2 = @f_0002(%1, meta[relay.Constant][2] /* ty=Tensor[(64, 64, 3, 3), float32] */, meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%3 = @f_0003(%2, meta[relay.Constant][4] /* ty=Tensor[(64, 64, 3, 3), float32] */, meta[relay.Constant][5] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%4 = @f_0004(%3, %1) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%5 = @f_0005(%4, meta[relay.Constant][6] /* ty=Tensor[(64, 64, 3, 3), float32] */, meta[relay.Constant][7] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%6 = @f_0006(%5, meta[relay.Constant][8] /* ty=Tensor[(64, 64, 3, 3), float32] */, meta[relay.Constant][9] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%7 = @f_0007(%6, %4) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%8 = @f_0008(%7, meta[relay.Constant][10] /* ty=Tensor[(128, 64, 3, 3), float32] */, meta[relay.Constant][11] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 28, 28), float32] */;
%9 = @f_0009(%8, meta[relay.Constant][12] /* ty=Tensor[(128, 128, 3, 3), float32] */, meta[relay.Constant][13] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 28, 28), float32] */;
%10 = @f_0010(%7, meta[relay.Constant][14] /* ty=Tensor[(128, 64, 1, 1), float32] */, meta[relay.Constant][15] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 28, 28), float32] */;
%11 = @f_0011(%9, %10) /* ty=Tensor[(1, 128, 28, 28), float32] */;
%12 = @f_0012(%11, meta[relay.Constant][16] /* ty=Tensor[(128, 128, 3, 3), float32] */, meta[relay.Constant][17] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 28, 28), float32] */;
%13 = @f_0013(%12, meta[relay.Constant][18] /* ty=Tensor[(128, 128, 3, 3), float32] */, meta[relay.Constant][19] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 28, 28), float32] */;
%14 = @f_0014(%13, %11) /* ty=Tensor[(1, 128, 28, 28), float32] */;
%15 = @f_0015(%14, meta[relay.Constant][20] /* ty=Tensor[(256, 128, 3, 3), float32] */, meta[relay.Constant][21] /* ty=Tensor[(256, 1, 1), float32] */) /* ty=Tensor[(1, 256, 14, 14), float32] */;
%16 = @f_0016(%15, meta[relay.Constant][22] /* ty=Tensor[(256, 256, 3, 3), float32] */, meta[relay.Constant][23] /* ty=Tensor[(256, 1, 1), float32] */) /* ty=Tensor[(1, 256, 14, 14), float32] */;
%17 = @f_0017(%14, meta[relay.Constant][24] /* ty=Tensor[(256, 128, 1, 1), float32] */, meta[relay.Constant][25] /* ty=Tensor[(256, 1, 1), float32] */) /* ty=Tensor[(1, 256, 14, 14), float32] */;
%18 = @f_0018(%16, %17) /* ty=Tensor[(1, 256, 14, 14), float32] */;
%19 = @f_0019(%18, meta[relay.Constant][26] /* ty=Tensor[(256, 256, 3, 3), float32] */, meta[relay.Constant][27] /* ty=Tensor[(256, 1, 1), float32] */) /* ty=Tensor[(1, 256, 14, 14), float32] */;
%20 = @f_0020(%19, meta[relay.Constant][28] /* ty=Tensor[(256, 256, 3, 3), float32] */, meta[relay.Constant][29] /* ty=Tensor[(256, 1, 1), float32] */) /* ty=Tensor[(1, 256, 14, 14), float32] */;
%21 = @f_0021(%20, %18) /* ty=Tensor[(1, 256, 14, 14), float32] */;
%22 = @f_0022(%21, meta[relay.Constant][30] /* ty=Tensor[(512, 256, 3, 3), float32] */, meta[relay.Constant][31] /* ty=Tensor[(512, 1, 1), float32] */) /* ty=Tensor[(1, 512, 7, 7), float32] */;
%23 = @f_0023(%22, meta[relay.Constant][32] /* ty=Tensor[(512, 512, 3, 3), float32] */, meta[relay.Constant][33] /* ty=Tensor[(512, 1, 1), float32] */) /* ty=Tensor[(1, 512, 7, 7), float32] */;
%24 = @f_0024(%21, meta[relay.Constant][34] /* ty=Tensor[(512, 256, 1, 1), float32] */, meta[relay.Constant][35] /* ty=Tensor[(512, 1, 1), float32] */) /* ty=Tensor[(1, 512, 7, 7), float32] */;
%25 = @f_0025(%23, %24) /* ty=Tensor[(1, 512, 7, 7), float32] */;
%26 = @f_0026(%25, meta[relay.Constant][36] /* ty=Tensor[(512, 512, 3, 3), float32] */, meta[relay.Constant][37] /* ty=Tensor[(512, 1, 1), float32] */) /* ty=Tensor[(1, 512, 7, 7), float32] */;
%27 = @f_0027(%26, meta[relay.Constant][38] /* ty=Tensor[(512, 512, 3, 3), float32] */, meta[relay.Constant][39] /* ty=Tensor[(512, 1, 1), float32] */) /* ty=Tensor[(1, 512, 7, 7), float32] */;
%28 = @f_0028(%27, %25) /* ty=Tensor[(1, 512, 7, 7), float32] */;
%29 = @f_0029(%28) /* ty=Tensor[(1, 512, 1, 1), float32] */;
%30 = reshape(%29, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 512, 1, 1), float32] span=aten::flatten_0:0:0 */;
%31 = squeeze(%30, axis=[2, 3]) /* ty=Tensor[(1, 512), float32] span=aten::flatten_0:0:0 */;
@f_0030(%31, meta[relay.Constant][40] /* ty=Tensor[(1000, 512), float32] */, meta[relay.Constant][41] /* ty=Tensor[(1000), float32] */) /* ty=Tensor[(1, 1000), float32] */
} /* ty=fn (Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 1000), float32] */
注解 resnet18 模型#
with tvm.transform.PassContext(opt_level=3):
run_mod = merge_passes(opt_mod)
# run_mod = relay.quantize.annotate()(run_mod)
print(run_mod["f_0000"])
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/media/pc/data/lxw/ai/tvm/xinetzone/tvm-book/doc/chaos/quantize/resnet18.ipynb 单元格 17 line 1
----> <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/tvm/xinetzone/tvm-book/doc/chaos/quantize/resnet18.ipynb#X22sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> print(run_mod["f_0000"])
File /media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/ir/module.py:124, in IRModule.__getitem__(self, var)
111 """Lookup a global definition by name or by variable.
112
113 Parameters
(...)
121 The definition referenced by :code:`var` (either a function or type).
122 """
123 if isinstance(var, string_types):
--> 124 return _ffi_api.Module_Lookup_str(self, var)
125 if isinstance(var, _expr.GlobalVar):
126 return _ffi_api.Module_Lookup(self, var)
File /media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/_ffi/_ctypes/packed_func.py:239, in PackedFuncBase.__call__(self, *args)
227 ret_tcode = ctypes.c_int()
228 if (
229 _LIB.TVMFuncCall(
230 self.handle,
(...)
237 != 0
238 ):
--> 239 raise_last_ffi_error()
240 _ = temp_args
241 _ = args
File /media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/_ffi/base.py:476, in raise_last_ffi_error()
470 # The exception PyObject may contain a large amount of state,
471 # including all stack frames that may be inspected in a later
472 # PDB post-mortem. Therefore, we must make sure to remove the
473 # underlying PyObject* from the C++ side after we retrieve it.
474 _LIB.TVMDropLastPythonError()
--> 476 raise py_err
ValueError: Traceback (most recent call last):
2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::BaseFunc (tvm::IRModule, tvm::runtime::String)>::AssignTypedLambda<tvm::{lambda(tvm::IRModule, tvm::runtime::String)#5}>(tvm::{lambda(tvm::IRModule, tvm::runtime::String)#5}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
1: tvm::IRModuleNode::Lookup(tvm::runtime::String const&) const
0: tvm::IRModuleNode::GetGlobalVar(tvm::runtime::String const&) const
File "/media/pc/data/lxw/ai/tvm/src/ir/module.cc", line 176
ValueError: Cannot find global var "f_0000" in the Module
candidates are: ["f_0058", "main", "f_0044", "f_0037", "f_0056", "f_0042", "f_0049", "f_0061", "f_0035", "f_0054", "f_0040", "f_0047", "f_0033", "f_0052", "f_0059", "f_0045", "f_0031", "f_0050", "f_0038", "f_0057", "f_0043", "f_0036", "f_0055", "f_0041", "f_0048", "f_0060", "f_0034", "f_0053", "f_0046", "f_0032", "f_0039", "f_0051"]
# from tvm.ir import IRModule, structural_equal
import tvm
with tvm.transform.PassContext(opt_level=3):
with relay.quantize.qconfig(
skip_conv_layers=[],
calibrate_mode="kl_divergence",
weight_scale="max",
round_for_shift=True,
# rounding="TONEAREST", # "UPWARD" or "TONEAREST"
calibrate_skip_layers=[],
skip_dense_layer=False,
):
qmod = relay.quantize.quantize(mod, params, dataset.calibrateset(calibrate_num=200))
200it [00:10, 19.39it/s]
度量 resnet18 结果
from tqdm import tqdm
from tvm.runtime.vm import VirtualMachine
from tvm_book.metric.classification import Accuracy, TopKAccuracy
with tvm.transform.PassContext(opt_level=3):
vm_exec = relay.vm.compile(run_mod, target="llvm", params=params)
vm = VirtualMachine(vm_exec, tvm.cpu())
with tvm.transform.PassContext(opt_level=3):
qvm_exec = relay.vm.compile(qmod, target="llvm", params=params)
qvm = VirtualMachine(qvm_exec, tvm.cpu())
metric_top1 = Accuracy("浮点")
metric_top5 = TopKAccuracy(top_k=5)
qmetric_top1 = Accuracy("量化")
qmetric_top5 = TopKAccuracy(top_k=5)
for k, (data, label) in tqdm(enumerate(dataset.valset)):
image = preprocess_image(data, dataset.size, dataset.mean, dataset.std)
images = np.expand_dims(image, 0)
images = images.transpose((0, 3, 1, 2))
input_dict = {"data": images}
output = vm.run(**input_dict).asnumpy()
quant_output = qvm.run(**input_dict).asnumpy()
label = np.array([label])
# 精度度量
metric_top1.update(preds = output, labels = label)
metric_top5.update(preds = output, labels = label)
qmetric_top1.update(preds = quant_output, labels = label)
qmetric_top5.update(preds = quant_output, labels = label)
if k % 1000 == 0:
print(f"浮点: {metric_top1.get(), metric_top5.get()}||量化: {qmetric_top1.get(), qmetric_top5.get()}")
# break
浮点: (('浮点', 0.0), ('top_5_accuracy', 0.0))||量化: (('量化', 0.0), ('top_5_accuracy', 0.0))
浮点: (('浮点', 0.8381618381618382), ('top_5_accuracy', 0.954045954045954))||量化: (('量化', 0.8331668331668332), ('top_5_accuracy', 0.954045954045954))
浮点: (('浮点', 0.7886056971514243), ('top_5_accuracy', 0.9300349825087456))||量化: (('量化', 0.7831084457771115), ('top_5_accuracy', 0.9300349825087456))
浮点: (('浮点', 0.7534155281572809), ('top_5_accuracy', 0.915361546151283))||量化: (('量化', 0.7444185271576141), ('top_5_accuracy', 0.9136954348550483))
浮点: (('浮点', 0.7275681079730068), ('top_5_accuracy', 0.9092726818295426))||量化: (('量化', 0.7223194201449638), ('top_5_accuracy', 0.9072731817045738))
浮点: (('浮点', 0.7518496300739852), ('top_5_accuracy', 0.9164167166566687))||量化: (('量化', 0.7458508298340332), ('top_5_accuracy', 0.9154169166166767))
浮点: (('浮点', 0.7518746875520746), ('top_5_accuracy', 0.9161806365605732))||量化: (('量化', 0.74704215964006), ('top_5_accuracy', 0.9155140809865022))
浮点: (('浮点', 0.7587487501785459), ('top_5_accuracy', 0.9190115697757463))||量化: (('量化', 0.7541779745750607), ('top_5_accuracy', 0.9181545493500929))
浮点: (('浮点', 0.762779652543432), ('top_5_accuracy', 0.9220097487814023))||量化: (('量化', 0.7582802149731284), ('top_5_accuracy', 0.9212598425196851))
浮点: (('浮点', 0.7561382068659038), ('top_5_accuracy', 0.9223419620042218))||量化: (('量化', 0.7512498611265415), ('top_5_accuracy', 0.9213420731029885))
浮点: (('浮点', 0.7503249675032497), ('top_5_accuracy', 0.9234076592340766))||量化: (('量化', 0.7457254274572542), ('top_5_accuracy', 0.922107789221078))
浮点: (('浮点', 0.7503863285155895), ('top_5_accuracy', 0.92573402417962))||量化: (('量化', 0.747022997909281), ('top_5_accuracy', 0.9243705117716571))
浮点: (('浮点', 0.7462711440713274), ('top_5_accuracy', 0.9252562286476127))||量化: (('量化', 0.7422714773768853), ('top_5_accuracy', 0.9237563536371969))
浮点: (('浮点', 0.745250365356511), ('top_5_accuracy', 0.9276978693946619))||量化: (('量化', 0.7406353357434043), ('top_5_accuracy', 0.9259287747096377))
浮点: (('浮点', 0.7438754374687523), ('top_5_accuracy', 0.9284336833083351))||量化: (('量化', 0.7393043354046139), ('top_5_accuracy', 0.9267909435040355))
浮点: (('浮点', 0.7448170121991867), ('top_5_accuracy', 0.9284714352376509))||量化: (('量化', 0.7402839810679288), ('top_5_accuracy', 0.9271381907872809))
浮点: (('浮点', 0.7429535654021624), ('top_5_accuracy', 0.92819198800075))||量化: (('量化', 0.7378288856946441), ('top_5_accuracy', 0.9270670583088557))
浮点: (('浮点', 0.7496029645314981), ('top_5_accuracy', 0.929533556849597))||量化: (('量化', 0.7444856184930299), ('top_5_accuracy', 0.928239515322628))
浮点: (('浮点', 0.7482917615688017), ('top_5_accuracy', 0.929892783734237))||量化: (('量化', 0.743236486861841), ('top_5_accuracy', 0.9288372868173991))
浮点: (('浮点', 0.7500657860112626), ('top_5_accuracy', 0.9293721383085101))||量化: (('量化', 0.7456449660544182), ('top_5_accuracy', 0.928214304510289))
浮点: (('浮点', 0.7497125143742813), ('top_5_accuracy', 0.9290035498225089))||量化: (('量化', 0.7453127343632818), ('top_5_accuracy', 0.9278536073196341))
浮点: (('浮点', 0.7441074234560259), ('top_5_accuracy', 0.9260511404218846))||量化: (('量化', 0.7392505118803866), ('top_5_accuracy', 0.9249559544783582))
浮点: (('浮点', 0.7413299395482024), ('top_5_accuracy', 0.9241852643061679))||量化: (('量化', 0.7368301440843598), ('top_5_accuracy', 0.9228216899231854))
浮点: (('浮点', 0.7370549106560584), ('top_5_accuracy', 0.921090387374462))||量化: (('量化', 0.7328377027085778), ('top_5_accuracy', 0.9199600017390548))
浮点: (('浮点', 0.7325944752301987), ('top_5_accuracy', 0.9175034373567768))||量化: (('量化', 0.728719636681805), ('top_5_accuracy', 0.9165451439523353))
浮点: (('浮点', 0.7274509019639215), ('top_5_accuracy', 0.9144034238630455))||量化: (('量化', 0.7237310507579697), ('top_5_accuracy', 0.9135634574617015))
浮点: (('浮点', 0.7225875927848929), ('top_5_accuracy', 0.9120803046036691))||量化: (('量化', 0.7190108072766432), ('top_5_accuracy', 0.9112726433598708))
浮点: (('浮点', 0.7200103699862967), ('top_5_accuracy', 0.9108551535128329))||量化: (('量化', 0.716306803451724), ('top_5_accuracy', 0.9098922262138439))
浮点: (('浮点', 0.7173315238741473), ('top_5_accuracy', 0.9092175279454305))||量化: (('量化', 0.7138316488696832), ('top_5_accuracy', 0.9082175636584408))
浮点: (('浮点', 0.7188717630426537), ('top_5_accuracy', 0.9096238060756525))||量化: (('量化', 0.7154236060825488), ('top_5_accuracy', 0.9085548774180201))
浮点: (('浮点', 0.7153428219059365), ('top_5_accuracy', 0.9070030998966702))||量化: (('量化', 0.7119762674577514), ('top_5_accuracy', 0.9055364821172628))
浮点: (('浮点', 0.7146866230121609), ('top_5_accuracy', 0.9054869197767814))||量化: (('量化', 0.7113964065675301), ('top_5_accuracy', 0.903841811554466))
浮点: (('浮点', 0.7093215837005094), ('top_5_accuracy', 0.9031280272491484))||量化: (('量化', 0.7062279303771757), ('top_5_accuracy', 0.9010968407237274))
浮点: (('浮点', 0.7076755249840914), ('top_5_accuracy', 0.9014575315899518))||量化: (('量化', 0.7043422926577982), ('top_5_accuracy', 0.8994272900821187))
浮点: (('浮点', 0.7063615776006588), ('top_5_accuracy', 0.9002676391870827))||量化: (('量化', 0.702949913237846), ('top_5_accuracy', 0.8981500544101644))
浮点: (('浮点', 0.7044941573097911), ('top_5_accuracy', 0.8994028742035942))||量化: (('量化', 0.7011799662866776), ('top_5_accuracy', 0.8973457901202823))
浮点: (('浮点', 0.7038693369628621), ('top_5_accuracy', 0.8990861364962084))||量化: (('量化', 0.7008694202938808), ('top_5_accuracy', 0.8969750840254438))
浮点: (('浮点', 0.7021972379124889), ('top_5_accuracy', 0.8973811518607605))||量化: (('量化', 0.6989270560255129), ('top_5_accuracy', 0.8951920218372477))
浮点: (('浮点', 0.6996921133654378), ('top_5_accuracy', 0.8958974763822005))||量化: (('量化', 0.6963237809531329), ('top_5_accuracy', 0.8935554327517697))
浮点: (('浮点', 0.6986487525960873), ('top_5_accuracy', 0.8945411656111382))||量化: (('量化', 0.6952642239942566), ('top_5_accuracy', 0.8922078921053306))
浮点: (('浮点', 0.697257568560786), ('top_5_accuracy', 0.8931526711832204))||量化: (('量化', 0.6938826529336767), ('top_5_accuracy', 0.8907527311817205))
浮点: (('浮点', 0.6957147386649106), ('top_5_accuracy', 0.8919294651349967))||量化: (('量化', 0.6924465256944953), ('top_5_accuracy', 0.8896368381258993))
浮点: (('浮点', 0.6931977810052141), ('top_5_accuracy', 0.8900026189852622))||量化: (('量化', 0.689840718078141), ('top_5_accuracy', 0.8878121949477393))
浮点: (('浮点', 0.6917048440733936), ('top_5_accuracy', 0.8890258366084509))||量化: (('量化', 0.6881002767377503), ('top_5_accuracy', 0.8868165856607986))
浮点: (('浮点', 0.6905524874434672), ('top_5_accuracy', 0.8884798072771073))||量化: (('量化', 0.6870071134746938), ('top_5_accuracy', 0.8863207654371491))
浮点: (('浮点', 0.6892069065131886), ('top_5_accuracy', 0.8874691673518367))||量化: (('量化', 0.6854736561409747), ('top_5_accuracy', 0.8854247683384814))
浮点: (('浮点', 0.6880719984348166), ('top_5_accuracy', 0.886785069889785))||量化: (('量化', 0.6843764265994218), ('top_5_accuracy', 0.8845242494728376))
浮点: (('浮点', 0.6882619518733644), ('top_5_accuracy', 0.8871300610625306))||量化: (('量化', 0.6845811791238484), ('top_5_accuracy', 0.8850237228995128))
浮点: (('浮点', 0.6899022937022146), ('top_5_accuracy', 0.8883356596737568))||量化: (('量化', 0.6861940376242162), ('top_5_accuracy', 0.8860232078498365))
浮点: (('浮点', 0.6872920960796718), ('top_5_accuracy', 0.8874308687577804))||量化: (('量化', 0.6833942164445623), ('top_5_accuracy', 0.8849411236505378))
WARNING:autotvm:One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
3it [00:00, 7.21it/s]
1003it [01:43, 3.80it/s]
2002it [03:27, 3.38it/s]
3003it [05:18, 10.51it/s]
4002it [07:21, 8.66it/s]
5003it [09:29, 9.69it/s]
6002it [11:48, 8.74it/s]
7001it [13:46, 4.84it/s]
8003it [15:58, 9.68it/s]
9002it [17:54, 9.23it/s]
10002it [19:50, 10.08it/s]
11003it [21:36, 9.82it/s]
12001it [24:06, 2.87it/s]
13003it [28:07, 9.15it/s]
14002it [30:02, 9.48it/s]
15003it [32:07, 9.52it/s]
16001it [34:19, 3.67it/s]
17003it [36:39, 9.77it/s]
18002it [38:54, 7.95it/s]
19003it [40:37, 8.98it/s]
20003it [42:36, 9.76it/s]
21002it [44:57, 7.26it/s]
22002it [47:26, 9.43it/s]
23002it [49:31, 3.05it/s]
24001it [51:40, 3.97it/s]
25002it [53:54, 9.73it/s]
26002it [56:01, 9.41it/s]
27002it [57:59, 9.76it/s]
28001it [1:00:32, 5.44it/s]
29001it [1:02:47, 5.26it/s]
30003it [1:05:14, 9.22it/s]
31001it [1:08:07, 3.49it/s]
32003it [1:10:28, 8.66it/s]
33002it [1:12:19, 9.10it/s]
34002it [1:14:10, 9.52it/s]
35001it [1:16:22, 10.01it/s]
36003it [1:18:23, 10.15it/s]
37002it [1:20:18, 6.70it/s]
38001it [1:23:07, 3.72it/s]
39002it [1:25:34, 8.72it/s]
40001it [1:27:49, 4.11it/s]
41003it [1:29:57, 9.91it/s]
42002it [1:33:24, 7.63it/s]
43001it [1:36:15, 2.34it/s]
44003it [1:38:18, 9.28it/s]
45002it [1:40:13, 9.61it/s]
46003it [1:42:36, 10.18it/s]
47003it [1:44:14, 9.96it/s]
48002it [1:47:09, 3.39it/s]
49002it [1:49:35, 9.77it/s]
50000it [1:51:37, 7.47it/s]
from torchvision.prototype.datasets.utils import ImageNet
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
/media/pc/data/lxw/ai/tvm/xinetzone/tvm-book/doc/chaos/quantize/resnet18.ipynb 单元格 21 line 1
----> <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/tvm/xinetzone/tvm-book/doc/chaos/quantize/resnet18.ipynb#X26sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> from torchvision.prototype.datasets.utils import ImageNet
ModuleNotFoundError: No module named 'torchvision.prototype'
import torch
torch.__version__
'2.1.0+cu121'
torch.compiler.list_backends()
['cudagraphs', 'inductor', 'onnxrt', 'openxla', 'openxla_eval', 'tvm']
import torch
import warnings
gpu_ok = False
if torch.cuda.is_available():
device_cap = torch.cuda.get_device_capability()
if device_cap in ((7, 0), (8, 0), (9, 0)):
gpu_ok = True
if not gpu_ok:
warnings.warn(
"GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
"than expected."
)
/tmp/ipykernel_41974/231343347.py:11: UserWarning: GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower than expected.
warnings.warn(
torch.cuda.is_available()
True
torch.cuda.get_device_capability()
(8, 6)