{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 测试" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from matplotlib import pyplot as plt\n", "import torch\n", "\n", "from mod import load_mod\n", "\n", "plt.ion()\n", "# 载入自定义模块\n", "load_mod()\n", "\n", "from xinet import CV" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/pc/xinet/anaconda3/envs/ai/lib/python3.9/site-packages/torch/ao/quantization/observer.py:172: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] } ], "source": [ "batch_size = 128\n", "train_iter, test_iter = CV.load_data_cifar10(batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.008, train acc 0.997, test acc 0.756\n", "814.4 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1), device(type='cuda', index=2)]\n" ] }, { "data": { "image/svg+xml": "\n\n\n \n \n \n \n 2022-03-10T16:33:07.659965\n image/svg+xml\n \n \n Matplotlib v3.4.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "CV.train_fine_tuning(model_ft, train_iter, test_iter,\n", " learning_rate=1e-3,\n", " num_epochs=100,\n", " param_group=False)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from torch.quantization import convert\n", "\n", "model_ft.cpu()\n", "model_convert = convert(model_ft, inplace=False)\n", "torch.save(model_ft.state_dict(), '../models/cifar10.pt')\n", "torch.save(model_convert.state_dict(), '../models/cifar10_convert.pt')" ] } ], "metadata": { "interpreter": { "hash": "61447c3ddb95e77cf825d46d6f70227c36ed08aec1baab023b2c78b109ed3829" }, "kernelspec": { "display_name": "Python 3.9.7 ('ai')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }