模型和预训练权重#

torchvision.models 子包包含用于解决不同任务的模型定义,包括:图像分类、像素级语义分割、目标检测、实例分割、人体关键点检测、视频分类和光流。

参考:models

关于预训练权重的一般信息#

TorchVision 为每个提供的架构提供了使用 PyTorch torch.hub 的预训练权重。实例化预训练模型将下载其权重到缓存目录。可以使用 TORCH_HOME 环境变量设置此目录。有关详细信息,请参阅 torch.hub.load_state_dict_from_url()

初始化预训练模型#

截至 v0.13,TorchVision 为现有模型构建方法提供了新的 多权重支持 API 来加载不同的权重:

from torchvision.models import resnet50, ResNet50_Weights

# Old weights with accuracy 76.130%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

# New weights with accuracy 80.858%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# Best available weights (currently alias for IMAGENET1K_V2)
# Note that these weights may change across versions
resnet50(weights=ResNet50_Weights.DEFAULT)

# Strings are also supported
resnet50(weights="IMAGENET1K_V2")

# No weights - random initialization
resnet50(weights=None);

迁移到新的 API 非常简单。以下两个 API 之间的方法调用是等效的:

from torchvision.models import resnet50, ResNet50_Weights

# Using pretrained weights:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights="IMAGENET1K_V1")
resnet50(pretrained=True)  # deprecated
resnet50(True)  # deprecated

# Using no weights:
resnet50(weights=None)
resnet50()
resnet50(pretrained=False)  # deprecated
resnet50(False);  # deprecated

使用预训练模型#

在使用预训练模型之前,必须对图像进行预处理(使用正确的分辨率/插值调整大小,应用推理转换,重新缩放值等)。由于这取决于给定模型的训练方式,因此没有标准的方法。它可以在不同的模型族、变体甚至权重版本之间有所不同。使用正确的预处理方法至关重要,否则可能导致准确性下降或输出错误。

每个预训练模型的推理转换所需的所有信息都在其权重文档中提供。为了简化推理,TorchVision 将必要的预处理转换捆绑到每个模型权重中。这些可以通过 weight.transforms 属性访问:

# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()

# Apply it to the input image
img_transformed = preprocess(img)

一些模型使用具有不同训练和评估行为的模块,例如批量归一化。要在这些模式之间切换,请使用 model.train()model.eval()(视情况而定)。有关详细信息,请参阅 traineval

# Initialize model
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)

# Set model to eval mode
model.eval()

列出和检索可用的模型#

从v0.14开始,TorchVision 提供了一种新的机制,允许通过名称列出和检索模型和权重。以下是一些使用示例:

# List available models
all_models = list_models()
classification_models = list_models(module=torchvision.models)

# Initialize models
m1 = get_model("mobilenet_v3_large", weights=None)
m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")

# Fetch weights
weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT

weights_enum = get_model_weights("quantized_mobilenet_v3_large")
assert weights_enum == MobileNet_V3_Large_QuantizedWeights

weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
assert weights_enum == weights_enum2

使用来自 Hub 的模型#

大多数预训练模型可以通过 PyTorch Hub 直接访问,而无需安装 TorchVision:

import torch

# Option 1: passing weights param as string
model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")

# Option 2: passing weights param as enum
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)
Downloading: "https://github.com/pytorch/vision/zipball/main" to /home/ai/.cache/torch/hub/main.zip
---------------------------------------------------------------------------
RemoteDisconnected                        Traceback (most recent call last)
/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb 单元格 13 line 7
      <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb#X24sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a> model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")
      <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb#X24sdnNjb2RlLXJlbW90ZQ%3D%3D?line=5'>6</a> # Option 2: passing weights param as enum
----> <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb#X24sdnNjb2RlLXJlbW90ZQ%3D%3D?line=6'>7</a> weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
      <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb#X24sdnNjb2RlLXJlbW90ZQ%3D%3D?line=7'>8</a> model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/hub.py:563, in load(repo_or_dir, model, source, trust_repo, force_reload, verbose, skip_validation, *args, **kwargs)
    559     raise ValueError(
    560         f'Unknown source: "{source}". Allowed values: "github" | "local".')
    562 if source == 'github':
--> 563     repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load",
    564                                        verbose=verbose, skip_validation=skip_validation)
    566 model = _load_local(repo_or_dir, model, *args, **kwargs)
    567 return model

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/hub.py:207, in _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose, skip_validation)
    205     os.makedirs(hub_dir)
    206 # Parse github repo information
--> 207 repo_owner, repo_name, ref = _parse_repo_info(github)
    208 # Github allows branch name with slash '/',
    209 # this causes confusion with path on both Linux and Windows.
    210 # Backslash is not allowed in Github branch name so no need to
    211 # to worry about it.
    212 normalized_br = ref.replace('/', '_')

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/hub.py:150, in _parse_repo_info(github)
    145 if ref is None:
    146     # The ref wasn't specified by the user, so we need to figure out the
    147     # default branch: main or master. Our assumption is that if main exists
    148     # then it's the default branch, otherwise it's master.
    149     try:
--> 150         with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
    151             ref = 'main'
    152     except HTTPError as e:

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:216, in urlopen(url, data, timeout, cafile, capath, cadefault, context)
    214 else:
    215     opener = _opener
--> 216 return opener.open(url, data, timeout)

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:519, in OpenerDirector.open(self, fullurl, data, timeout)
    516     req = meth(req)
    518 sys.audit('urllib.Request', req.full_url, req.data, req.headers, req.get_method())
--> 519 response = self._open(req, data)
    521 # post-process response
    522 meth_name = protocol+"_response"

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:536, in OpenerDirector._open(self, req, data)
    533     return result
    535 protocol = req.type
--> 536 result = self._call_chain(self.handle_open, protocol, protocol +
    537                           '_open', req)
    538 if result:
    539     return result

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:496, in OpenerDirector._call_chain(self, chain, kind, meth_name, *args)
    494 for handler in handlers:
    495     func = getattr(handler, meth_name)
--> 496     result = func(*args)
    497     if result is not None:
    498         return result

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:1391, in HTTPSHandler.https_open(self, req)
   1390 def https_open(self, req):
-> 1391     return self.do_open(http.client.HTTPSConnection, req,
   1392         context=self._context, check_hostname=self._check_hostname)

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:1352, in AbstractHTTPHandler.do_open(self, http_class, req, **http_conn_args)
   1350     except OSError as err: # timeout error
   1351         raise URLError(err)
-> 1352     r = h.getresponse()
   1353 except:
   1354     h.close()

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/http/client.py:1375, in HTTPConnection.getresponse(self)
   1373 try:
   1374     try:
-> 1375         response.begin()
   1376     except ConnectionError:
   1377         self.close()

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/http/client.py:318, in HTTPResponse.begin(self)
    316 # read until we get a non-100 response
    317 while True:
--> 318     version, status, reason = self._read_status()
    319     if status != CONTINUE:
    320         break

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/http/client.py:287, in HTTPResponse._read_status(self)
    283     print("reply:", repr(line))
    284 if not line:
    285     # Presumably, the server closed the connection before
    286     # sending a valid response.
--> 287     raise RemoteDisconnected("Remote end closed connection without"
    288                              " response")
    289 try:
    290     version, status, reason = line.split(None, 2)

RemoteDisconnected: Remote end closed connection without response

您还可以通过执行以下操作来检索特定模型的所有可用权重:

import torch

weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
print([weight for weight in weight_enum])