zl程序教程

您现在的位置是:首页 >  Java

当前栏目

想抓住 AI 绘画风口?绝不能错过的重大算法库升级!

2023-02-18 16:49:15 时间

近年来,AI 绘画技术呈现井喷式的爆发,给学术界和工业界都带来了新的无限可能。我们知道,在 OpenMMLab 大家族中,和 AI 绘画最为相关的两大算法库是 MMEditing 和 MMGeneration。

其中,MMEditing 聚焦于图像和视频的质量恢复和增强,提供了世界领先的知名算法;而 MMGeneration 则致力于提供一系列充满想象力的生成模型(如 16 种不同的生成模型)。

利用 MMEditing 复刻经典爱情故事

利用 MMGeneration 生成以假乱真的图片

今天,你是否还在纠结应该使用 MMEditing 和 MMGeneration 的哪个算法库呢?不用纠结!在全新的 MMEditing 1.0 升级中,两个算法库已经合二为一!MMEditing 1.0 现已支持 MMGeneration 中的全量算法和模型!

我们希望这样的一个升级,能够让生成模型和底层视觉任务碰撞出灵感的火花,让 AI 绘画的研发更加容易。大家所关心的 diffusion model 系列也已经在路上了哟~

接下来我们将介绍新版本的新特性,以及一个快速上手 MMEditing 1.x 的实战教程。

全新版本的新特性

更加全面

拥抱升级,全新的 MMEditing 1.0 变得更加全面了。

MMEditing 1.0 拥有更全面的算法和工具

首先,与 0.x 版本相比,新版本中支持的算法数量、与训练模型文件的数量都几乎翻倍,相信这么多种算法和任务,总有一款适合你!

其次,在全新的 MMEditing 1.0 中涵盖了更多类型的损失函数和评价指标,希望这些能够推动底层视觉任务和生成模型基准的建设。小伙伴们比较 SOTA 算法的时候再也不用跑五六个算法库啦,MMEditing 1.0 一个就能搞定!

最后,MMEditing 1.0 还支持了更丰富的可视化后台,包括经典的 tensorboard 和 wandb。

更加简洁

虽然我们的功能更加强大了,但是我们的架构设计却变得更加简洁了。我们知道,一个深度学习框架中有几个重要的组成部分:数据、模型、评测器、优化器。其中,我们需要用评测器来对训练得到的模型进行性能评估,帮助我们在训练过程中保存性能最好的模型;同时,我们使用优化器来对模型的参数进行优化更新;当然,我们还需要一个参数调度器,来对优化器中的学习率、动量等超参数进行动态调整。

在全新的训练引擎 MMEngine 中,不同的组件能够被统一在一个执行器中有条不紊地进行调度的灵魂,就是钩子(Hook)。钩子的使用使得 MMEditing 能够更自由灵活地在训练、测试、推理的流程中插入、拓展、执行自定义的逻辑。上述所有组件都可以调用日志管理模块和可视化器进行结构化和非结构化的存储与展示。

MMEditing 使用配置文件来对训练、测试等过程进行设置,使用注册器来管理算法库中具有相同功能的模块。同时,MMEditing 1.0 使用 MMEngine 中提供的文件读写和分布式通信组件。下图展示了 MMEditing 1.0 整体框架设计:

目前 MMEditing 1.0 已经提供了全面的教程来帮助大家使用我们的新版本啦。

具体可以查看:https://mmediting.readthedocs.io/en/1.x/

MMEditing 1.0 提供了三大类教程

经典 MMEditing 1.0 实践:

从零支持 BasicVSR

在本小节中,主要通过从零支持 BasicVSR 这个经典视频超分模型,来带大家了解 MMEditing 1.0 源码目录结构以及配置文件的一些变化。

步骤 0. 灵活的配置文件设计

与 MMEditing 0.x (算法库的现主分支)相比,MMEditing 1.0 同样支持在配置文件中来指定模型训练、测试等过程中相关组件的参数,同时还支持了包括 python、json 以及 yaml 格式的配置文件。此外,最大的亮点则是支持了项目内、跨项目配置文件的继承。

