模型剪枝#

原作者: Michela Paganini => pruning_tutorial

最先进的深度学习技术依赖于难以部署的过度参数化模型(over-parametrized models)。相反,已知生物神经网络使用高效的稀疏连接(sparse connectivity)。为了在不牺牲精度的情况下减少内存、电池和硬件的消耗,在设备上部署轻量级模型,并通过私有设备上的计算保证隐私性,确定通过减少模型中的参数数量来压缩模型的最佳技术是很重要的。在研究方面,剪枝(pruning)被用于研究过度参数化(over-parametrized)和欠参数化(under-parametrized)网络之间学习动态的差异,研究 lucky 稀疏子网络和初始化(“lottery tickets”)作为破坏性(destructive)神经结构搜索技术的作用,等等。

目标

学习如何使用 torch.nn.utils.prune 来稀疏化您的神经网络,以及如何扩展它来实现您自定义剪枝技术。

import torch
from torch import nn
from torch.nn.utils import prune
import torch.nn.functional as F

构建模型#

下面以 LeNet([Lecun et al., 1998])为例子。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

检查 Module#

检查 LeNet 模型中的(未修剪的)conv1 层。它将包含两个参数 weightbias,目前没有缓冲区(buffers)。

module = model.conv1
print(list(module.named_parameters()))
[('weight', Parameter containing:
tensor([[[[-0.2908, -0.3297,  0.3301],
          [-0.1059,  0.3224, -0.1656],
          [-0.3119,  0.0924,  0.2647]]],


        [[[ 0.0005,  0.0149,  0.1317],
          [ 0.0265,  0.2909, -0.2732],
          [-0.1525, -0.0275, -0.0561]]],


        [[[-0.2313,  0.3281, -0.2581],
          [ 0.1683, -0.0615, -0.2187],
          [-0.1147, -0.0558, -0.0907]]],


        [[[ 0.1100, -0.0474,  0.1916],
          [-0.2361,  0.3031, -0.2396],
          [-0.2578,  0.2026,  0.2532]]],


        [[[ 0.0928,  0.2640,  0.1735],
          [-0.1389,  0.0455, -0.3115],
          [ 0.1367,  0.1075,  0.2437]]],


        [[[-0.0152,  0.1968,  0.3237],
          [ 0.2488,  0.2891,  0.0444],
          [ 0.0297,  0.0734, -0.0335]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.2374, -0.3188, -0.0395,  0.1943,  0.2974,  0.0997], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[]

剪枝 Module#

要剪枝 Module(在本例中是 LeNet 架构的 conv1 层),首先从 torch.nn.utils.prune 中选择一种剪枝技术(或者通过子类化 BasePruningMethod 实现自己的剪枝技术)。然后,指定要在该 module 中删除的 module 和参数的名称。最后,使用所选剪枝技术所需的适当关键字参数,指定剪枝参数。

在本例中,将在 conv1 层中随机删除名为 weight 的参数中的 \(30\%\) 的连接。module 作为函数的第一个参数传递;name 使用它的字符串标识符标识 module 中的参数;amount 表示要修剪的连接的百分比(如果是 0.1. 之间的浮点数),或要修剪的连接的绝对数量(如果它是非负整数)。

prune.random_unstructured(module, name="weight", amount=0.3)
Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

修剪的方法是从参数中移除 weight,并用名为 weight_orig 的新参数替换它(即在初始参数 name 后追加 "_orig")。weight_orig 存储了张量的未修剪版本。bias 没有被剪除,所以它将保持不变。

print(list(module.named_parameters()))
[('bias', Parameter containing:
tensor([-0.2374, -0.3188, -0.0395,  0.1943,  0.2974,  0.0997], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.2908, -0.3297,  0.3301],
          [-0.1059,  0.3224, -0.1656],
          [-0.3119,  0.0924,  0.2647]]],


        [[[ 0.0005,  0.0149,  0.1317],
          [ 0.0265,  0.2909, -0.2732],
          [-0.1525, -0.0275, -0.0561]]],


        [[[-0.2313,  0.3281, -0.2581],
          [ 0.1683, -0.0615, -0.2187],
          [-0.1147, -0.0558, -0.0907]]],


        [[[ 0.1100, -0.0474,  0.1916],
          [-0.2361,  0.3031, -0.2396],
          [-0.2578,  0.2026,  0.2532]]],


        [[[ 0.0928,  0.2640,  0.1735],
          [-0.1389,  0.0455, -0.3115],
          [ 0.1367,  0.1075,  0.2437]]],


        [[[-0.0152,  0.1968,  0.3237],
          [ 0.2488,  0.2891,  0.0444],
          [ 0.0297,  0.0734, -0.0335]]]], device='cuda:0', requires_grad=True))]

由上述选择的剪枝技术生成的剪枝掩码被保存为名为 weight_mask 的模块缓冲区(即,在初始参数 name 后追加 "_mask")。

print(list(module.named_buffers()))
[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 0.],
          [0., 1., 1.],
          [0., 1., 0.]]],


        [[[0., 0., 1.],
          [0., 1., 1.],
          [1., 1., 0.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]], device='cuda:0'))]

为了使前向传播在不修改的情况下正常工作,需要存在名为 weight 的属性。在 torch.nn.utils.prune 中实现的剪枝技术通过将掩码与原始参数结合来计算剪枝后的权重,并将它们存储在属性 weight 中。请注意,这不再是模块的参数,现在它只是一个属性。

print(module.weight)
tensor([[[[-0.0000, -0.3297,  0.0000],
          [-0.1059,  0.3224, -0.1656],
          [-0.3119,  0.0924,  0.2647]]],


        [[[ 0.0005,  0.0000,  0.1317],
          [ 0.0000,  0.0000, -0.2732],
          [-0.1525, -0.0275, -0.0561]]],


        [[[-0.2313,  0.0000, -0.0000],
          [ 0.0000, -0.0615, -0.2187],
          [-0.0000, -0.0558, -0.0000]]],


        [[[ 0.0000, -0.0000,  0.1916],
          [-0.0000,  0.3031, -0.2396],
          [-0.2578,  0.2026,  0.0000]]],


        [[[ 0.0928,  0.2640,  0.1735],
          [-0.1389,  0.0455, -0.3115],
          [ 0.1367,  0.1075,  0.0000]]],


        [[[-0.0152,  0.0000,  0.3237],
          [ 0.2488,  0.2891,  0.0444],
          [ 0.0297,  0.0734, -0.0335]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

最后,在每次前向传播之前,使用 PyTorch 的 forward_pre_hooks 应用剪枝。具体来说,当模块被剪枝时(就像我们在这里所做的那样),它会为与之关联的每个参数获取一个 forward_pre_hook。在这种情况下,由于到目前为止我们只剪枝了名为 weight 的原始参数,因此只会存在一个钩子。

print(module._forward_pre_hooks)
OrderedDict({0: <torch.nn.utils.prune.RandomUnstructured object at 0x7fc57b5a4680>})

为了完整性,我们现在也可以剪枝 bias,以观察模块的参数、缓冲区、钩子和属性如何变化。仅仅为了尝试另一种剪枝技术,在这里我们通过 L1 范数剪枝偏置中的 3 个最小项,正如在 l1_unstructured 剪枝函数中实现的那样。

prune.l1_unstructured(module, name="bias", amount=3)
Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

我们现在期望命名参数包括 weight_orig (之前的)和 bias_orig。缓冲区将包括 weight_maskbias_mask。两个张量的剪枝版本将作为模块属性存在,模块现在将有两个 forward_pre_hooks

print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[-0.2908, -0.3297,  0.3301],
          [-0.1059,  0.3224, -0.1656],
          [-0.3119,  0.0924,  0.2647]]],


        [[[ 0.0005,  0.0149,  0.1317],
          [ 0.0265,  0.2909, -0.2732],
          [-0.1525, -0.0275, -0.0561]]],


        [[[-0.2313,  0.3281, -0.2581],
          [ 0.1683, -0.0615, -0.2187],
          [-0.1147, -0.0558, -0.0907]]],


        [[[ 0.1100, -0.0474,  0.1916],
          [-0.2361,  0.3031, -0.2396],
          [-0.2578,  0.2026,  0.2532]]],


        [[[ 0.0928,  0.2640,  0.1735],
          [-0.1389,  0.0455, -0.3115],
          [ 0.1367,  0.1075,  0.2437]]],


        [[[-0.0152,  0.1968,  0.3237],
          [ 0.2488,  0.2891,  0.0444],
          [ 0.0297,  0.0734, -0.0335]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.2374, -0.3188, -0.0395,  0.1943,  0.2974,  0.0997], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 0.],
          [0., 1., 1.],
          [0., 1., 0.]]],


        [[[0., 0., 1.],
          [0., 1., 1.],
          [1., 1., 0.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]], device='cuda:0')), ('bias_mask', tensor([1., 1., 0., 0., 1., 0.], device='cuda:0'))]
