Cifar10 上的 ReNet18 SiLU 线性化

Cifar10 上的 ReNet18 SiLU 线性化#

本文主要介绍在 cifar10 的试验 resnet18 的性能。

导入一些必要包:

import logging
import torch
from torch.fx.node import Argument
from typing import Any
from torch import nn, fx
from torchvision.models import resnet18, ResNet18_Weights
from torch_book.vision.classifier import Classifier, evaluate_accuracy
from torch_book.datasets.cifar10 import Cifar10
torch.cuda.empty_cache() # 清空 GPU 缓存
from torch_book.transforms.cutout import Cutout
model = torch.jit.load("params/resnet18_cifar10_silu_cutout.pt")
dataset = Cifar10(root="../data", batch_size=64, num_workers=4)
train_iter = dataset.train_loader()
test_iter = dataset.val_loader()
Files already downloaded and verified
Files already downloaded and verified
valid_acc = evaluate_accuracy(model, test_iter, device=torch.device("cuda:0"))
valid_acc
0.9463