举个例子,在 basicvsr_2xb4_reds4.py 的配置文件中,我们可以通过在 _base_ 字段里继承 ../_base_/default_runtime.py 以及 ../_base_/datasets/basicvsr_test_config.py 文件来重复利用这两个配置文件中的相关设置。

有了这样的新特性,我们在多个不同数据集上进行同一个模型的验证实验时,就不需要在配置文件中重复写模型相关的配置啦,只需要在 _base_ 字段里继承即可,如下面代码所示:

# basicvsr_2xb4_reds4.py 

 _base_ = [
    '../_base_/default_runtime.py',
    '../_base_/datasets/basicvsr_test_config.py'
]
# default_runtime.py

log_level = 'INFO'
log_processor = dict(type='LogProcessor', window_size=100, by_epoch=False)
load_from = None
resume = False

步骤 1. 准备数据集

实现从零支持 BasicVSR 的第一步就是准备数据集。在 MMEditing 1.0 中,数据集的源码实现主要在 mmedit/datasets 的目录结构下:

mmedit
├── datasets
│   ├── basic_image_dataset.py
│   ├── ...
│   ├── basic_frames_dataset.py
│  

在这里,我们对数据集基类进行了重新的设计,支持了包括图像、视频、带标签图像、成对以及非成对图像对等 35+ 种研究数据集。只要数据集的格式符合 OpenMMLab 2.0 数据集的规范,即可在配置文件中指定对应的数据集路径进行数据集的加载。如下所示,是 basicvsr 的一个训练数据加载器的配置:

# basicvsr_2xb4_reds4.py 
  train_dataloader = dict(
    num_workers=6,
    batch_size=4,
    persistent_workers=False,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    dataset=dict(
        type='BasicFramesDataset',
        metainfo=dict(dataset_type='reds_reds4', task_name='vsr'),
        data_root=data_root,
        data_prefix=dict(img='train_sharp_bicubic/X4', gt='train_sharp'),
        ann_file='meta_info_reds4_train.txt',
        depth=1,
        num_input_frames=15,
        pipeline=train_pipeline))

在下面的这个配置中,我们可以看到 pipeline=train_pipeline,而 train_pipeline 的值实际上是一个字典的序列,指定了数据加载和预处理的流水线操作。

# basicvsr_2xb4_reds4.py 
 train_pipeline = [
    dict(type='GenerateSegmentIndices', interval_list=[1]),
    dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
    dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
    dict(type='SetValues', dictionary=dict(scale=scale)),
    dict(type='PairedRandomCrop', gt_patch_size=256),
    dict(type='Flip',keys=['img', 'gt’], flip_ratio=0.5, direction='horizontal'),
    dict(type='Flip', keys=['img', 'gt'], flip_ratio=0.5, direction='vertical'),
    dict(type='RandomTransposeHW', keys=['img', 'gt'], transpose_ratio=0.5),
    dict(type='PackEditInputs')
]

目前,MMEditing 1.0 支持对图像像素、图像大小、视频长度等 64 种 数据加载和预处理的操作。具体的实现源码可以在 mmedit/datasets/transforms 中找到。

如下面代码所示,是数据集以及数据加载和预处理的源码目录结构:

mmedit
├── datasets
│   ├── basic_image_dataset.py
│   ├── ...
│   ├── basic_frames_dataset.py
│   └── transforms
│       ├── aug_frames.py
│       ├── ...
│       └── aug_shape.py

步骤 2. 准备模型和网络结构

在 MMEditing 1.0 中,模型和网络结构的实现主要在 mmedit/models 下。

首先,我们在 mmedit/models/base_models 下实现了一些基础模型,定义了基础模型的训练逻辑,如 GAN,one-stage,two-stage 等模型;其次,我们在 mmedit/models/base_archs 的目录下支持了一些经典的网络结构,如 vgg,ASPP 等;最后,我们将每个算法的核心实现放在 mmedit/models/editors 下。

具体来说,我们可以在mmedit/models/editors 的目录下创建一个叫做 BasicVSR 的文件目录,并在这个目录下实现 basicvsr 的网络结构和模型。

同时,MMEditing 1.0 在 mmedit/models/losses 下支持了 26 种不同的损失函数。mmedit/models/data_preprocessors 中则主要实现了将加载的数据放到计算芯片 (如 GPU)、数据归一化等处理,源码目录结构如下:

 mmedit
├── datasets
├── models
│   ├── base_models
│   │   ├── ...
│   │   └── base_gan.py
│   ├── base_archs
│   │   ├── ...
│   │   └── resnet.py
│   ├── editors
│   │   ├── ... 
│   │   └── basicvsr
│   │       ├── basicvsr_net.py
│   │       └── basicvsr.py
│   ├── losses
│   │   ├── gan_loss.py
│   │   ├── ... 
│   │   └── clip_loss.py
│   └── data_preprocessors
│       ├── ...
│       └── data_processor.py

在下面的这个配置文件中,我们只需要在 model 这个字段中指定具体的模型构建所需要的参数、损失函数构建所需要的参数以及数据预处理的参数。

# basicvsr_2xb4_reds4.py 

model = dict(
    type='BasicVSR',
    generator=dict(
        type='BasicVSRNet',
        mid_channels=64,
        num_blocks=30,
        spynet_pretrained='https://download.openmmlab.com/mmediting/restorers/'
        'basicvsr/spynet_20210409-c6c1bd09.pth'),
    pixel_loss=dict(type='CharbonnierLoss', loss_weight=1.0, reduction='mean'),
    train_cfg=dict(fix_iter=5000),
    data_preprocessor=dict(
        type='EditDataPreprocessor',
        mean=[0., 0., 0.],
        std=[255., 255., 255.],
        input_view=(1, -1, 1, 1),
        output_view=(1, -1, 1, 1),
    ))

步骤 3. 准备评测器

准备好了数据和模型之后,我们就可以准备评测器了。评测器可以帮助我们在训练过程中对模型性能进行监督,从而帮助我们保存训练过程中性能表现最好的一组模型参数。如下是评测器的源码目录结构:

mmedit
├── datasets
├── models
├── evaluation
│   ├── evaluator.py
│   ├── functional
│   │   └── gaussian_funcs.py
│   └── metrics
│       ├── fid.py
│       ├── ...
│       └── ssim.py

在对应的配置文件中,我们只需要提供一个字典的序列就可以监控相关的性能了。在下面这个例子中,我们通过在 val_evaluator 这个字段中指定 PSNR 和 SSIM,就可以轻松地在训练过程中检测 BasicVSR 的 PSNR 和 SSIM 在验证集上的表现了。目前,MMEditing 1.x 使用统一的评测接口,提供了多达 20 种不同的评价指标。

# basicvsr_2xb4_reds4.py 

val_evaluator = [
    dict(type='PSNR'),
    dict(type='SSIM'),
]

在 MMEditing 中 和 MMGeneration 0.x 中,如下图所示,我们往往需要多个脚本才能完成不同任务的评测需求。

而在 MMEditing 1.x 中,我们只需要一个评测脚本就能搞定全量算法的评测。同时,MMEditing 1.x 还支持了更快的分布式评测,自动缓存特征用于重复计算,以及批量采样存图(针对 GAN 等模型)的功能。

MMEditing 1.0 的评测脚本支持更多功能

步骤 4. 准备训练配置

最后,我们在 mmedit/engine 中实现训练相关的组件,如:钩子(hook)用以实现一些自定义的逻辑,优化器(optimizers),参数调度器(schedulers),以及执行器(runner)。如下所示是对应的源码目录结构:

mmedit
├── datasets
├── models
├── evaluation
├── engine
│   ├── hooks
│   │   ├── ...
│   │   └── visualization_hook.py
│   ├── optimizers
│   │   ├── ...
│   │   └── multi_optimizer_constructor.py
│   ├── schedulers
│   │   ├── __init__.py
│   │   └── reduce_lr_scheduler.py
│   └── runner

在对应的配置文件中,我们可以在 default_hooks 中指定一些 MMEditing 1.0 已经支持的钩子,例如下面的这个例子中, IterTimerHook 用来统计训练过程中不同步骤的时间开销,用来帮助我们更好地定位训练的速度瓶颈;LoggerHook 则用来设置训练日志的打印,CheckpointHook 则可以帮助我们在训练过程中自动保存模型参数。

# basicvsr_2xb4_reds4.py 

default_hooks = dict(
    timer=dict(type='IterTimerHook'),
    logger=dict(type='LoggerHook', interval=100),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(
        type='CheckpointHook',
        interval=5000,
        out_dir=save_dir,
        by_epoch=False,
        max_keep_ckpts=10,
        save_best='PSNR',
        rule='greater',
    ),
    sampler_seed=dict(type='DistSamplerSeedHook'),
)

如下所示,我们可以在 optim_wrapper 中则可以设置优化器的参数。

# basicvsr_2xb4_reds4.py 

optim_wrapper = dict(
    constructor='DefaultOptimWrapperConstructor',
    type='OptimWrapper',
    optimizer=dict(type='Adam', lr=2e-4, betas=(0.9, 0.99)),
    paramwise_cfg=dict(custom_keys={'spynet': dict(lr_mult=0.125)}))

步骤 5. 启动训练

python tools/train.py configs/basicvsr/basicvsr_2xb4_reds4.py

最后一切准备就绪,我们只需要上面这行指令,就可以启动整个训练啦!

当你看到屏幕上打印出下列的这些日志,就说明我们成功了!

------------------------------------------------------------
- Exp name: basicvsr_2xb4_reds4_20220822_204628
08/22 20:49:05 - mmengine - [INFO] - Iter(train) [100/300000]  lr: 2.5000e-05  eta: 4 days, 5:06:47  time: 1.2138  data_time: 0.1251  memory: 10217  loss: 0.0348
08/22 20:50:26 - mmengine - [INFO] - Iter(train) [200/300000]  lr: 2.5000e-05  eta: 3 days, 12:13:50  time: 0.8091  data_time: 0.0131  memory: 10217  loss: 0.0318
08/22 20:51:47 - mmengine - [INFO] - Iter(train) [300/300000]  lr: 2.5000e-05  eta: 3 days, 6:39:12  time: 0.8115  data_time: 0.0140  memory: 10217  loss: 0.0293
08/22 20:53:07 - mmengine - [INFO] - Iter(train) [400/300000]  lr: 2.5000e-05  eta: 3 days, 3:42:29  time: 0.8045  data_time: 0.0153  memory: 10217  loss: 0.0279
08/22 20:54:29 - mmengine - [INFO] - Iter(train) [500/300000]  lr: 2.5000e-05  eta: 3 days, 2:04:00  time: 0.8126  data_time: 0.0139  memory: 10217  loss: 0.0269
08/22 20:55:50 - mmengine - [INFO] - Iter(train) [600/300000]  lr: 2.5000e-05  eta: 3 days, 0:57:22  time: 0.8119  data_time: 0.0136  memory: 10217  loss: 0.0266
08/22 20:57:11 - mmengine - [INFO] - Iter(train) [700/300000]  lr: 2.5000e-05  eta: 3 days, 0:09:15  time: 0.8118  data_time: 0.0146  memory: 10217  loss: 0.0271
08/22 20:58:32 - mmengine - [INFO] - Iter(train) [800/300000]  lr: 2.5000e-05  eta: 2 days, 23:30:09  time: 0.8075  data_time: 0.0141  memory: 10217  loss: 0.0260
08/22 20:59:53 - mmengine - [INFO] - Iter(train) [900/300000]  lr: 2.5000e-05  eta: 2 days, 23:02:37  time: 0.8132  data_time: 0.0154  memory: 10217  loss: 0.0265
08/22 21:01:18 - mmengine - [INFO] - Exp name: basicvsr_2xb4_reds4_20220822_204628
08/22 21:01:18 - mmengine - [INFO] - Iter(train) [1000/300000]  lr: 2.5000e-05  eta: 2 days, 22:59:48  time: 0.8523  data_time: 0.0174  memory: 10217  loss: 0.0257
08/22 21:02:46 - mmengine - [INFO] - Iter(train) [1100/300000]  lr: 2.5000e-05  eta: 2 days, 23:09:22  time: 0.8791  data_time: 0.0195  memory: 10217  loss: 0.0253
08/22 21:04:09 - mmengine - [INFO] - Iter(train) [1200/300000]  lr: 2.5000e-05  eta: 2 days, 22:54:59  time: 0.8258  data_time: 0.0165  memory: 10217  loss: 0.0256
08/22 21:05:30 - mmengine - [INFO] - Iter(train) [1300/300000]  lr: 2.5000e-05  eta: 2 days, 22:38:49  time: 0.8159  data_time: 0.0164  memory: 10217  loss: 0.0247
08/22 21:06:52 - mmengine - [INFO] - Iter(train) [1400/300000]  lr: 2.5000e-05  eta: 2 days, 22:23:46  time: 0.8131  data_time: 0.0172  memory: 10217  loss: 0.0252

步骤 6. 推理和部署

除了模型的训练和测试,我们还提供了好玩的推理 API 让用户可以更简单地玩转 MMEditing 1.0 中各种各样的不同任务。这部分的源码实现可以查看 mmedit/apis下的代码:

mmedit
├── datasets
├── models
├── evaluation
├── engine
├── visualization
├── utils
├── structures
└── apis
    ├── gan_inference.py
    ├── ...
    └── restoration_inference.py

这里我们给出了一个 10 行代码体验图像超分的例子作为参考:

# demo.py 

import mmcv
from mmedit.apis import init_model, restoration_inference
from mmedit.engine.misc import tensor2img

config = 'configs/esrgan/esrgan_x4c64b23g32_1xb16-400k_div2k.py'
checkpoint = 'https://download.openmmlab.com/mmediting/restorers/esrgan/esrgan_x4c64b23g32_1x16_400k_div2k_20200508-f8ccaf3b.pth'
img_path = 'tests/data/image/lq/baboon_x4.png'

model = init_model(config, checkpoint)
output = restoration_inference(model, img_path)
output = tensor2img(output)
mmcv.imwrite(output, 'output.png')

目前, MMDeploy 已经支持了 MMEditing 的部署啦。通过 MMDeploy,可以轻松将 MMEditing 1.0 中的算法(主要为图像超分算法)进行模型转换,并转化成使用于不同设备和推理引擎上的 SDK。

总结

在 AI 绘画时代,想要创造充满想象力、足以以假乱真的图像和视频内容,需要生成模型和图像视频质量增强技术强强联手,爆发无限可能。

今天我们介绍了 MMEditing 和 MMGeneration 的重大升级新特性,并提供了一个实战教程,希望能够帮助大家在 AI 绘画方面的研发。diffusion model 系列也已经在路上了哟,诚邀社区用户加入我们,在 GitHub 或者社区用户群联系我们,一起变得更强!

想了解更多 MMEditing 1.0 相关内容,欢迎查看直播回放视频~

开发维护计划

MMEditing 目前有两个分支,分别是 master 和 1.x (https://github.com/open-mmlab/mmediting/tree/1.x ) 分支,它们会经历以下三个阶段:

阶段

时间

分支

说明

公测期

2022/9/1 - 2022.12.31

公测版代码发布在 1.x 分支; 默认主分支 master 仍对应 0.x 版本

master 和 1.x 分支正常进行迭代

兼容期

2023/1/1 - 2023.12.31

切换默认主分支 master 为 1.x 版本; 0.x 分支对应 0.x 版本

保持对旧版本 0.x 的维护和开发,响应用户需求,但尽量不引进破坏旧版本兼容性的改动; master 分支正常进行迭代

维护期

2024/1/1 - 待定

默认主分支 master 为 1.x 版本; 0.x 分支对应 0.x 版本

0.x 分支进入维护阶段,不再进行新功能支持; master 分支正常进行迭代

我们非常鼓励社区小伙伴们能够在公测期就可以尽早尝试 MMEditing 1.0,然后能在兼容期很好地过渡到 MMEditing 1.0,在这个过程中,我们一定也会帮助到大家。

  • 如果在使用过程中,你有任何的疑问,可以来 Github 的Discussions 页面来进行讨论:https://github.com/open-mmlab/mmediting/discussions
  • 如果你碰到了 bug 或者想要吐槽文档中遇到的问题,可以提出你的 issue:https://github.com/open-mmlab/mmediting/issues
  • 如果你支持了新功能或者修复了 bug,那就更不要犹豫啦,快来提交你的修改,让你的代码在 MMEditing 中闪闪发光:https://github.com/open-mmlab/mmediting/pulls