print(module.bias)
tensor([-0.2374, -0.3188, -0.0000,  0.0000,  0.2974,  0.0000], device='cuda:0',
       grad_fn=<MulBackward0>)
print(module._forward_pre_hooks)
OrderedDict({0: <torch.nn.utils.prune.RandomUnstructured object at 0x7fc57b5a4680>, 1: <torch.nn.utils.prune.L1Unstructured object at 0x7fc5660c6690>})

迭代剪枝#

模块中的同一个参数可以被多次剪枝,各个剪枝调用的效果等同于依次应用的各个掩码的组合。 新的掩码与旧的掩码的组合由 PruningContainercompute_mask 方法处理。

例如,假设我们现在想进一步剪枝 module.weight,这次使用结构化剪枝沿着张量的0轴(0轴对应于卷积层的输出通道,对于 conv1 来说维度为6),基于通道的 L2 范数。这可以通过使用 ln_structured 函数,设置 n=2dim=0 来实现。

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# 我们可以验证,这将使与50%(3/6)的通道对应的所有连接置零,同时保留之前掩码的作用。
print(module.weight)
tensor([[[[-0.0000, -0.3297,  0.0000],
          [-0.1059,  0.3224, -0.1656],
          [-0.3119,  0.0924,  0.2647]]],


        [[[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000]]],


        [[[-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000]]],


        [[[ 0.0000, -0.0000,  0.1916],
          [-0.0000,  0.3031, -0.2396],
          [-0.2578,  0.2026,  0.0000]]],


        [[[ 0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000]]],


        [[[-0.0152,  0.0000,  0.3237],
          [ 0.2488,  0.2891,  0.0444],
          [ 0.0297,  0.0734, -0.0335]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

相应的钩子现在将属于 torch.nn.utils.prune.PruningContainer 类型,并将存储应用于 weight 参数的剪枝历史。

for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container
[<torch.nn.utils.prune.RandomUnstructured object at 0x7fc57b5a4680>, <torch.nn.utils.prune.LnStructured object at 0x7fc5660c47d0>]

序列化剪枝模型#

所有相关的张量,包括用于计算剪枝张量的掩码缓冲区和原始参数,都存储在模型的state_dict中,因此如果需要,可以轻松地进行序列化和保存。

print(model.state_dict().keys())
odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

移除剪枝重参数化#

为了使剪枝永久生效,需要移除与weight_origweight_mask相关的重参数化,以及移除forward_pre_hook,我们可以使用torch.nn.utils.prune中的remove功能。请注意,这并不会撤销剪枝,就好像它从未发生过一样。它只是通过将参数weight重新分配给模型的剪枝版本,使其变为永久状态。

在移除重参数化之前:

print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[-0.2908, -0.3297,  0.3301],
          [-0.1059,  0.3224, -0.1656],
          [-0.3119,  0.0924,  0.2647]]],


        [[[ 0.0005,  0.0149,  0.1317],
          [ 0.0265,  0.2909, -0.2732],
          [-0.1525, -0.0275, -0.0561]]],


        [[[-0.2313,  0.3281, -0.2581],
          [ 0.1683, -0.0615, -0.2187],
          [-0.1147, -0.0558, -0.0907]]],


        [[[ 0.1100, -0.0474,  0.1916],
          [-0.2361,  0.3031, -0.2396],
          [-0.2578,  0.2026,  0.2532]]],


        [[[ 0.0928,  0.2640,  0.1735],
          [-0.1389,  0.0455, -0.3115],
          [ 0.1367,  0.1075,  0.2437]]],


        [[[-0.0152,  0.1968,  0.3237],
          [ 0.2488,  0.2891,  0.0444],
          [ 0.0297,  0.0734, -0.0335]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.2374, -0.3188, -0.0395,  0.1943,  0.2974,  0.0997], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 1.],
          [0., 1., 1.],
          [1., 1., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]], device='cuda:0')), ('bias_mask', tensor([1., 1., 0., 0., 1., 0.], device='cuda:0'))]
