SRCNN 实现

SRCNN 实现#

参考:mmagic/configs/srcnn/srcnn_x4k915_1xb16-1000k_div2k.py

%%file utils/srcnn.py
from torch import nn

class SRCNN(nn.Module):
    """SRCNN 网络结构用于图像超分辨率。

    SRCNN包含三个卷积层。对于每一层,可以定义输入通道数、输出通道数和卷积核大小。
    输入图像首先会使用双三次插值法进行上采样,然后在高分辨率空间尺寸中进行超分辨处理。

    论文:Learning a Deep Convolutional Network for Image Super-Resolution.

    Args:
        channels (tuple[int]): 元组,包含了每一层的通道数,包括输入和输出的通道数。默认值:(3, 64, 32, 3)。
        kernel_sizes (tuple[int]): 元组,包含了每个卷积层的卷积核大小。默认值:(9, 1, 5)。
        upscale_factor (int): 上采样因子。默认值:4。
    """

    def __init__(self,
                 upscale_factor=4,
                 channels=(3, 64, 32, 3),
                 kernel_sizes=(9, 1, 5),
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert len(channels) == 4, (f'通道元组的长度应为4,但实际得到的长度是 {len(channels)}')
        assert len(kernel_sizes) == 3, f"kernel 元组的长度应为3,但得到的是{len(kernel_sizes)}"
        self.upscale_factor = upscale_factor
        self.img_upsampler = nn.Upsample(
            scale_factor=self.upscale_factor,
            mode='bicubic',
            align_corners=False)

        self.conv1 = nn.Conv2d(
            channels[0],
            channels[1],
            kernel_size=kernel_sizes[0],
            padding=kernel_sizes[0] // 2)
        self.conv2 = nn.Conv2d(
            channels[1],
            channels[2],
            kernel_size=kernel_sizes[1],
            padding=kernel_sizes[1] // 2)
        self.conv3 = nn.Conv2d(
            channels[2],
            channels[3],
            kernel_size=kernel_sizes[2],
            padding=kernel_sizes[2] // 2)

        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.img_upsampler(x)
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.conv3(out)
        return out
Overwriting utils/srcnn.py
from utils.srcnn import SRCNN
net = SRCNN(upscale_factor=4,)
net
SRCNN(
  (img_upsampler): Upsample(scale_factor=4.0, mode='bicubic')
  (conv1): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (conv2): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (conv3): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (relu): ReLU()
)