TVM 初探#

from pathlib import Path

# 载入自定义模块
from mod import torchq

import set_env
from torch import nn, jit
from torchvision.models import quantization as qmodels
from torch.ao.quantization import get_default_qat_qconfig
def create_model(model_name='resnet18',
                 quantize=False,
                 pretrained=True):
    '''定义模型'''
    mod = getattr(qmodels, model_name)
    return mod(pretrained=pretrained,
               quantize=quantize)
# 设置 warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module='.*'
)
warnings.filterwarnings(
    action='ignore',
    module='torch.ao.quantization'
)
# 载入自定义模块
from mod import torchq

from torchq.helper import evaluate, print_size_of_model, load_model

def print_info(model, model_type, criterion, test_iter):
    '''打印信息'''
    num_eval = len(test_iter)
    print_size_of_model(model)
    top1, top5 = evaluate(model, criterion, test_iter)
    print(f'\n{model_type}\n\t'
          f'在 {num_eval} 批次图片上评估 accuracy 为: {top1.avg:2.5f}')
    return top1, top5
from utils.loader import get_val_loader

saved_model_dir = 'models/'
model_name = 'resnet18'
float_model_file = f'{model_name}_pretrained_float.pth'
float_model_path = saved_model_dir + float_model_file
batch_size = 8

val_loader = get_val_loader(batch_size)
backend = 'fbgemm'

qat_model = create_model(model_name, quantize=False)
qat_model.fuse_model()
qat_model.qconfig = get_default_qat_qconfig(backend=backend)
model_type = '浮点模型'
criterion = nn.CrossEntropyLoss(reduction="none")
top1, top5 = print_info(qat_model, model_type, criterion, val_loader)
模型大小:46.837645 MB
Batch 0 ~ Acc@1  75.00 ( 75.00)	 Acc@5 100.00 (100.00)
Batch 500 ~ Acc@1 100.00 ( 74.28)	 Acc@5 100.00 ( 91.89)
Batch 1000 ~ Acc@1  75.00 ( 77.78)	 Acc@5  87.50 ( 92.87)
Batch 1500 ~ Acc@1   0.00 ( 75.97)	 Acc@5 100.00 ( 93.03)
Batch 2000 ~ Acc@1  87.50 ( 75.61)	 Acc@5 100.00 ( 93.35)
Batch 2500 ~ Acc@1  50.00 ( 76.28)	 Acc@5  87.50 ( 93.47)
Batch 3000 ~ Acc@1  50.00 ( 74.50)	 Acc@5  75.00 ( 92.22)
Batch 3500 ~ Acc@1 100.00 ( 72.89)	 Acc@5 100.00 ( 91.26)
Batch 4000 ~ Acc@1  87.50 ( 72.02)	 Acc@5  87.50 ( 90.55)
Batch 4500 ~ Acc@1  50.00 ( 71.30)	 Acc@5  75.00 ( 90.17)
Batch 5000 ~ Acc@1  87.50 ( 70.59)	 Acc@5  87.50 ( 89.57)
Batch 5500 ~ Acc@1  87.50 ( 69.91)	 Acc@5  87.50 ( 89.13)
Batch 6000 ~ Acc@1  75.00 ( 69.82)	 Acc@5 100.00 ( 89.07)

浮点模型:
	在 6250 批次图片上评估 accuracy 为: 69.76000
top1.avg, top5.avg
(tensor(69.7600), tensor(89.0820))
backend = 'fbgemm'

