TVM#

from matplotlib import pyplot as plt
import torch

from mod import load_mod

plt.ion()
# 载入自定义模块
load_mod()

from xinet import CV
import torch
from torchvision.models import quantization as models
from torch import nn

def create_combined_model(model_fe):
    # 步骤1:分离特征提取器
    model_fe_features = nn.Sequential(
        model_fe.quant,  # 量化 input
        model_fe.conv1,
        model_fe.bn1,
        model_fe.relu,
        model_fe.maxpool,
        model_fe.layer1,
        model_fe.layer2,
        model_fe.layer3,
        model_fe.layer4,
        model_fe.avgpool,
        model_fe.dequant,  # 反量化 output
    )

    # 步骤2:创建一个新的“头”
    new_head = nn.Sequential(
        nn.Dropout(p=0.5),
        nn.Linear(num_ftrs, 10),
    )

    # 步骤3:合并,不要忘记量化 stubs
    new_model = nn.Sequential(
        model_fe_features,
        nn.Flatten(1),
        new_head,
    )
    return new_model

batch_size = 128
train_iter, test_iter = CV.load_data_cifar10(batch_size=batch_size)
Files already downloaded and verified
Files already downloaded and verified
import numpy as np
import time
import torch

import sys
TVM_HOME = '/media/pc/data/4tb/xinet/tvm'
sys.path.extend([f'{TVM_HOME}/python', f'{TVM_HOME}/vta/python'])

import tvm
from tvm import relay


def find_topk(array, k, axis=-1, largest=True, sorted=True):
    if axis is None:
        axis_size = array.size
    else:
        axis_size = array.shape[axis]
    assert 1 <= k <= axis_size

    array = np.asanyarray(array)
    if largest:
        index_array = np.argpartition(array, axis_size-k, axis=axis)
        topk_indices = np.take(index_array, -np.arange(k)-1, axis=axis)
    else:
        index_array = np.argpartition(array, k-1, axis=axis)
        topk_indices = np.take(index_array, np.arange(k), axis=axis)
    topk_values = np.take_along_axis(array, topk_indices, axis=axis)
    if sorted:
        sorted_indices_in_topk = np.argsort(topk_values, axis=axis)
        if largest:
            sorted_indices_in_topk = np.flip(sorted_indices_in_topk, axis=axis)
        sorted_topk_values = np.take_along_axis(
            topk_values, sorted_indices_in_topk, axis=axis)
        sorted_topk_indices = np.take_along_axis(
            topk_indices, sorted_indices_in_topk, axis=axis)
        return sorted_topk_values, sorted_topk_indices
    return topk_values, topk_indices


def calibrate_dataset(val_loader, calibration_samples, batch_size):
    for i, (input, _) in enumerate(val_loader):
        if i * batch_size >= calibration_samples:
            break
        yield {"input": input.numpy()}


def quantize(mod, params, data_aware, val_loader, calibration_samples=500, batch_size=1):
    if data_aware:
        print("tvm calibration quantize begin---------------------------->>")
        with relay.quantize.qconfig(calibrate_mode="kl_divergence", weight_scale="max", skip_conv_layers=[0], skip_dense_layer=True):
            mod = relay.quantize.quantize(
                mod, params, dataset=calibrate_dataset(val_loader, calibration_samples, batch_size))
            #print(mod)
        print("tvm calibration quantize end---------------------------->>")
    else:
        print("tvm global scale quantize begin---------------------------->>")
        with relay.quantize.qconfig(calibrate_mode="global_scale", global_scale=8.0):
            mod = relay.quantize.quantize(mod, params)
        print("tvm global scale quantize end---------------------------->>")
    return mod


def run_tvm_model(mod, params, target="llvm"):
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target=target, params=params)
    runtime = tvm.contrib.graph_executor.GraphModule(
        lib["default"](tvm.device(target, 0)))
    return runtime


def tvm_model(model, batch_size):
    input_shape = (batch_size, 3, 32, 32)
    shape_list = [("input", input_shape)]
    input_data = torch.randn(input_shape)
    scripted_model = torch.jit.trace(model, input_data).eval()
    mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
    return mod, params


