{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 特定于 cifar10 的量化(待更)\n", "\n", "针对 cifar10 重写网络:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "'''MobileNetV2 in PyTorch.\n", "\n", "See the paper \"Inverted Residuals and Linear Bottlenecks:\n", "Mobile Networks for Classification, Detection and Segmentation\" for more details.\n", "'''\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "\n", "class Block(nn.Module):\n", " '''expand + depthwise + pointwise'''\n", "\n", " def __init__(self, in_planes, out_planes, expansion, stride):\n", " super().__init__()\n", " self.stride = stride\n", "\n", " planes = expansion * in_planes\n", " self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)\n", " self.bn1 = nn.BatchNorm2d(planes)\n", " self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,\n", " stride=stride, padding=1, groups=planes, bias=False)\n", " self.bn2 = nn.BatchNorm2d(planes)\n", " self.conv3 = nn.Conv2d(\n", " planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)\n", " self.bn3 = nn.BatchNorm2d(out_planes)\n", "\n", " self.shortcut = nn.Sequential()\n", " if stride == 1 and in_planes != out_planes:\n", " self.shortcut = nn.Sequential(\n", " nn.Conv2d(in_planes, out_planes, kernel_size=1,\n", " stride=1, padding=0, bias=False),\n", " nn.BatchNorm2d(out_planes),\n", " )\n", "\n", " def forward(self, x):\n", " out = F.relu(self.bn1(self.conv1(x)))\n", " out = F.relu(self.bn2(self.conv2(out)))\n", " out = self.bn3(self.conv3(out))\n", " out = out + self.shortcut(x) if self.stride == 1 else out\n", " return out\n", "\n", "\n", "class MobileNetV2(nn.Module):\n", " # (expansion, out_planes, num_blocks, stride)\n", " cfg = [(1, 16, 1, 1),\n", " (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10\n", " (6, 32, 3, 2),\n", " (6, 64, 4, 2),\n", " (6, 96, 3, 1),\n", " (6, 160, 3, 2),\n", " (6, 320, 1, 1)]\n", "\n", " def __init__(self, num_classes=10):\n", " super().__init__()\n", " # NOTE: change conv1 stride 2 -> 1 for CIFAR10\n", " self.conv1 = nn.Conv2d(3, 32, kernel_size=3,\n", " stride=1, padding=1, bias=False)\n", " self.bn1 = nn.BatchNorm2d(32)\n", " self.layers = self._make_layers(in_planes=32)\n", " self.conv2 = nn.Conv2d(320, 1280, kernel_size=1,\n", " stride=1, padding=0, bias=False)\n", " self.bn2 = nn.BatchNorm2d(1280)\n", " self.linear = nn.Linear(1280, num_classes)\n", "\n", " def _make_layers(self, in_planes):\n", " layers = []\n", " for expansion, out_planes, num_blocks, stride in self.cfg:\n", " strides = [stride] + [1]*(num_blocks-1)\n", " for stride in strides:\n", " layers.append(Block(in_planes, out_planes, expansion, stride))\n", " in_planes = out_planes\n", " return nn.Sequential(*layers)\n", "\n", " def forward(self, x):\n", " out = F.relu(self.bn1(self.conv1(x)))\n", " out = self.layers(out)\n", " out = F.relu(self.bn2(self.conv2(out)))\n", " # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10\n", " out = F.avg_pool2d(out, 4)\n", " out = out.view(out.size(0), -1)\n", " out = self.linear(out)\n", " return out\n", "\n", "\n", "def test():\n", " net = MobileNetV2()\n", " x = torch.randn(2, 3, 32, 32)\n", " y = net(x)\n", " print(y.size())" ] } ], "metadata": { "interpreter": { "hash": "78526419bf48930935ba7e23437b2460cb231485716b036ebb8701887a294fa8" }, "kernelspec": { "display_name": "Python 3.10.0 ('torchx')", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }