快速运行 MMagic

快速运行 MMagic#

pip install controlnet_aux diffusers click av einops face-alignment facexlib lpips mediapipe resize_right transformers accelerate
import sys
from pathlib import Path
root_dir = Path(".").resolve().parents[1]
sys.path.extend([
    f"{root_dir}/src",
    f"{root_dir}/tests"
])
from set_env import temp_dir
(temp_dir/"output").mkdir(exist_ok=True)

从文本生成图像:

from mmagic.apis import MMagicInferencer
sd_inferencer = MMagicInferencer(model_name='stable_diffusion')
text_prompts = 'A panda is having dinner at KFC'
result_out_dir = 'output/sd_res.png'
sd_inferencer.infer(text=text_prompts, result_out_dir=result_out_dir)

MMagic 的超分辨率:

from mmagic.apis import MMagicInferencer
config = f'{temp_dir}/mmagic/configs/esrgan/esrgan_x4c64b23g32_1xb16-400k_div2k.py'
checkpoint = 'https://download.openmmlab.com/mmediting/restorers/esrgan/esrgan_x4c64b23g32_1x16_400k_div2k_20200508-f8ccaf3b.pth'
img_path = temp_dir/'mmagic/tests/data/image/lq/baboon_x4.png'
editor = MMagicInferencer('esrgan', model_config=config, model_ckpt=checkpoint)
result_out_dir = 'images/output.png'
output = editor.infer(img=img_path, result_out_dir=result_out_dir)
/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/mmengine/optim/optimizer/zero_optimizer.py:11: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
  from torch.distributed.optim import \
