设计 SRCNN 模型

设计 SRCNN 模型#

参考:SRCNN 的例子

from env import temp_dir, root_dir # 配置一些基础环境

定义 SRCNN 网络#

SRCNN 是第一个用于单幅图像超分辨率 [DLHT15] 的深度学习方法。为了实现 SRCNN 的网络架构,需要创建文件 mmagic/models/editors/srgan/sr_resnet.py 并执行 class MSRResNet

在这一步中,通过继承 mmengine.models.BaseModule 来实现 class MSRResNet,并在 __init__ 函数中定义网络架构。特别地,需要使用 @MODELS.register_module()class MSRResNet 的实现添加到 MMagic 的注册中。

import torch.nn as nn
from mmengine.model import BaseModule
from mmengine.registry import MODELS
from mmagic.models.archs import PixelShufflePack, ResidualBlockNoBN
from mmagic.models.utils import default_init_weights, make_layer

@MODELS.register_module()
class MSRResNet(BaseModule):
    """修改后的SRResNet。

    由 "使用生成对抗网络的照片-现实的单幅图像超级分辨率 "中的SRResNet修改而来的压缩版本。

    它使用无BN的残差块,类似于EDSR。
    目前支持x2、x3和x4上采样比例因子。

    Args:
        in_channels (int): Channel number of inputs.
        out_channels (int): Channel number of outputs.
        mid_channels (int): Channel number of intermediate features.
            Default: 64.
        num_blocks (int): Block number in the trunk network. Default: 16.
        upscale_factor (int): Upsampling factor. Support x2, x3 and x4.
            Default: 4.
    """
    _supported_upscale_factors = [2, 3, 4]

    def __init__(self,
                 in_channels,
                 out_channels,
                 mid_channels=64,
                 num_blocks=16,
                 upscale_factor=4):

        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.mid_channels = mid_channels
        self.num_blocks = num_blocks
        self.upscale_factor = upscale_factor

        self.conv_first = nn.Conv2d(
            in_channels, mid_channels, 3, 1, 1, bias=True)
        self.trunk_net = make_layer(
            ResidualBlockNoBN, num_blocks, mid_channels=mid_channels)

        # upsampling
        if self.upscale_factor in [2, 3]:
            self.upsample1 = PixelShufflePack(
                mid_channels,
                mid_channels,
                self.upscale_factor,
                upsample_kernel=3)
        elif self.upscale_factor == 4:
            self.upsample1 = PixelShufflePack(
                mid_channels, mid_channels, 2, upsample_kernel=3)
            self.upsample2 = PixelShufflePack(
                mid_channels, mid_channels, 2, upsample_kernel=3)
        else:
            raise ValueError(
                f'Unsupported scale factor {self.upscale_factor}. '
                f'Currently supported ones are '
                f'{self._supported_upscale_factors}.')

        self.conv_hr = nn.Conv2d(
            mid_channels, mid_channels, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(
            mid_channels, out_channels, 3, 1, 1, bias=True)

        self.img_upsampler = nn.Upsample(
            scale_factor=self.upscale_factor,
            mode='bilinear',
            align_corners=False)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)

        self.init_weights()

    def init_weights(self):
        """Init weights for models.

        Args:
            pretrained (str, optional): Path for pretrained weights. If given
                None, pretrained weights will not be loaded. Defaults to None.
            strict (boo, optional): Whether strictly load the pretrained model.
                Defaults to True.
        """

        for m in [self.conv_first, self.conv_hr, self.conv_last]:
            default_init_weights(m, 0.1)