快速运行 MMagic#
pip install controlnet_aux diffusers click av einops face-alignment facexlib lpips mediapipe resize_right transformers accelerate
import sys
from pathlib import Path
root_dir = Path(".").resolve().parents[1]
sys.path.extend([
f"{root_dir}/src",
f"{root_dir}/tests"
])
from set_env import temp_dir
(temp_dir/"output").mkdir(exist_ok=True)
从文本生成图像:
from mmagic.apis import MMagicInferencer
sd_inferencer = MMagicInferencer(model_name='stable_diffusion')
text_prompts = 'A panda is having dinner at KFC'
result_out_dir = 'output/sd_res.png'
sd_inferencer.infer(text=text_prompts, result_out_dir=result_out_dir)
MMagic 的超分辨率:
from mmagic.apis import MMagicInferencer
config = f'{temp_dir}/mmagic/configs/esrgan/esrgan_x4c64b23g32_1xb16-400k_div2k.py'
checkpoint = 'https://download.openmmlab.com/mmediting/restorers/esrgan/esrgan_x4c64b23g32_1x16_400k_div2k_20200508-f8ccaf3b.pth'
img_path = temp_dir/'mmagic/tests/data/image/lq/baboon_x4.png'
editor = MMagicInferencer('esrgan', model_config=config, model_ckpt=checkpoint)
result_out_dir = 'images/output.png'
output = editor.infer(img=img_path, result_out_dir=result_out_dir)
/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/mmengine/optim/optimizer/zero_optimizer.py:11: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
from torch.distributed.optim import \
/media/pc/data/lxw/ai/torch-book/tests/.temp/tasks/mmagic/mmagic/apis/mmagic_inferencer.py:186: UserWarning: esrgan's default config is overridden by /media/pc/data/lxw/ai/torch-book/tests/.temp/tasks/mmagic/configs/esrgan/esrgan_x4c64b23g32_1xb16-400k_div2k.py
warnings.warn(
/media/pc/data/lxw/ai/torch-book/tests/.temp/tasks/mmagic/mmagic/apis/mmagic_inferencer.py:193: UserWarning: esrgan's default checkpoint is overridden by https://download.openmmlab.com/mmediting/restorers/esrgan/esrgan_x4c64b23g32_1x16_400k_div2k_20200508-f8ccaf3b.pth
warnings.warn(
/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG19_Weights.IMAGENET1K_V1`. You can also use `weights=VGG19_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Loads checkpoint by http backend from path: https://download.openmmlab.com/mmediting/restorers/esrgan/esrgan_x4c64b23g32_1x16_400k_div2k_20200508-f8ccaf3b.pth
11/21 13:52:38 - mmengine - WARNING - Failed to search registry with scope "mmagic" in the "function" registry tree. As a workaround, the current "function" registry in "mmengine" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmagic" is a correct scope, or whether the registry is initialized.
11/21 13:52:38 - mmengine - WARNING - Cannot find key 'gt_img' in data sample.
/media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/mmengine/visualization/visualizer.py:196: UserWarning: Failed to add <class 'mmengine.visualization.vis_backend.LocalVisBackend'>, please provide the `save_dir` argument.
warnings.warn(f'Failed to add {vis_backend.__class__}, '
from PIL import Image
import numpy as np
im1 = Image.open(img_path)
im2 = Image.open(result_out_dir)
Image.fromarray(np.concatenate([im1.resize(im2.size), im2], axis=1))
!python {temp_dir}/mmagic/tools/analysis_tools/print_config.py {config}
Config:
custom_hooks = [
dict(interval=1, type='BasicVisualizationHook'),
]
data_root = 'data'
dataset_type = 'BasicImageDataset'
default_hooks = dict(
checkpoint=dict(
by_epoch=False,
interval=5000,
max_keep_ckpts=10,
out_dir='./work_dirs/',
rule='greater',
save_best='PSNR',
save_optimizer=True,
type='CheckpointHook'),
logger=dict(interval=100, type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
sampler_seed=dict(type='DistSamplerSeedHook'),
timer=dict(type='IterTimerHook'))
default_scope = 'mmagic'
div2k_data_root = 'data/DIV2K'
div2k_dataloader = dict(
dataset=dict(
ann_file='meta_info_DIV2K100sub_GT.txt',
data_prefix=dict(
gt='DIV2K_train_HR_sub', img='DIV2K_train_LR_bicubic/X4_sub'),
data_root='data/DIV2K',
metainfo=dict(dataset_type='div2k', task_name='sisr'),
pipeline=[
dict(
channel_order='rgb',
color_type='color',
imdecode_backend='cv2',
key='img',
type='LoadImageFromFile'),
dict(
channel_order='rgb',
color_type='color',
imdecode_backend='cv2',
key='gt',
type='LoadImageFromFile'),
dict(type='PackInputs'),
],
type='BasicImageDataset'),
drop_last=False,
num_workers=4,
persistent_workers=False,
sampler=dict(shuffle=False, type='DefaultSampler'))
div2k_evaluator = dict(
metrics=[
dict(crop_border=4, prefix='DIV2K', type='PSNR'),
dict(crop_border=4, prefix='DIV2K', type='SSIM'),
],
type='Evaluator')
env_cfg = dict(
cudnn_benchmark=False,
dist_cfg=dict(backend='nccl'),
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=4))
experiment_name = 'esrgan_x4c64b23g32_1xb16-400k_div2k'
load_from = None
log_level = 'INFO'
log_processor = dict(by_epoch=False, type='LogProcessor', window_size=100)
model = dict(
data_preprocessor=dict(
mean=[
0.0,
0.0,
0.0,
],
std=[
255.0,
255.0,
255.0,
],
type='DataPreprocessor'),
discriminator=dict(in_channels=3, mid_channels=64, type='ModifiedVGG'),
gan_loss=dict(
fake_label_val=0,
gan_type='vanilla',
loss_weight=0.005,
real_label_val=1.0,
type='GANLoss'),
generator=dict(
growth_channels=32,
in_channels=3,
init_cfg=dict(
checkpoint=
'https://download.openmmlab.com/mmediting/restorers/esrgan/esrgan_psnr_x4c64b23g32_1x16_1000k_div2k_20200420-bf5c993c.pth',
prefix='generator.',
type='Pretrained'),
mid_channels=64,
num_blocks=23,
out_channels=3,
type='RRDBNet',
upscale_factor=4),
perceptual_loss=dict(
layer_weights=dict({'34': 1.0}),
norm_img=False,
perceptual_weight=1.0,
style_weight=0,
type='PerceptualLoss',
vgg_type='vgg19'),
pixel_loss=dict(loss_weight=0.01, reduction='mean', type='L1Loss'),
test_cfg=dict(),
train_cfg=dict(),
type='ESRGAN')
model_wrapper_cfg = dict(type='MMSeparateDistributedDataParallel')
optim_wrapper = dict(
constructor='MultiOptimWrapperConstructor',
discriminator=dict(
optimizer=dict(betas=(
0.9,
0.999,
), lr=0.0001, type='Adam'),
type='OptimWrapper'),
generator=dict(
optimizer=dict(betas=(
0.9,
0.999,
), lr=0.0001, type='Adam'),
type='OptimWrapper'))
param_scheduler = dict(
by_epoch=False,
gamma=0.5,
milestones=[
50000,
100000,
200000,
300000,
],
type='MultiStepLR')
pretrain_generator_url = 'https://download.openmmlab.com/mmediting/restorers/esrgan/esrgan_psnr_x4c64b23g32_1x16_1000k_div2k_20200420-bf5c993c.pth'
resume = False
save_dir = './work_dirs/'
scale = 4
set14_data_root = 'data/Set14'
set14_dataloader = dict(
dataset=dict(
data_prefix=dict(gt='GTmod12', img='LRbicx4'),
data_root='data/Set14',
metainfo=dict(dataset_type='set14', task_name='sisr'),
pipeline=[
dict(
channel_order='rgb',
color_type='color',
imdecode_backend='cv2',
key='img',
type='LoadImageFromFile'),
dict(
channel_order='rgb',
color_type='color',
imdecode_backend='cv2',
key='gt',
type='LoadImageFromFile'),
dict(type='PackInputs'),
],
type='BasicImageDataset'),
drop_last=False,
num_workers=4,
persistent_workers=False,
sampler=dict(shuffle=False, type='DefaultSampler'))
set14_evaluator = dict(
metrics=[
dict(crop_border=4, prefix='Set14', type='PSNR'),
dict(crop_border=4, prefix='Set14', type='SSIM'),
],
type='Evaluator')
set5_data_root = 'data/Set5'
set5_dataloader = dict(
dataset=dict(
data_prefix=dict(gt='GTmod12', img='LRbicx4'),
data_root='data/Set5',
metainfo=dict(dataset_type='set5', task_name='sisr'),
pipeline=[
dict(
channel_order='rgb',
color_type='color',
imdecode_backend='cv2',
key='img',
type='LoadImageFromFile'),
dict(
channel_order='rgb',
color_type='color',
imdecode_backend='cv2',
key='gt',
type='LoadImageFromFile'),
dict(type='PackInputs'),
],
type='BasicImageDataset'),
drop_last=False,
num_workers=4,
persistent_workers=False,
sampler=dict(shuffle=False, type='DefaultSampler'))
set5_evaluator = dict(
metrics=[
dict(crop_border=4, prefix='Set5', type='PSNR'),
dict(crop_border=4, prefix='Set5', type='SSIM'),
],
type='Evaluator')
test_cfg = dict(type='MultiTestLoop')
test_dataloader = [
dict(
dataset=dict(
data_prefix=dict(gt='GTmod12', img='LRbicx4'),
data_root='data/Set5',
metainfo=dict(dataset_type='set5', task_name='sisr'),
pipeline=[
dict(
channel_order='rgb',
color_type='color',
imdecode_backend='cv2',
key='img',
type='LoadImageFromFile'),
dict(
channel_order='rgb',
color_type='color',
imdecode_backend='cv2',
key='gt',
type='LoadImageFromFile'),
dict(type='PackInputs'),
],
type='BasicImageDataset'),
drop_last=False,
num_workers=4,
persistent_workers=False,
sampler=dict(shuffle=False, type='DefaultSampler')),
dict(
dataset=dict(
data_prefix=dict(gt='GTmod12', img='LRbicx4'),
data_root='data/Set14',
metainfo=dict(dataset_type='set14', task_name='sisr'),
pipeline=[
dict(
channel_order='rgb',
color_type='color',
imdecode_backend='cv2',
key='img',
type='LoadImageFromFile'),
dict(
channel_order='rgb',
color_type='color',
imdecode_backend='cv2',
key='gt',
type='LoadImageFromFile'),
dict(type='PackInputs'),
],
type='BasicImageDataset'),
drop_last=False,
num_workers=4,
persistent_workers=False,
sampler=dict(shuffle=False, type='DefaultSampler')),
dict(
dataset=dict(
ann_file='meta_info_DIV2K100sub_GT.txt',
data_prefix=dict(
gt='DIV2K_train_HR_sub', img='DIV2K_train_LR_bicubic/X4_sub'),
data_root='data/DIV2K',
metainfo=dict(dataset_type='div2k', task_name='sisr'),
pipeline=[
dict(
channel_order='rgb',
color_type='color',
imdecode_backend='cv2',
key='img',
type='LoadImageFromFile'),
dict(
channel_order='rgb',
color_type='color',
imdecode_backend='cv2',
key='gt',
type='LoadImageFromFile'),
dict(type='PackInputs'),
],
type='BasicImageDataset'),
drop_last=False,
num_workers=4,
persistent_workers=False,
sampler=dict(shuffle=False, type='DefaultSampler')),
]
test_evaluator = [
dict(
metrics=[
dict(crop_border=4, prefix='Set5', type='PSNR'),
dict(crop_border=4, prefix='Set5', type='SSIM'),
],
type='Evaluator'),
dict(
metrics=[
dict(crop_border=4, prefix='Set14', type='PSNR'),
dict(crop_border=4, prefix='Set14', type='SSIM'),
],
type='Evaluator'),
dict(
metrics=[
dict(crop_border=4, prefix='DIV2K', type='PSNR'),
dict(crop_border=4, prefix='DIV2K', type='SSIM'),
],
type='Evaluator'),
]
test_pipeline = [
dict(
channel_order='rgb',
color_type='color',
imdecode_backend='cv2',
key='img',
type='LoadImageFromFile'),
dict(
channel_order='rgb',
color_type='color',
imdecode_backend='cv2',
key='gt',
type='LoadImageFromFile'),
dict(type='PackInputs'),
]
train_cfg = dict(
max_iters=400000, type='IterBasedTrainLoop', val_interval=5000)
train_dataloader = dict(
batch_size=16,
dataset=dict(
ann_file='meta_info_DIV2K800sub_GT.txt',
data_prefix=dict(
gt='DIV2K_train_HR_sub', img='DIV2K_train_LR_bicubic/X4_sub'),
data_root='data/DIV2K',
filename_tmpl=dict(gt='{}', img='{}'),
metainfo=dict(dataset_type='div2k', task_name='sisr'),
pipeline=[
dict(
channel_order='rgb',
color_type='color',
key='img',
type='LoadImageFromFile'),
dict(
channel_order='rgb',
color_type='color',
key='gt',
type='LoadImageFromFile'),
dict(dictionary=dict(scale=4), type='SetValues'),
dict(gt_patch_size=128, type='PairedRandomCrop'),
dict(
direction='horizontal',
flip_ratio=0.5,
keys=[
'img',
'gt',
],
type='Flip'),
dict(
direction='vertical',
flip_ratio=0.5,
keys=[
'img',
'gt',
],
type='Flip'),
dict(
keys=[
'img',
'gt',
],
transpose_ratio=0.5,
type='RandomTransposeHW'),
dict(type='PackInputs'),
],
type='BasicImageDataset'),
num_workers=8,
persistent_workers=False,
sampler=dict(shuffle=True, type='InfiniteSampler'))
train_pipeline = [
dict(
channel_order='rgb',
color_type='color',
key='img',
type='LoadImageFromFile'),
dict(
channel_order='rgb',
color_type='color',
key='gt',
type='LoadImageFromFile'),
dict(dictionary=dict(scale=4), type='SetValues'),
dict(gt_patch_size=128, type='PairedRandomCrop'),
dict(
direction='horizontal',
flip_ratio=0.5,
keys=[
'img',
'gt',
],
type='Flip'),
dict(
direction='vertical',
flip_ratio=0.5,
keys=[
'img',
'gt',
],
type='Flip'),
dict(keys=[
'img',
'gt',
], transpose_ratio=0.5, type='RandomTransposeHW'),
dict(type='PackInputs'),
]
val_cfg = dict(type='MultiValLoop')
val_dataloader = dict(
dataset=dict(
data_prefix=dict(gt='GTmod12', img='LRbicx4'),
data_root='data/Set14',
metainfo=dict(dataset_type='set14', task_name='sisr'),
pipeline=[
dict(
channel_order='rgb',
color_type='color',
key='img',
type='LoadImageFromFile'),
dict(
channel_order='rgb',
color_type='color',
key='gt',
type='LoadImageFromFile'),
dict(type='PackInputs'),
],
type='BasicImageDataset'),
drop_last=False,
num_workers=4,
persistent_workers=False,
sampler=dict(shuffle=False, type='DefaultSampler'))
val_evaluator = dict(
metrics=[
dict(type='MAE'),
dict(crop_border=4, type='PSNR'),
dict(crop_border=4, type='SSIM'),
],
type='Evaluator')
val_pipeline = [
dict(
channel_order='rgb',
color_type='color',
key='img',
type='LoadImageFromFile'),
dict(
channel_order='rgb',
color_type='color',
key='gt',
type='LoadImageFromFile'),
dict(type='PackInputs'),
]
vis_backends = [
dict(type='LocalVisBackend'),
]
visualizer = dict(
bgr2rgb=True,
fn_key='gt_path',
img_keys=[
'gt_img',
'input',
'pred_img',
],
type='ConcatImageVisualizer',
vis_backends=[
dict(type='LocalVisBackend'),
])
work_dir = './work_dirs/esrgan_x4c64b23g32_1xb16-400k_div2k'