/media/pc/data/lxw/ai/torch-book/tests/.temp/tasks/mmagic/mmagic/apis/mmagic_inferencer.py:186: UserWarning: esrgan's default config is overridden by /media/pc/data/lxw/ai/torch-book/tests/.temp/tasks/mmagic/configs/esrgan/esrgan_x4c64b23g32_1xb16-400k_div2k.py
  warnings.warn(
/media/pc/data/lxw/ai/torch-book/tests/.temp/tasks/mmagic/mmagic/apis/mmagic_inferencer.py:193: UserWarning: esrgan's default checkpoint is overridden by https://download.openmmlab.com/mmediting/restorers/esrgan/esrgan_x4c64b23g32_1x16_400k_div2k_20200508-f8ccaf3b.pth
  warnings.warn(
/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG19_Weights.IMAGENET1K_V1`. You can also use `weights=VGG19_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Loads checkpoint by http backend from path: https://download.openmmlab.com/mmediting/restorers/esrgan/esrgan_x4c64b23g32_1x16_400k_div2k_20200508-f8ccaf3b.pth
11/21 13:52:38 - mmengine - WARNING - Failed to search registry with scope "mmagic" in the "function" registry tree. As a workaround, the current "function" registry in "mmengine" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmagic" is a correct scope, or whether the registry is initialized.
11/21 13:52:38 - mmengine - WARNING - Cannot find key 'gt_img' in data sample.
/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/mmengine/visualization/visualizer.py:196: UserWarning: Failed to add <class 'mmengine.visualization.vis_backend.LocalVisBackend'>, please provide the `save_dir` argument.
  warnings.warn(f'Failed to add {vis_backend.__class__}, '
from PIL import Image
import numpy as np
im1 = Image.open(img_path)
im2 = Image.open(result_out_dir)
Image.fromarray(np.concatenate([im1.resize(im2.size), im2], axis=1))
../../../_images/266b843517e4cb894ae466a7db7715acc7762f46d84109f68b52d55b5c4303e4.png
!python {temp_dir}/mmagic/tools/analysis_tools/print_config.py {config}
Config:
custom_hooks = [
    dict(interval=1, type='BasicVisualizationHook'),
]
data_root = 'data'
dataset_type = 'BasicImageDataset'
default_hooks = dict(
    checkpoint=dict(
        by_epoch=False,
        interval=5000,
        max_keep_ckpts=10,
        out_dir='./work_dirs/',
        rule='greater',
        save_best='PSNR',
        save_optimizer=True,
        type='CheckpointHook'),
    logger=dict(interval=100, type='LoggerHook'),
    param_scheduler=dict(type='ParamSchedulerHook'),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    timer=dict(type='IterTimerHook'))
default_scope = 'mmagic'
div2k_data_root = 'data/DIV2K'
div2k_dataloader = dict(
    dataset=dict(
        ann_file='meta_info_DIV2K100sub_GT.txt',
        data_prefix=dict(
            gt='DIV2K_train_HR_sub', img='DIV2K_train_LR_bicubic/X4_sub'),
        data_root='data/DIV2K',
        metainfo=dict(dataset_type='div2k', task_name='sisr'),
        pipeline=[
            dict(
                channel_order='rgb',
                color_type='color',
                imdecode_backend='cv2',
                key='img',
                type='LoadImageFromFile'),
            dict(
                channel_order='rgb',
                color_type='color',
                imdecode_backend='cv2',
                key='gt',
                type='LoadImageFromFile'),
            dict(type='PackInputs'),
        ],
        type='BasicImageDataset'),
    drop_last=False,
    num_workers=4,
    persistent_workers=False,
    sampler=dict(shuffle=False, type='DefaultSampler'))
div2k_evaluator = dict(
    metrics=[
        dict(crop_border=4, prefix='DIV2K', type='PSNR'),
        dict(crop_border=4, prefix='DIV2K', type='SSIM'),
    ],
    type='Evaluator')
env_cfg = dict(
    cudnn_benchmark=False,
    dist_cfg=dict(backend='nccl'),
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=4))
experiment_name = 'esrgan_x4c64b23g32_1xb16-400k_div2k'
load_from = None
log_level = 'INFO'
log_processor = dict(by_epoch=False, type='LogProcessor', window_size=100)
model = dict(
    data_preprocessor=dict(
        mean=[
            0.0,
            0.0,
            0.0,
        ],
        std=[
            255.0,
            255.0,
            255.0,
        ],
        type='DataPreprocessor'),
    discriminator=dict(in_channels=3, mid_channels=64, type='ModifiedVGG'),
    gan_loss=dict(
        fake_label_val=0,
        gan_type='vanilla',
        loss_weight=0.005,
        real_label_val=1.0,
        type='GANLoss'),
    generator=dict(
        growth_channels=32,
        in_channels=3,
        init_cfg=dict(
            checkpoint=
            'https://download.openmmlab.com/mmediting/restorers/esrgan/esrgan_psnr_x4c64b23g32_1x16_1000k_div2k_20200420-bf5c993c.pth',
            prefix='generator.',
            type='Pretrained'),
        mid_channels=64,
        num_blocks=23,
        out_channels=3,
        type='RRDBNet',
        upscale_factor=4),
    perceptual_loss=dict(
        layer_weights=dict({'34': 1.0}),
        norm_img=False,
        perceptual_weight=1.0,
        style_weight=0,
        type='PerceptualLoss',
        vgg_type='vgg19'),
    pixel_loss=dict(loss_weight=0.01, reduction='mean', type='L1Loss'),
    test_cfg=dict(),
    train_cfg=dict(),
    type='ESRGAN')
model_wrapper_cfg = dict(type='MMSeparateDistributedDataParallel')
optim_wrapper = dict(
    constructor='MultiOptimWrapperConstructor',
    discriminator=dict(
        optimizer=dict(betas=(
            0.9,
            0.999,
        ), lr=0.0001, type='Adam'),
        type='OptimWrapper'),
    generator=dict(
        optimizer=dict(betas=(
            0.9,
            0.999,
        ), lr=0.0001, type='Adam'),
        type='OptimWrapper'))
param_scheduler = dict(
    by_epoch=False,
    gamma=0.5,
    milestones=[
        50000,
        100000,
        200000,
        300000,
    ],
    type='MultiStepLR')
pretrain_generator_url = 'https://download.openmmlab.com/mmediting/restorers/esrgan/esrgan_psnr_x4c64b23g32_1x16_1000k_div2k_20200420-bf5c993c.pth'
resume = False
save_dir = './work_dirs/'
scale = 4
set14_data_root = 'data/Set14'
set14_dataloader = dict(
    dataset=dict(
        data_prefix=dict(gt='GTmod12', img='LRbicx4'),
        data_root='data/Set14',
        metainfo=dict(dataset_type='set14', task_name='sisr'),
        pipeline=[
            dict(
                channel_order='rgb',
                color_type='color',
                imdecode_backend='cv2',
                key='img',
                type='LoadImageFromFile'),
            dict(
                channel_order='rgb',
                color_type='color',
                imdecode_backend='cv2',
                key='gt',
                type='LoadImageFromFile'),
            dict(type='PackInputs'),
        ],
        type='BasicImageDataset'),
    drop_last=False,
    num_workers=4,
    persistent_workers=False,
    sampler=dict(shuffle=False, type='DefaultSampler'))
set14_evaluator = dict(
    metrics=[
        dict(crop_border=4, prefix='Set14', type='PSNR'),
        dict(crop_border=4, prefix='Set14', type='SSIM'),
    ],
    type='Evaluator')
set5_data_root = 'data/Set5'
set5_dataloader = dict(
    dataset=dict(
        data_prefix=dict(gt='GTmod12', img='LRbicx4'),
        data_root='data/Set5',
        metainfo=dict(dataset_type='set5', task_name='sisr'),
        pipeline=[
            dict(
                channel_order='rgb',
                color_type='color',
                imdecode_backend='cv2',
                key='img',
                type='LoadImageFromFile'),
            dict(
                channel_order='rgb',
                color_type='color',
                imdecode_backend='cv2',
                key='gt',
                type='LoadImageFromFile'),
            dict(type='PackInputs'),
        ],
        type='BasicImageDataset'),
    drop_last=False,
    num_workers=4,
    persistent_workers=False,
    sampler=dict(shuffle=False, type='DefaultSampler'))
set5_evaluator = dict(
    metrics=[
        dict(crop_border=4, prefix='Set5', type='PSNR'),
        dict(crop_border=4, prefix='Set5', type='SSIM'),
    ],
    type='Evaluator')
test_cfg = dict(type='MultiTestLoop')
test_dataloader = [
    dict(
        dataset=dict(
            data_prefix=dict(gt='GTmod12', img='LRbicx4'),
            data_root='data/Set5',
            metainfo=dict(dataset_type='set5', task_name='sisr'),
            pipeline=[
                dict(
                    channel_order='rgb',
                    color_type='color',
                    imdecode_backend='cv2',
                    key='img',
                    type='LoadImageFromFile'),
                dict(
                    channel_order='rgb',
                    color_type='color',
                    imdecode_backend='cv2',
                    key='gt',
                    type='LoadImageFromFile'),
                dict(type='PackInputs'),
            ],
            type='BasicImageDataset'),
        drop_last=False,
        num_workers=4,
        persistent_workers=False,
        sampler=dict(shuffle=False, type='DefaultSampler')),
    dict(
        dataset=dict(
            data_prefix=dict(gt='GTmod12', img='LRbicx4'),
            data_root='data/Set14',
            metainfo=dict(dataset_type='set14', task_name='sisr'),
            pipeline=[
                dict(
                    channel_order='rgb',
                    color_type='color',
                    imdecode_backend='cv2',
                    key='img',
                    type='LoadImageFromFile'),
                dict(
                    channel_order='rgb',
                    color_type='color',
                    imdecode_backend='cv2',
                    key='gt',
                    type='LoadImageFromFile'),
                dict(type='PackInputs'),
            ],
            type='BasicImageDataset'),
        drop_last=False,
        num_workers=4,
        persistent_workers=False,
        sampler=dict(shuffle=False, type='DefaultSampler')),
    dict(
        dataset=dict(
            ann_file='meta_info_DIV2K100sub_GT.txt',
            data_prefix=dict(
                gt='DIV2K_train_HR_sub', img='DIV2K_train_LR_bicubic/X4_sub'),
            data_root='data/DIV2K',
            metainfo=dict(dataset_type='div2k', task_name='sisr'),
            pipeline=[
                dict(
                    channel_order='rgb',
                    color_type='color',
                    imdecode_backend='cv2',
                    key='img',
                    type='LoadImageFromFile'),
                dict(
                    channel_order='rgb',
                    color_type='color',
                    imdecode_backend='cv2',
                    key='gt',
                    type='LoadImageFromFile'),
                dict(type='PackInputs'),
            ],
            type='BasicImageDataset'),
        drop_last=False,
        num_workers=4,
        persistent_workers=False,
        sampler=dict(shuffle=False, type='DefaultSampler')),
]
test_evaluator = [
    dict(
        metrics=[
            dict(crop_border=4, prefix='Set5', type='PSNR'),
            dict(crop_border=4, prefix='Set5', type='SSIM'),
        ],
        type='Evaluator'),
    dict(
        metrics=[
            dict(crop_border=4, prefix='Set14', type='PSNR'),
            dict(crop_border=4, prefix='Set14', type='SSIM'),
        ],
        type='Evaluator'),
    dict(
        metrics=[
            dict(crop_border=4, prefix='DIV2K', type='PSNR'),
            dict(crop_border=4, prefix='DIV2K', type='SSIM'),
        ],
        type='Evaluator'),
]
test_pipeline = [
    dict(
        channel_order='rgb',
        color_type='color',
        imdecode_backend='cv2',
        key='img',
        type='LoadImageFromFile'),
    dict(
        channel_order='rgb',
        color_type='color',
        imdecode_backend='cv2',
        key='gt',
        type='LoadImageFromFile'),
    dict(type='PackInputs'),
]
train_cfg = dict(
    max_iters=400000, type='IterBasedTrainLoop', val_interval=5000)
train_dataloader = dict(
    batch_size=16,
    dataset=dict(
        ann_file='meta_info_DIV2K800sub_GT.txt',
        data_prefix=dict(
            gt='DIV2K_train_HR_sub', img='DIV2K_train_LR_bicubic/X4_sub'),
        data_root='data/DIV2K',
        filename_tmpl=dict(gt='{}', img='{}'),
        metainfo=dict(dataset_type='div2k', task_name='sisr'),
        pipeline=[
            dict(
                channel_order='rgb',
                color_type='color',
                key='img',
                type='LoadImageFromFile'),
            dict(
                channel_order='rgb',
                color_type='color',
                key='gt',
                type='LoadImageFromFile'),
            dict(dictionary=dict(scale=4), type='SetValues'),
            dict(gt_patch_size=128, type='PairedRandomCrop'),
            dict(
                direction='horizontal',
                flip_ratio=0.5,
                keys=[
                    'img',
                    'gt',
                ],
                type='Flip'),
            dict(
                direction='vertical',
                flip_ratio=0.5,
                keys=[
                    'img',
                    'gt',
                ],
                type='Flip'),
            dict(
                keys=[
                    'img',
                    'gt',
                ],
                transpose_ratio=0.5,
                type='RandomTransposeHW'),
            dict(type='PackInputs'),
        ],
        type='BasicImageDataset'),
    num_workers=8,
    persistent_workers=False,
    sampler=dict(shuffle=True, type='InfiniteSampler'))
train_pipeline = [
    dict(
        channel_order='rgb',
        color_type='color',
        key='img',
        type='LoadImageFromFile'),
    dict(
        channel_order='rgb',
        color_type='color',
        key='gt',
        type='LoadImageFromFile'),
    dict(dictionary=dict(scale=4), type='SetValues'),
    dict(gt_patch_size=128, type='PairedRandomCrop'),
    dict(
        direction='horizontal',
        flip_ratio=0.5,
        keys=[
            'img',
            'gt',
        ],
        type='Flip'),
    dict(
        direction='vertical',
        flip_ratio=0.5,
        keys=[
            'img',
            'gt',
        ],
        type='Flip'),
    dict(keys=[
        'img',
        'gt',
    ], transpose_ratio=0.5, type='RandomTransposeHW'),
    dict(type='PackInputs'),
]
val_cfg = dict(type='MultiValLoop')
val_dataloader = dict(
    dataset=dict(
        data_prefix=dict(gt='GTmod12', img='LRbicx4'),
        data_root='data/Set14',
        metainfo=dict(dataset_type='set14', task_name='sisr'),
        pipeline=[
            dict(
                channel_order='rgb',
                color_type='color',
                key='img',
                type='LoadImageFromFile'),
            dict(
                channel_order='rgb',
                color_type='color',
                key='gt',
                type='LoadImageFromFile'),
            dict(type='PackInputs'),
        ],
        type='BasicImageDataset'),
    drop_last=False,
    num_workers=4,
    persistent_workers=False,
    sampler=dict(shuffle=False, type='DefaultSampler'))
val_evaluator = dict(
    metrics=[
        dict(type='MAE'),
        dict(crop_border=4, type='PSNR'),
        dict(crop_border=4, type='SSIM'),
    ],
    type='Evaluator')
val_pipeline = [
    dict(
        channel_order='rgb',
        color_type='color',
        key='img',
        type='LoadImageFromFile'),
    dict(
        channel_order='rgb',
        color_type='color',
        key='gt',
        type='LoadImageFromFile'),
    dict(type='PackInputs'),
]
vis_backends = [
    dict(type='LocalVisBackend'),
]
visualizer = dict(
    bgr2rgb=True,
    fn_key='gt_path',
    img_keys=[
        'gt_img',
        'input',
        'pred_img',
    ],
    type='ConcatImageVisualizer',
    vis_backends=[
        dict(type='LocalVisBackend'),
    ])
work_dir = './work_dirs/esrgan_x4c64b23g32_1xb16-400k_div2k'