# For tips on running notebooks in Google Colab, see
# https://pytorch.org/tutorials/beginner/colab
%matplotlib inline
使用 NumPy 和 SciPy 创建扩展#
作者: Adam Paszke
更新者: Adam Dziedzic
该层的实现中会调用 NumPy
该层的实现中会调用 SciPy
import torch
from torch.autograd import Function
它被恰当地命名为 BadFFTFunction
from numpy.fft import rfft2, irfft2
class BadFFTFunction(Function):
def forward(ctx, input):
numpy_input = input.detach().numpy()
result = abs(rfft2(numpy_input))
return input.new(result)
def backward(ctx, grad_output):
numpy_go = grad_output.numpy()
result = irfft2(numpy_go)
return grad_output.new(result)
# 由于这个层没有任何参数,可以
# 简单地将其声明为函数,而不是 ``nn.Module`` 类
def incorrect_fft(input):
return BadFFTFunction.apply(input)
input = torch.randn(8, 8, requires_grad=True)
result = incorrect_fft(input)
tensor([[ 1.4051, 13.6226, 9.5634, 3.7766, 5.9795],
[ 0.9438, 3.6594, 13.2123, 1.4907, 11.1425],
[ 6.0420, 3.6192, 6.0886, 13.8669, 4.3223],
[ 1.7715, 5.5878, 4.6413, 2.0451, 6.5434],
[ 6.7019, 11.7479, 5.7401, 9.0250, 0.7763],
[ 1.7715, 7.4761, 2.4372, 2.5261, 6.5434],
[ 6.0420, 1.6046, 7.2918, 3.5582, 4.3223],
[ 0.9438, 6.5624, 7.5281, 2.8141, 11.1425]],
tensor([[-1.6377e+00, 2.0592e-01, 2.0573e-01, 7.8468e-01, 1.7276e+00,
1.3217e-01, -2.4822e-01, -3.9126e-01],
[ 6.2618e-01, 2.6570e-01, -5.2193e-01, 5.0499e-01, -1.4101e+00,
-6.7571e-01, -5.6869e-01, -8.3366e-01],
[ 1.9149e-01, -7.3538e-01, 1.0080e+00, 2.1421e-01, 1.1228e+00,
5.3282e-01, 3.4630e-01, -1.3304e+00],
[ 7.8825e-01, -2.5452e-01, 2.9769e-01, -4.0304e-01, 7.7959e-01,
1.4877e+00, -2.9209e-01, -1.2098e+00],
[-1.6950e+00, 4.3674e-01, -4.5096e-01, 6.6104e-01, 1.0375e+00,
3.0109e-01, 8.0961e-03, 2.7743e-02],
[ 1.1005e-01, -5.9287e-01, -1.8919e+00, 1.5949e+00, 1.1019e+00,
-6.7195e-01, -9.6363e-01, -2.7448e-01],
[-2.8044e-02, -1.0913e+00, -7.7489e-01, 2.2238e+00, 4.3961e-01,
3.5224e-01, -9.1448e-01, 1.3916e+00],
[ 9.5001e-01, 5.7468e-04, 1.9211e-01, 2.7369e-01, -8.5658e-01,
2.6022e-01, -9.6591e-01, 5.0487e-01]], requires_grad=True)
from numpy import flip
import numpy as np
from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
class ScipyConv2dFunction(Function):
def forward(ctx, input, filter, bias):
# detach so we can cast to NumPy
input, filter, bias = input.detach(), filter.detach(), bias.detach()
result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
result += bias.numpy()
ctx.save_for_backward(input, filter, bias)
return torch.as_tensor(result, dtype=input.dtype)
def backward(ctx, grad_output):
grad_output = grad_output.detach()
input, filter, bias = ctx.saved_tensors
grad_output = grad_output.numpy()
grad_bias = np.sum(grad_output, keepdims=True)
grad_input = convolve2d(grad_output, filter.numpy(), mode='full')
# the previous line can be expressed equivalently as:
# grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full')
grad_filter = correlate2d(input.numpy(), grad_output, mode='valid')
return torch.from_numpy(grad_input), torch.from_numpy(grad_filter).to(torch.float), torch.from_numpy(grad_bias).to(torch.float)
class ScipyConv2d(Module):
def __init__(self, filter_width, filter_height):
super(ScipyConv2d, self).__init__()
self.filter = Parameter(torch.randn(filter_width, filter_height))
self.bias = Parameter(torch.randn(1, 1))
def forward(self, input):
return ScipyConv2dFunction.apply(input, self.filter, self.bias)
module = ScipyConv2d(3, 3)
print("Filter and bias: ", list(module.parameters()))
input = torch.randn(10, 10, requires_grad=True)
output = module(input)
print("Output from the convolution: ", output)
output.backward(torch.randn(8, 8))
print("Gradient for the input map: ", input.grad)
Filter and bias: [Parameter containing:
tensor([[-1.0271, 1.0463, 0.3066],
[ 0.2833, 1.0729, 1.0716],
[ 0.1538, 0.7043, 0.4240]], requires_grad=True), Parameter containing:
tensor([[-1.5248]], requires_grad=True)]
Output from the convolution: tensor([[-3.0503, -2.0735, -2.9300, -3.5769, 0.5267, -3.8087, 1.5216, 3.6531],
[-3.0974, 2.9297, -3.3729, -6.7103, 0.3726, -5.8613, 3.5687, 4.6253],
[-1.5237, 1.0010, -0.1776, -7.9033, -2.9091, -2.9002, 0.9109, 2.4959],
[-4.8617, -2.4489, -4.9739, -6.3394, -3.2192, -3.8084, -0.6294, -0.5166],
[-2.9100, -3.5628, -2.0847, -2.4449, -4.0917, -3.0924, -2.7929, 1.4404],
[-0.3139, -0.0623, -1.9976, -5.0148, -4.9538, -2.1379, -0.3020, -0.6575],
[ 0.8865, -3.4737, -4.3420, -5.2427, -5.6424, -0.8231, -1.4015, -0.8467],
[-3.3883, -3.7128, -2.9577, -3.0117, -5.9017, 1.5836, -0.2116, -6.0861]],
Gradient for the input map: tensor([[ 1.2887e+00, -8.1603e-01, 3.3014e-03, -7.7249e-01, -3.4147e-01,
-2.7340e-02, -1.1833e+00, 1.1249e+00, -1.1679e-01, -1.0715e-01],
[-9.9647e-01, -1.4424e+00, -1.7103e+00, -1.5938e+00, -4.4227e-01,
-1.2288e+00, 1.0444e+00, 9.9767e-01, -5.3399e-01, -6.2767e-01],
[ 2.0647e-01, -2.1494e-01, -1.4147e+00, 3.2380e+00, -1.5417e+00,
1.0907e+00, 3.6141e+00, -1.9452e+00, -1.9549e+00, -1.0021e+00],
[-1.4297e+00, 1.0584e+00, 3.6035e+00, 8.7692e-01, 1.6264e+00,
-1.2343e+00, 2.2948e+00, -3.2134e+00, -2.6682e+00, -5.5733e-02],
[ 1.0553e+00, 8.0772e-01, 7.6035e-01, 3.0935e-01, -2.4132e+00,
1.2373e-01, -1.4208e+00, -6.5298e-01, -6.9090e-01, 6.7234e-01],
[-5.0215e-01, 7.1934e-01, 2.9100e-01, -6.9082e-01, 1.5414e+00,
2.7353e+00, -1.8814e+00, 1.2451e+00, 3.2584e-01, -4.7738e-02],
[-4.6497e-01, 1.3912e+00, 1.5108e+00, -6.6143e-01, 3.2445e+00,
4.2020e+00, -1.8766e+00, -5.2713e-01, 8.4413e-01, -7.1390e-01],
[ 1.6174e+00, 4.3118e-01, -1.7414e+00, -3.2706e+00, 2.5401e+00,
2.7037e+00, -3.7122e-01, 2.3204e+00, 9.5514e-01, -4.5598e-02],
[-3.0912e-01, -1.5266e+00, -3.0772e+00, -1.5461e+00, 2.7041e+00,
4.1396e+00, 2.6272e+00, 3.1130e+00, 1.6704e+00, -3.2189e-01],
[-2.0918e-01, -1.1258e+00, -1.3340e+00, -1.2008e-01, 1.4991e+00,
1.6114e+00, 1.5005e+00, 1.5106e+00, 3.0398e-01, -1.9151e-01]])
from torch.autograd.gradcheck import gradcheck
moduleConv = ScipyConv2d(3, 3)
input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)]
test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4)
print("Are the gradients correct: ", test)
Are the gradients correct: True