print(module.weight)
tensor([[[[-0.0000, -0.3297,  0.0000],
          [-0.1059,  0.3224, -0.1656],
          [-0.3119,  0.0924,  0.2647]]],


        [[[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000]]],


        [[[-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000]]],


        [[[ 0.0000, -0.0000,  0.1916],
          [-0.0000,  0.3031, -0.2396],
          [-0.2578,  0.2026,  0.0000]]],


        [[[ 0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000]]],


        [[[-0.0152,  0.0000,  0.3237],
          [ 0.2488,  0.2891,  0.0444],
          [ 0.0297,  0.0734, -0.0335]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

在移除重参数化之后:

prune.remove(module, 'weight')
print(list(module.named_parameters()))
[('bias_orig', Parameter containing:
tensor([-0.2374, -0.3188, -0.0395,  0.1943,  0.2974,  0.0997], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[-0.0000, -0.3297,  0.0000],
          [-0.1059,  0.3224, -0.1656],
          [-0.3119,  0.0924,  0.2647]]],


        [[[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000]]],


        [[[-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000]]],


        [[[ 0.0000, -0.0000,  0.1916],
          [-0.0000,  0.3031, -0.2396],
          [-0.2578,  0.2026,  0.0000]]],


        [[[ 0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000]]],


        [[[-0.0152,  0.0000,  0.3237],
          [ 0.2488,  0.2891,  0.0444],
          [ 0.0297,  0.0734, -0.0335]]]], device='cuda:0', requires_grad=True))]
print(list(module.named_buffers()))
[('bias_mask', tensor([1., 1., 0., 0., 1., 0.], device='cuda:0'))]

在模型中剪枝多个参数#

通过指定所需的剪枝技术和参数,我们可以很容易地根据类型剪枝网络中的多个张量,正如我们将在这个示例中看到的那样。

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers 
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers 
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

全局剪枝#

到目前为止,我们只讨论了通常所说的“局部”剪枝,即逐个剪枝模型中的张量,通过将每个条目的统计数据(权重大小、激活值、梯度等)仅与该张量中的其他条目进行比较。然而,一种更常见且可能更强大的技术是一次性剪枝整个模型,例如,通过移除整个模型中最低的20%连接,而不是每层中最低的20%连接。这可能会导致每层的剪枝百分比不同。让我们看看如何使用torch.nn.utils.prune中的global_unstructured来实现这一点。

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

现在我们可以检查每个剪枝参数中引入的稀疏度,这将不会在每一层都是20%。然而,全局稀疏度将(大约)为 \(20\%\)

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)
Sparsity in conv1.weight: 0.00%
Sparsity in conv2.weight: 8.80%
Sparsity in fc1.weight: 22.08%
Sparsity in fc2.weight: 11.88%
Sparsity in fc3.weight: 11.55%
Global sparsity: 20.00%

扩展 torch.nn.utils.prune 以支持自定义剪枝函数#

要实现自己的剪枝函数,你可以通过继承nn.utils.prune模块中的BasePruningMethod基类来扩展它,所有其他剪枝方法也是如此。基类为你实现了以下方法:__call__apply_maskapplypruneremove。除了一些特殊情况外,你不需要为新的剪枝技术重新实现这些方法。

然而,你需要实现__init__(构造函数)和compute_mask(根据剪枝技术的运算逻辑计算掩码的指令)。此外,你还需要指定该技术实现的剪枝类型(支持的选项是globalstructuredunstructured)。这在需要迭代应用剪枝的情况下是必要的,以确定如何组合掩码。换句话说,当剪枝一个已经剪枝过的参数时,当前的剪枝技术应该作用于参数未被剪枝的部分。指定PRUNING_TYPE将使PruningContainer(处理剪枝掩码的迭代应用)能够正确地识别要剪枝的参数切片。

例如,假设你想实现一种剪枝技术,它在张量中每隔一个条目进行剪枝(或者如果张量已经被剪枝过——在剩余未被剪枝的张量部分中)。这将是PRUNING_TYPE='unstructured',因为它作用于层中的单个连接而不是整个单元/通道('structured'),或跨不同参数('global')。

class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0 
        return mask

现在,要将此方法应用于nn.Module中的参数,你还应该提供一个简单函数来实例化该方法并应用它。

def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module) 
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the 
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the 
    original (unpruned) parameter is stored in a new parameter named 
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module
    
    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

让我们试一试吧!

model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)
tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])