模型和预训练权重#
torchvision.models
子包包含用于解决不同任务的模型定义,包括:图像分类、像素级语义分割、目标检测、实例分割、人体关键点检测、视频分类和光流。
参考:models
关于预训练权重的一般信息#
TorchVision 为每个提供的架构提供了使用 PyTorch torch.hub 的预训练权重。实例化预训练模型将下载其权重到缓存目录。可以使用 TORCH_HOME
环境变量设置此目录。有关详细信息,请参阅 torch.hub.load_state_dict_from_url()
。
初始化预训练模型#
截至 v0.13,TorchVision 为现有模型构建方法提供了新的 多权重支持 API 来加载不同的权重:
from torchvision.models import resnet50, ResNet50_Weights
# Old weights with accuracy 76.130%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
# New weights with accuracy 80.858%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Best available weights (currently alias for IMAGENET1K_V2)
# Note that these weights may change across versions
resnet50(weights=ResNet50_Weights.DEFAULT)
# Strings are also supported
resnet50(weights="IMAGENET1K_V2")
# No weights - random initialization
resnet50(weights=None);
迁移到新的 API 非常简单。以下两个 API 之间的方法调用是等效的:
from torchvision.models import resnet50, ResNet50_Weights
# Using pretrained weights:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights="IMAGENET1K_V1")
resnet50(pretrained=True) # deprecated
resnet50(True) # deprecated
# Using no weights:
resnet50(weights=None)
resnet50()
resnet50(pretrained=False) # deprecated
resnet50(False); # deprecated
使用预训练模型#
在使用预训练模型之前,必须对图像进行预处理(使用正确的分辨率/插值调整大小,应用推理转换,重新缩放值等)。由于这取决于给定模型的训练方式,因此没有标准的方法。它可以在不同的模型族、变体甚至权重版本之间有所不同。使用正确的预处理方法至关重要,否则可能导致准确性下降或输出错误。
每个预训练模型的推理转换所需的所有信息都在其权重文档中提供。为了简化推理,TorchVision 将必要的预处理转换捆绑到每个模型权重中。这些可以通过 weight.transforms
属性访问:
# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
# Apply it to the input image
img_transformed = preprocess(img)
一些模型使用具有不同训练和评估行为的模块,例如批量归一化。要在这些模式之间切换,请使用 model.train()
或 model.eval()
(视情况而定)。有关详细信息,请参阅 train
或 eval
。
# Initialize model
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
# Set model to eval mode
model.eval()
列出和检索可用的模型#
从v0.14开始,TorchVision 提供了一种新的机制,允许通过名称列出和检索模型和权重。以下是一些使用示例:
# List available models
all_models = list_models()
classification_models = list_models(module=torchvision.models)
# Initialize models
m1 = get_model("mobilenet_v3_large", weights=None)
m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
# Fetch weights
weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT
weights_enum = get_model_weights("quantized_mobilenet_v3_large")
assert weights_enum == MobileNet_V3_Large_QuantizedWeights
weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
assert weights_enum == weights_enum2
使用来自 Hub 的模型#
大多数预训练模型可以通过 PyTorch Hub 直接访问,而无需安装 TorchVision:
import torch
# Option 1: passing weights param as string
model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")
# Option 2: passing weights param as enum
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)
Downloading: "https://github.com/pytorch/vision/zipball/main" to /home/ai/.cache/torch/hub/main.zip
---------------------------------------------------------------------------
RemoteDisconnected Traceback (most recent call last)
/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb 单元格 13 line 7
<a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb#X24sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a> model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")
<a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb#X24sdnNjb2RlLXJlbW90ZQ%3D%3D?line=5'>6</a> # Option 2: passing weights param as enum
----> <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb#X24sdnNjb2RlLXJlbW90ZQ%3D%3D?line=6'>7</a> weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
<a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb#X24sdnNjb2RlLXJlbW90ZQ%3D%3D?line=7'>8</a> model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)
File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/hub.py:563, in load(repo_or_dir, model, source, trust_repo, force_reload, verbose, skip_validation, *args, **kwargs)
559 raise ValueError(
560 f'Unknown source: "{source}". Allowed values: "github" | "local".')
562 if source == 'github':
--> 563 repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load",
564 verbose=verbose, skip_validation=skip_validation)
566 model = _load_local(repo_or_dir, model, *args, **kwargs)
567 return model
File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/hub.py:207, in _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose, skip_validation)
205 os.makedirs(hub_dir)
206 # Parse github repo information
--> 207 repo_owner, repo_name, ref = _parse_repo_info(github)
208 # Github allows branch name with slash '/',
209 # this causes confusion with path on both Linux and Windows.
210 # Backslash is not allowed in Github branch name so no need to
211 # to worry about it.
212 normalized_br = ref.replace('/', '_')
File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/hub.py:150, in _parse_repo_info(github)
145 if ref is None:
146 # The ref wasn't specified by the user, so we need to figure out the
147 # default branch: main or master. Our assumption is that if main exists
148 # then it's the default branch, otherwise it's master.
149 try:
--> 150 with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
151 ref = 'main'
152 except HTTPError as e:
File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:216, in urlopen(url, data, timeout, cafile, capath, cadefault, context)
214 else:
215 opener = _opener
--> 216 return opener.open(url, data, timeout)
File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:519, in OpenerDirector.open(self, fullurl, data, timeout)
516 req = meth(req)
518 sys.audit('urllib.Request', req.full_url, req.data, req.headers, req.get_method())
--> 519 response = self._open(req, data)
521 # post-process response
522 meth_name = protocol+"_response"
File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:536, in OpenerDirector._open(self, req, data)
533 return result
535 protocol = req.type
--> 536 result = self._call_chain(self.handle_open, protocol, protocol +
537 '_open', req)
538 if result:
539 return result
File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:496, in OpenerDirector._call_chain(self, chain, kind, meth_name, *args)
494 for handler in handlers:
495 func = getattr(handler, meth_name)
--> 496 result = func(*args)
497 if result is not None:
498 return result
File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:1391, in HTTPSHandler.https_open(self, req)
1390 def https_open(self, req):
-> 1391 return self.do_open(http.client.HTTPSConnection, req,
1392 context=self._context, check_hostname=self._check_hostname)
File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:1352, in AbstractHTTPHandler.do_open(self, http_class, req, **http_conn_args)
1350 except OSError as err: # timeout error
1351 raise URLError(err)
-> 1352 r = h.getresponse()
1353 except:
1354 h.close()
File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/http/client.py:1375, in HTTPConnection.getresponse(self)
1373 try:
1374 try:
-> 1375 response.begin()
1376 except ConnectionError:
1377 self.close()
File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/http/client.py:318, in HTTPResponse.begin(self)
316 # read until we get a non-100 response
317 while True:
--> 318 version, status, reason = self._read_status()
319 if status != CONTINUE:
320 break
File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/http/client.py:287, in HTTPResponse._read_status(self)
283 print("reply:", repr(line))
284 if not line:
285 # Presumably, the server closed the connection before
286 # sending a valid response.
--> 287 raise RemoteDisconnected("Remote end closed connection without"
288 " response")
289 try:
290 version, status, reason = line.split(None, 2)
RemoteDisconnected: Remote end closed connection without response
您还可以通过执行以下操作来检索特定模型的所有可用权重:
import torch
weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
print([weight for weight in weight_enum])