def tvm_test(model, val_loader,
             batch_size, data_aware,
             calibration_samples=500,
             print_freq=100,
             pre_quantization=False):
    mod, params = tvm_model(model, batch_size)
    if not pre_quantization:
        mod = quantize(mod, params, data_aware, val_loader,
                       calibration_samples=calibration_samples,
                       batch_size=batch_size)
    runtime = run_tvm_model(mod, params)
    #print(runtime.benchmark(dev, number=1, repeat=100))

    test_nums = len(val_loader)
    top1_correct = 0
    top5_correct = 0
    print('llvm inference-------------->>')
    for i, (input, label) in enumerate(val_loader, 1):
        runtime.set_input('input', input)
        runtime.run()
        output = runtime.get_output(0).asnumpy()

        # find topk index
        _, preds = find_topk(output, 5)
        print(preds, label)
        if label.item() == preds[0][0]:
            top1_correct += 1

        if label.item() in preds[0]:
            top5_correct += 1

        if i % print_freq == 0:
            print('Test: [{}/{}] \t'
                  'Acc@1 {:.4f} \t'
                  'Acc@5 {:.4f}'.format(
                      i, test_nums, top1_correct / i, top5_correct / i))

    top1 = top1_correct / test_nums
    top5 = top5_correct / test_nums
    print(' * Acc@1 {:.4f} Acc@5 {:.4f}'
          .format(top1, top5))

    time_start = time.time()
    repeat = 100
    for i, (input, label) in enumerate(val_loader, 1):
        for r in range(repeat):
            runtime.set_input('input', input)
            runtime.run()
            output = runtime.get_output(0).asnumpy()
        time_end = time.time()
        print("平均推理时间:", (time_end - time_start)/repeat)
        exit()
# 注意 `quantize=False`
model = models.resnet18(pretrained=True, progress=True, quantize=False)
num_ftrs = model.fc.in_features

# Step 1
model.train()
# model.fuse_model()
# Step 2
model_ft = create_combined_model(model)

for param in model_ft.parameters():
  param.requires_grad = True


tvm_test(model_ft, test_iter,
         batch_size=1, data_aware=True,
         calibration_samples=500,
         print_freq=100,
         pre_quantization=False)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb Cell 5' in <cell line: 15>()
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000004vscode-remote?line=10'>11</a> for param in model_ft.parameters():
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000004vscode-remote?line=11'>12</a>   param.requires_grad = True
---> <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000004vscode-remote?line=14'>15</a> tvm_test(model_ft, test_iter,
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000004vscode-remote?line=15'>16</a>          batch_size=1, data_aware=True,
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000004vscode-remote?line=16'>17</a>          calibration_samples=500,
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000004vscode-remote?line=17'>18</a>          print_freq=100,
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000004vscode-remote?line=18'>19</a>          pre_quantization=False)

/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb Cell 4' in tvm_test(model, val_loader, batch_size, data_aware, calibration_samples, print_freq, pre_quantization)
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000003vscode-remote?line=79'>80</a> def tvm_test(model, val_loader,
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000003vscode-remote?line=80'>81</a>              batch_size, data_aware,
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000003vscode-remote?line=81'>82</a>              calibration_samples=500,
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000003vscode-remote?line=82'>83</a>              print_freq=100,
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000003vscode-remote?line=83'>84</a>              pre_quantization=False):
---> <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000003vscode-remote?line=84'>85</a>     mod, params = tvm_model(model, batch_size)
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000003vscode-remote?line=85'>86</a>     if not pre_quantization:
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000003vscode-remote?line=86'>87</a>         mod = quantize(mod, params, data_aware, val_loader,
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000003vscode-remote?line=87'>88</a>                        calibration_samples=calibration_samples,
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000003vscode-remote?line=88'>89</a>                        batch_size=batch_size)

/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb Cell 4' in tvm_model(model, batch_size)
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000003vscode-remote?line=72'>73</a> shape_list = [("input", input_shape)]
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000003vscode-remote?line=73'>74</a> input_data = torch.randn(input_shape)
---> <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000003vscode-remote?line=74'>75</a> scripted_model = torch.jit.trace(model, input_data).eval()
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000003vscode-remote?line=75'>76</a> mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
     <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/4tb/xinet/web/quantization/docs/study/transfer-learning/tvm.ipynb#ch0000003vscode-remote?line=76'>77</a> return mod, params

