{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 自定义" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from mod import load_mod\n", "# 载入自定义模块\n", "load_mod()\n", "\n", "from pytorch_book.quantization.qat import QuantizableCustom\n", "from xinet import CV" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "qat = QuantizableCustom('resnet18')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "model = qat.qconfig(quantize=True,\n", " pretrained=True,\n", " progress=True)\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] }, { "ename": "AttributeError", "evalue": "'QuantizableResNet' object has no attribute 'quantize_qat'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m/media/pc/data/4tb/xinet/web/pytorch-book/docs/quantization/study/transfer-learning/custom.ipynb Cell 5'\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m batch_size \u001b[39m=\u001b[39m \u001b[39m128\u001b[39m\n\u001b[1;32m 2\u001b[0m train_iter, test_iter \u001b[39m=\u001b[39m CV\u001b[39m.\u001b[39mload_data_cifar10(batch_size\u001b[39m=\u001b[39mbatch_size)\n\u001b[0;32m----> 3\u001b[0m quantized_model \u001b[39m=\u001b[39m model\u001b[39m.\u001b[39;49mquantize_qat(model,\n\u001b[1;32m 4\u001b[0m run_fn\u001b[39m=\u001b[39mCV\u001b[39m.\u001b[39mtrain_fine_tuning,\n\u001b[1;32m 5\u001b[0m run_args\u001b[39m=\u001b[39m[train_iter, test_iter],\n\u001b[1;32m 6\u001b[0m run_kwargs\u001b[39m=\u001b[39m{\n\u001b[1;32m 7\u001b[0m \u001b[39m'\u001b[39m\u001b[39mlearning_rate\u001b[39m\u001b[39m'\u001b[39m: \u001b[39m1e-3\u001b[39m,\n\u001b[1;32m 8\u001b[0m \u001b[39m'\u001b[39m\u001b[39mnum_epochs\u001b[39m\u001b[39m'\u001b[39m: \u001b[39m100\u001b[39m,\n\u001b[1;32m 9\u001b[0m },\n\u001b[1;32m 10\u001b[0m inplace\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m\n\u001b[1;32m 11\u001b[0m )\n", "File \u001b[0;32m~/xinet/anaconda3/envs/torchx/lib/python3.10/site-packages/torch/nn/modules/module.py:1185\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1183\u001b[0m \u001b[39mif\u001b[39;00m name \u001b[39min\u001b[39;00m modules:\n\u001b[1;32m 1184\u001b[0m \u001b[39mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1185\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m object has no attribute \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 1186\u001b[0m \u001b[39mtype\u001b[39m(\u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, name))\n", "\u001b[0;31mAttributeError\u001b[0m: 'QuantizableResNet' object has no attribute 'quantize_qat'" ] } ], "source": [ "batch_size = 128\n", "train_iter, test_iter = CV.load_data_cifar10(batch_size=batch_size)\n", "quantized_model = model.quantize_qat(model,\n", " run_fn=CV.train_fine_tuning,\n", " run_args=[train_iter, test_iter],\n", " run_kwargs={\n", " 'learning_rate': 1e-3,\n", " 'num_epochs': 100,\n", " },\n", " inplace=False\n", " )\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "78526419bf48930935ba7e23437b2460cb231485716b036ebb8701887a294fa8" }, "kernelspec": { "display_name": "Python 3.10.0 ('torchx')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }