{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 计算机视觉分类\n",
"\n",
"在本教程中,您将学习如何使用迁移学习训练卷积神经网络进行图像分类。你可以在 [cs231n notes](https://cs231n.github.io/transfer-learning/) 中阅读更多关于迁移学习的信息。\n",
"\n",
"```{note}\n",
"在实践中,很少有人从零开始(使用随机初始化)训练整个卷积网络,因为拥有足够大的数据集是相对罕见的。相反,通常是在非常大的数据集(例如 ImageNet,包含 120 万张包含 1000 个类别的图像)上预训练卷积神经网络,然后使用卷积神经网络作为初始化或用于感兴趣的任务的固定特征提取器。\n",
"```\n",
"\n",
"这两种主要的迁移学习场景如下:\n",
"\n",
"1. **微调 ConvNet**:该模型不是随机初始化,而是使用预训练的网络初始化,之后的训练照常进行,但使用不同的数据集。如果有不同数量的输出,通常也替换网络中的头(或它的一部分)。在这种方法中,通常将学习速率设置为较小的数值。之所以这样做,是因为网络已经经过了训练,只需要稍作修改就可以将其“微调”到新的数据集。\n",
"\n",
"2. **ConvNet 作为固定的特征提取器**:在这里,你 [“冻结”](https://arxiv.org/abs/1706.04983) 除了最后几个层(又名“头部”,通常是全连接的层)之外,网络中所有参数的权重。这些最后的层将被替换为新层,并初始化为随机权重,只有这些层会被训练。\n",
"\n",
"你也可以将以上两种方法结合使用:首先,你可以冻结特征提取器,训练头部。在此之后,您可以解冻特征提取器(或部分特征提取器),将学习速率设置为较小的值,然后继续训练。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"from torch.backends import cudnn\n",
"from torchvision.models import resnet18\n",
"from matplotlib import pyplot as plt\n",
"\n",
"# 让程序在开始时花费一点额外时间,\n",
"# 为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。\n",
"cudnn.benchmark = True\n",
"\n",
"# 载入自定义模块\n",
"from mod import load_mod\n",
"load_mod()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{admonition} 任务\n",
"训练对蚂蚁和蜜蜂进行分类的模型。\n",
"```\n",
"\n",
"```{rubric} 加载数据\n",
"```\n",
"\n",
"通常,如果从零开始训练,这是一个非常小的可泛化的数据集。由于使用迁移学习,应该能够相当好地推广。\n",
"\n",
"本文自定义数据加载器 {class}`~pytorch_book.datasets.examples.Hymenoptera`:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from pytorch_book.datasets.examples import Hymenoptera\n",
"# 加载数据\n",
"loader = Hymenoptera()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{rubric} 模型的辅助函数\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from xinet import ModuleTool, CV"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{rubric} 设置设备\n",
"```\n",
"\n",
"设置模型训练的设备:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 微调 ConvNet\n",
"\n",
"加载预训练模型,重置最终的全连接层。"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"model_ft = resnet18(pretrained=True)\n",
"num_ftrs = model_ft.fc.in_features\n",
"# 这里每个输出样本的大小设置为 2。\n",
"# 或者,它可以推广到 ``nn.Linear(num_ftrs, len(class_names))``。\n",
"model_ft.fc = nn.Linear(num_ftrs, 2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{rubric} 训练和评估\n",
"```\n",
"\n",
"在 CPU 上应该需要大约 15-25 分钟。而在 GPU 上,这只需要不到一分钟的时间。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss 0.333, train acc 0.865, test acc 0.922\n",
"35.0 examples/sec on cpu\n"
]
},
{
"data": {
"image/svg+xml": "\n\n\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"train_iter = loader.dataloaders['train']\n",
"test_iter = loader.dataloaders['val']\n",
"CV.train_fine_tuning(model_ft, train_iter, test_iter,\n",
" learning_rate=1e-3,\n",
" num_epochs=25,\n",
" param_group=False)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": "\n\n\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"for xs, _ in loader.dataloaders['val']:\n",
" break\n",
"\n",
"MT = ModuleTool(xs)\n",
"MT.imshow(model_ft, loader.class_names, device,\n",
" num_rows=1, num_cols=4, scale=2);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ConvNet 作为固定的特征提取器\n",
"\n",
"在这里,需要冻结除了最后一层之外的所有网络。需要设置 ``requires_grad = False`` 来冻结参数,这样梯度就不会在 ``backward()`` 中计算出来。\n",
"\n",
"也可以阅读 [excluding-subgraphs-from-backward](https://pytorch.org/docs/notes/autograd.html#excluding-subgraphs-from-backward) 了解更多内容。\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"model_conv = resnet18(pretrained=True)\n",
"for param in model_conv.parameters():\n",
" param.requires_grad = False\n",
"\n",
"# 新构造模块的参数默认为 requires_grad=True\n",
"num_ftrs = model_conv.fc.in_features\n",
"model_conv.fc = nn.Linear(num_ftrs, 2)\n",
"\n",
"model_conv = model_conv.to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{rubric} 训练和评估\n",
"```\n",
"\n",
"在 CPU 上,与之前的场景相比,这将花费大约一半的时间。这是预期的,因为梯度不需要计算的大多数网络。但是,forward 确实需要计算。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss 0.392, train acc 0.857, test acc 0.954\n",
"102.5 examples/sec on cpu\n"
]
},
{
"data": {
"image/svg+xml": "\n\n\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"CV.train_fine_tuning(model_conv, train_iter, test_iter,\n",
" learning_rate=1e-3,\n",
" num_epochs=25,\n",
" param_group=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{rubric} 可视化\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/svg+xml": "\n\n\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"for xs, _ in loader.dataloaders['val']:\n",
" break\n",
"\n",
"MT = ModuleTool(xs)\n",
"MT.imshow(model_ft, loader.class_names, device,\n",
" num_rows=1, num_cols=4, scale=2)\n",
"plt.ioff()\n",
"plt.show()"
]
}
],
"metadata": {
"interpreter": {
"hash": "7a45eadec1f9f49b0fdfd1bc7d360ac982412448ce738fa321afc640e3212175"
},
"kernelspec": {
"display_name": "Python 3.10.4 ('torchx')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 0
}