{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 测试"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt\n",
"import torch\n",
"\n",
"from mod import load_mod\n",
"\n",
"plt.ion()\n",
"# 载入自定义模块\n",
"load_mod()\n",
"\n",
"from xinet import CV"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/pc/xinet/anaconda3/envs/ai/lib/python3.9/site-packages/torch/ao/quantization/observer.py:172: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
]
}
],
"source": [
"batch_size = 128\n",
"train_iter, test_iter = CV.load_data_cifar10(batch_size=batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss 0.008, train acc 0.997, test acc 0.756\n",
"814.4 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1), device(type='cuda', index=2)]\n"
]
},
{
"data": {
"image/svg+xml": "\n\n\n",
"text/plain": [
"