设计 SRCNN 模型#
参考:SRCNN 的例子
from env import temp_dir, root_dir # 配置一些基础环境
定义 SRCNN 网络#
SRCNN 是第一个用于单幅图像超分辨率 [DLHT15] 的深度学习方法。为了实现 SRCNN 的网络架构,需要创建文件 mmagic/models/editors/srgan/sr_resnet.py
并执行 class MSRResNet
。
在这一步中,通过继承 mmengine.models.BaseModule
来实现 class MSRResNet
,并在 __init__
函数中定义网络架构。特别地,需要使用 @MODELS.register_module()
将 class MSRResNet
的实现添加到 MMagic
的注册中。
import torch.nn as nn
from mmengine.model import BaseModule
from mmengine.registry import MODELS
from mmagic.models.archs import PixelShufflePack, ResidualBlockNoBN
from mmagic.models.utils import default_init_weights, make_layer
@MODELS.register_module()
class MSRResNet(BaseModule):
"""修改后的SRResNet。
由 "使用生成对抗网络的照片-现实的单幅图像超级分辨率 "中的SRResNet修改而来的压缩版本。
它使用无BN的残差块,类似于EDSR。
目前支持x2、x3和x4上采样比例因子。
Args:
in_channels (int): Channel number of inputs.
out_channels (int): Channel number of outputs.
mid_channels (int): Channel number of intermediate features.
Default: 64.
num_blocks (int): Block number in the trunk network. Default: 16.
upscale_factor (int): Upsampling factor. Support x2, x3 and x4.
Default: 4.
"""
_supported_upscale_factors = [2, 3, 4]
def __init__(self,
in_channels,
out_channels,
mid_channels=64,
num_blocks=16,
upscale_factor=4):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.mid_channels = mid_channels
self.num_blocks = num_blocks
self.upscale_factor = upscale_factor
self.conv_first = nn.Conv2d(
in_channels, mid_channels, 3, 1, 1, bias=True)
self.trunk_net = make_layer(
ResidualBlockNoBN, num_blocks, mid_channels=mid_channels)
# upsampling
if self.upscale_factor in [2, 3]:
self.upsample1 = PixelShufflePack(
mid_channels,
mid_channels,
self.upscale_factor,
upsample_kernel=3)
elif self.upscale_factor == 4:
self.upsample1 = PixelShufflePack(
mid_channels, mid_channels, 2, upsample_kernel=3)
self.upsample2 = PixelShufflePack(
mid_channels, mid_channels, 2, upsample_kernel=3)
else:
raise ValueError(
f'Unsupported scale factor {self.upscale_factor}. '
f'Currently supported ones are '
f'{self._supported_upscale_factors}.')
self.conv_hr = nn.Conv2d(
mid_channels, mid_channels, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(
mid_channels, out_channels, 3, 1, 1, bias=True)
self.img_upsampler = nn.Upsample(
scale_factor=self.upscale_factor,
mode='bilinear',
align_corners=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.init_weights()
def init_weights(self):
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
strict (boo, optional): Whether strictly load the pretrained model.
Defaults to True.
"""
for m in [self.conv_first, self.conv_hr, self.conv_last]:
default_init_weights(m, 0.1)