File ~/.local/lib/python3.8/site-packages/torch/jit/_trace.py:733, in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=729'>730</a>     return func
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=731'>732</a> if isinstance(func, torch.nn.Module):
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=732'>733</a>     return trace_module(
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=733'>734</a>         func,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=734'>735</a>         {"forward": example_inputs},
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=735'>736</a>         None,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=736'>737</a>         check_trace,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=737'>738</a>         wrap_check_inputs(check_inputs),
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=738'>739</a>         check_tolerance,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=739'>740</a>         strict,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=740'>741</a>         _force_outplace,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=741'>742</a>         _module_class,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=742'>743</a>     )
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=744'>745</a> if (
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=745'>746</a>     hasattr(func, "__self__")
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=746'>747</a>     and isinstance(func.__self__, torch.nn.Module)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=747'>748</a>     and func.__name__ == "forward"
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=748'>749</a> ):
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=749'>750</a>     return trace_module(
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=750'>751</a>         func.__self__,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=751'>752</a>         {"forward": example_inputs},
   (...)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=758'>759</a>         _module_class,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=759'>760</a>     )

File ~/.local/lib/python3.8/site-packages/torch/jit/_trace.py:934, in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=931'>932</a> func = mod if method_name == "forward" else getattr(mod, method_name)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=932'>933</a> example_inputs = make_tuple(example_inputs)
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=933'>934</a> module._c._create_method_from_trace(
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=934'>935</a>     method_name,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=935'>936</a>     func,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=936'>937</a>     example_inputs,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=937'>938</a>     var_lookup_fn,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=938'>939</a>     strict,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=939'>940</a>     _force_outplace,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=940'>941</a> )
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=941'>942</a> check_trace_method = module._c._get_method(method_name)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/jit/_trace.py?line=943'>944</a> # Check the trace against new traces created from user-specified inputs

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:887, in Module._call_impl(self, *input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=883'>884</a>     input = bw_hook.setup_input_hook(input)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=885'>886</a> if torch._C._get_tracing_state():
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=886'>887</a>     result = self._slow_forward(*input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=887'>888</a> else:
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=888'>889</a>     result = self.forward(*input, **kwargs)

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:860, in Module._slow_forward(self, *input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=857'>858</a>         recording_scopes = False
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=858'>859</a> try:
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=859'>860</a>     result = self.forward(*input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=860'>861</a> finally:
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=861'>862</a>     if recording_scopes:

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/container.py:119, in Sequential.forward(self, input)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/container.py?line=116'>117</a> def forward(self, input):
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/container.py?line=117'>118</a>     for module in self:
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/container.py?line=118'>119</a>         input = module(input)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/container.py?line=119'>120</a>     return input

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:887, in Module._call_impl(self, *input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=883'>884</a>     input = bw_hook.setup_input_hook(input)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=885'>886</a> if torch._C._get_tracing_state():
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=886'>887</a>     result = self._slow_forward(*input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=887'>888</a> else:
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=888'>889</a>     result = self.forward(*input, **kwargs)

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:860, in Module._slow_forward(self, *input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=857'>858</a>         recording_scopes = False
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=858'>859</a> try:
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=859'>860</a>     result = self.forward(*input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=860'>861</a> finally:
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=861'>862</a>     if recording_scopes:

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/container.py:119, in Sequential.forward(self, input)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/container.py?line=116'>117</a> def forward(self, input):
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/container.py?line=117'>118</a>     for module in self:
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/container.py?line=118'>119</a>         input = module(input)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/container.py?line=119'>120</a>     return input

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:887, in Module._call_impl(self, *input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=883'>884</a>     input = bw_hook.setup_input_hook(input)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=885'>886</a> if torch._C._get_tracing_state():
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=886'>887</a>     result = self._slow_forward(*input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=887'>888</a> else:
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=888'>889</a>     result = self.forward(*input, **kwargs)

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:860, in Module._slow_forward(self, *input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=857'>858</a>         recording_scopes = False
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=858'>859</a> try:
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=859'>860</a>     result = self.forward(*input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=860'>861</a> finally:
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=861'>862</a>     if recording_scopes:

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/container.py:119, in Sequential.forward(self, input)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/container.py?line=116'>117</a> def forward(self, input):
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/container.py?line=117'>118</a>     for module in self:
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/container.py?line=118'>119</a>         input = module(input)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/container.py?line=119'>120</a>     return input

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:887, in Module._call_impl(self, *input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=883'>884</a>     input = bw_hook.setup_input_hook(input)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=885'>886</a> if torch._C._get_tracing_state():
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=886'>887</a>     result = self._slow_forward(*input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=887'>888</a> else:
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=888'>889</a>     result = self.forward(*input, **kwargs)

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:860, in Module._slow_forward(self, *input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=857'>858</a>         recording_scopes = False
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=858'>859</a> try:
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=859'>860</a>     result = self.forward(*input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=860'>861</a> finally:
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=861'>862</a>     if recording_scopes:

File ~/.local/lib/python3.8/site-packages/torchvision/models/quantization/resnet.py:32, in QuantizableBasicBlock.forward(self, x)
     <a href='file:///home/pc/.local/lib/python3.8/site-packages/torchvision/models/quantization/resnet.py?line=28'>29</a> identity = x
     <a href='file:///home/pc/.local/lib/python3.8/site-packages/torchvision/models/quantization/resnet.py?line=30'>31</a> out = self.conv1(x)
---> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torchvision/models/quantization/resnet.py?line=31'>32</a> out = self.bn1(out)
     <a href='file:///home/pc/.local/lib/python3.8/site-packages/torchvision/models/quantization/resnet.py?line=32'>33</a> out = self.relu(out)
     <a href='file:///home/pc/.local/lib/python3.8/site-packages/torchvision/models/quantization/resnet.py?line=34'>35</a> out = self.conv2(out)

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:887, in Module._call_impl(self, *input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=883'>884</a>     input = bw_hook.setup_input_hook(input)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=885'>886</a> if torch._C._get_tracing_state():
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=886'>887</a>     result = self._slow_forward(*input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=887'>888</a> else:
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=888'>889</a>     result = self.forward(*input, **kwargs)

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:860, in Module._slow_forward(self, *input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=857'>858</a>         recording_scopes = False
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=858'>859</a> try:
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=859'>860</a>     result = self.forward(*input, **kwargs)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=860'>861</a> finally:
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/module.py?line=861'>862</a>     if recording_scopes:

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py:135, in _BatchNorm.forward(self, input)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py?line=132'>133</a> assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py?line=133'>134</a> assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
--> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py?line=134'>135</a> return F.batch_norm(
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py?line=135'>136</a>     input,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py?line=136'>137</a>     # If buffers are not to be tracked, ensure that they won't be updated
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py?line=137'>138</a>     self.running_mean if not self.training or self.track_running_stats else None,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py?line=138'>139</a>     self.running_var if not self.training or self.track_running_stats else None,
    <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py?line=139'>140</a>     self.weight, self.bias, bn_training, exponential_average_factor, self.eps)

File ~/.local/lib/python3.8/site-packages/torch/nn/functional.py:2147, in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
   <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/functional.py?line=2133'>2134</a>     return handle_torch_function(
   <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/functional.py?line=2134'>2135</a>         batch_norm,
   <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/functional.py?line=2135'>2136</a>         (input,),
   (...)
   <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/functional.py?line=2143'>2144</a>         eps=eps,
   <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/functional.py?line=2144'>2145</a>     )
   <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/functional.py?line=2145'>2146</a> if training:
-> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/functional.py?line=2146'>2147</a>     _verify_batch_size(input.size())
   <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/functional.py?line=2148'>2149</a> return torch.batch_norm(
   <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/functional.py?line=2149'>2150</a>     input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled
   <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/functional.py?line=2150'>2151</a> )

File ~/.local/lib/python3.8/site-packages/torch/nn/functional.py:2114, in _verify_batch_size(size)
   <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/functional.py?line=2111'>2112</a>     size_prods *= size[i + 2]
   <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/functional.py?line=2112'>2113</a> if size_prods == 1:
-> <a href='file:///home/pc/.local/lib/python3.8/site-packages/torch/nn/functional.py?line=2113'>2114</a>     raise ValueError("Expected more than 1 value per channel when training, got input size {}".format(size))

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 512, 1, 1])