PyTorch 量化

PyTorch 量化#

import sys
from pathlib import Path

sys.path.extend([f"{Path(".").resolve().parents[3]}/tests"])
# from tools.tag_span import _create_span, _set_span, _verify_structural_equal_with_span
from tools.torch_utils import verify_model
import torch
import torchvision

# Classification
x = torch.rand(1, 3, 224, 224)
m_classifier = torchvision.models.mobilenet_v3_large(pretrained=True)
# m_classifier = torchvision.models.mobilenet_v3_small(pretrained=True)
m_classifier.eval()
predictions = m_classifier(x)

# Quantized Classification
x = torch.rand(1, 3, 224, 224)
m_classifier = torchvision.models.quantization.mobilenet_v3_large(pretrained=True)
m_classifier.eval()
predictions = m_classifier(x)
/media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=MobileNet_V3_Large_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V3_Large_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /home/ai/.cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth
100%|██████████| 21.1M/21.1M [00:04<00:00, 5.45MB/s]
from tvm import relay