qat_model = create_model(model_name, quantize=True)
# qat_model.fuse_model()
# qat_model.qconfig = get_default_qat_qconfig(backend=backend)
model_type = '浮点模型'
criterion = nn.CrossEntropyLoss(reduction="none")
top1, top5 = print_info(qat_model, model_type, criterion, val_loader)
模型大小:11.838625 MB
Batch 0 ~ Acc@1  75.00 ( 75.00)	 Acc@5 100.00 (100.00)
Batch 500 ~ Acc@1 100.00 ( 74.23)	 Acc@5 100.00 ( 92.02)
Batch 1000 ~ Acc@1  75.00 ( 77.68)	 Acc@5  87.50 ( 93.01)
Batch 1500 ~ Acc@1   0.00 ( 75.77)	 Acc@5 100.00 ( 93.12)
Batch 2000 ~ Acc@1  87.50 ( 75.32)	 Acc@5 100.00 ( 93.35)
Batch 2500 ~ Acc@1  50.00 ( 75.97)	 Acc@5  75.00 ( 93.43)
Batch 3000 ~ Acc@1  62.50 ( 74.21)	 Acc@5  87.50 ( 92.13)
Batch 3500 ~ Acc@1 100.00 ( 72.59)	 Acc@5 100.00 ( 91.12)
Batch 4000 ~ Acc@1  75.00 ( 71.76)	 Acc@5  87.50 ( 90.46)
Batch 4500 ~ Acc@1  50.00 ( 71.03)	 Acc@5  75.00 ( 90.05)
Batch 5000 ~ Acc@1  87.50 ( 70.31)	 Acc@5  87.50 ( 89.46)
Batch 5500 ~ Acc@1  87.50 ( 69.66)	 Acc@5  87.50 ( 89.01)
Batch 6000 ~ Acc@1  87.50 ( 69.58)	 Acc@5 100.00 ( 88.90)

浮点模型:
	在 6250 批次图片上评估 accuracy 为: 69.48800
print(top1.avg, top5.avg)
tensor(69.4880) tensor(88.9000)
qat_model.conv1.weight().dtype
torch.qint8
scripted_qat_model_file = 'qat_resnet18.pth'
jit.save(jit.script(qat_model), saved_model_dir + scripted_qat_model_file)
from torchvision.models import quantization
from torch import jit
import torch
scripted_qat_model_file = '/media/pc/data/4tb/lxw/work/torch-quantization/docs/tutorial/models/mobilenet_qat_scripted_quantized.pth'
m = jit.load(scripted_qat_model_file)
input_shape = 1, 3, 224, 224
input_data = torch.randn(input_shape)
# scripted_model = jit.trace(m, input_data).eval()
y = m(input_data)
m
RecursiveScriptModule(
  original_name=QuantizableMobileNetV2
  (features): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(
      original_name=ConvNormActivation
      (0): RecursiveScriptModule(original_name=ConvReLU2d)
      (1): RecursiveScriptModule(original_name=Identity)
      (2): RecursiveScriptModule(original_name=Identity)
    )
    (1): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(original_name=Conv2d)
        (2): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (2): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (3): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (4): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (5): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (6): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (7): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (8): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (9): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (10): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (11): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (12): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (13): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (14): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (15): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (16): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (17): RecursiveScriptModule(
      original_name=QuantizableInvertedResidual
      (conv): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (1): RecursiveScriptModule(
          original_name=ConvNormActivation
          (0): RecursiveScriptModule(original_name=ConvReLU2d)
          (1): RecursiveScriptModule(original_name=Identity)
          (2): RecursiveScriptModule(original_name=Identity)
        )
        (2): RecursiveScriptModule(original_name=Conv2d)
        (3): RecursiveScriptModule(original_name=Identity)
      )
      (skip_add): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
    )
    (18): RecursiveScriptModule(
      original_name=ConvNormActivation
      (0): RecursiveScriptModule(original_name=ConvReLU2d)
      (1): RecursiveScriptModule(original_name=Identity)
      (2): RecursiveScriptModule(original_name=Identity)
    )
  )
  (classifier): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(original_name=Dropout)
    (1): RecursiveScriptModule(
      original_name=Linear
      (_packed_params): RecursiveScriptModule(original_name=LinearPackedParams)
    )
  )
  (quant): RecursiveScriptModule(original_name=Quantize)
  (dequant): RecursiveScriptModule(original_name=DeQuantize)
)