MMCV学习——基础篇2(Runner)

1. 主要特性

 Runner的目的就是给用户提供统一的训练流程管理,并支持弹性、可配置的定制化修改(通过Hook实现),因此其主要特性如下:

  • 默认支持Epoch和Iter为基础迭代训练的EpochBasedRunnerIterBasedRunnner,同时也支持用户实现自定义Runner。
  • 支持自定义的工作流以满足训练过程中各状态自由切换,目前支持训练(train)和验证(val)两个工作流。
  • 配合各类钩子函数(Hook),对外提供了灵活的扩展能力,注入不同类型的 Hook,就可以在训练过程中以一种优雅的方式实现扩展功能。

2. EpochBasedRunner & IterBasedRunner

 顾名思义,EpochBasedRunner就是以Epoch为基础迭代的Runner,下面我们实现一个简单的例子去演示它的工作流workflow控制原理。

class ToyRunner(nn.Module):

    def __init__(self):
        super().__init__()
    
    def train(self, data_loader, **kwargs):
        print(data_loader)
        

    def val(self, data_loader, **kwargs):
        print(data_loader)
    
    def run(self):
        # training epochs
        max_epochs = 3
        curr_epoch = 0
        # denotes 2 epochs for training and 1 epoch for validation
        workflow = [("train", 2), ("val", 1)]
        data_loaders = ["dl_train", "dl_val"]
        # the condition to stop training
        while curr_epoch < max_epochs: 
            # workflow
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                epoch_func = getattr(self, mode)
                for _ in range(epochs):
                    if mode == 'train' and curr_epoch >= max_epochs:
                        break
                    epoch_func(f'data_loader: {data_loaders[i]}, epoch={curr_epoch}')
                    if mode == 'train':
                        # validation doesn't affect curr_epoch
                        curr_epoch += 1

 然后我们运行一下ToyRunner:

runner = ToyRunner()
runner.run()
"""
Output:
data_loader: dl_train, epoch=0
data_loader: dl_train, epoch=1
data_loader: dl_val, epoch=2
data_loader: dl_train, epoch=2
data_loader: dl_val, epoch=3
"""

 上面代码逻辑十分简单,博主这里说一下有几点需要注意的:

  • workflow在这里代表训练2个epochs之后再验证1个epoch,其长度为2需要和data_loaders长度一致。可以理解为这里的workflow有train和val两个flow,因此就需要dl_train和dl_val这两个data_loader提供数据。
  • max_epoch代表的是训练epoch,所以只有当mode为train的时候epoch才会增加
  • 只有当mode为train并且curr_epoch>=max_epochs时才会break,这就保证了最后一次train epoch一定会被验证。
  • IterBasedRunner原理类似,并且两者都继承了一个BaseRunner的基础类,这里就不再赘述了,有兴趣的读者可以点击蓝色字体去看GitHub的源代码。

3. A Simple Example

3.1 Tool Function

 接下来我们通过一个简单的例子,按照mmcv的规范去使用一下mmcv提供的Runner类。首先,我们定义一些构建数据集的工具函数

import platform
import random
from functools import partial

import numpy as np
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import digit_version
from torch.utils.data import DataLoader, IterableDataset


if platform.system() != 'Windows':
    # https://github.com/pytorch/pytorch/issues/973
    import resource
    rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
    base_soft_limit = rlimit[0]
    hard_limit = rlimit[1]
    soft_limit = min(max(4096, base_soft_limit), hard_limit)
    resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))


def build_dataloader(dataset,
                     samples_per_gpu,
                     workers_per_gpu,
                     num_gpus=1,
                     dist=False,
                     shuffle=True,
                     seed=None,
                     drop_last=False,
                     pin_memory=True,
                     persistent_workers=True,
                     **kwargs):
    """Build PyTorch DataLoader.
    In distributed training, each GPU/process has a dataloader.
    In non-distributed training, there is only one dataloader for all GPUs.
    Args:
        dataset (Dataset): A PyTorch dataset.
        samples_per_gpu (int): Number of training samples on each GPU, i.e.,
            batch size of each GPU.
        workers_per_gpu (int): How many subprocesses to use for data loading
            for each GPU.
        num_gpus (int): Number of GPUs. Only used in non-distributed training.
        dist (bool): Distributed training/test or not. Default: True.
        shuffle (bool): Whether to shuffle the data at every epoch.
            Default: True.
        seed (int | None): Seed to be used. Default: None.
        drop_last (bool): Whether to drop the last incomplete batch in epoch.
            Default: False
        pin_memory (bool): Whether to use pin_memory in DataLoader.
            Default: True
        persistent_workers (bool): If True, the data loader will not shutdown
            the worker processes after a dataset has been consumed once.
            This allows to maintain the workers Dataset instances alive.
            The argument also has effect in PyTorch>=1.7.0.
            Default: True
        kwargs: any keyword argument to be used to initialize DataLoader
    Returns:
        DataLoader: A PyTorch dataloader.
    """
    rank, world_size = get_dist_info()
    if dist and not isinstance(dataset, IterableDataset):
        # not support dist for notebook
        pass
    elif dist:
        sampler = None
        shuffle = False
        batch_size = samples_per_gpu
        num_workers = workers_per_gpu
    else:
        sampler = None
        batch_size = num_gpus * samples_per_gpu
        num_workers = num_gpus * workers_per_gpu

    init_fn = partial(
        worker_init_fn, num_workers=num_workers, rank=rank,
        seed=seed) if seed is not None else None

    if digit_version(torch.__version__) >= digit_version('1.8.0'):
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            sampler=sampler,
            num_workers=num_workers,
            collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
            pin_memory=pin_memory,
            shuffle=shuffle,
            worker_init_fn=init_fn,
            drop_last=drop_last,
            persistent_workers=persistent_workers,
            **kwargs)
    else:
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            sampler=sampler,
            num_workers=num_workers,
            collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
            pin_memory=pin_memory,
            shuffle=shuffle,
            worker_init_fn=init_fn,
            drop_last=drop_last,
            **kwargs)

    return data_loader


def worker_init_fn(worker_id, num_workers, rank, seed):
    """Worker init func for dataloader.
    The seed of each worker equals to num_worker * rank + worker_id + user_seed
    Args:
        worker_id (int): Worker id.
        num_workers (int): Number of workers.
        rank (int): The rank of current process.
        seed (int): The random seed to use.
    """

    worker_seed = num_workers * rank + worker_id + seed
    np.random.seed(worker_seed)
    random.seed(worker_seed)
    torch.manual_seed(worker_seed)

 上述的工具函数来自MMSegmentation,我在这里为了运行演示方便做了一些删减。

3.2 Build Model

 接下来我们按照Runner的要求编写一个简单的模型

from mmcv.runner import BaseModule


class ToyModel(BaseModule):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Linear(10, 2)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        out = self.backbone(x)
        return out

    def train_step(self, data_batch, optimizer, **kwargs):
        labels, imgs = data_batch
        preds = self(imgs)
        loss = self.criterion(preds, labels)
        log_vars = dict(train_loss=loss.item())
        num_samples = len(imgs)
        outputs = dict(loss=loss,
                       preds=preds,
                       log_vars=log_vars,
                       num_samples=num_samples)
        return outputs

    def val_step(self, data_batch, optimizer, **kwargs):
        labels, imgs = data_batch
        preds = self(imgs)
        loss = self.criterion(preds, labels)
        log_vars = dict(val_loss=loss.item())
        num_samples = len(imgs)
        outputs = dict(log_vars=log_vars,
                       num_samples=num_samples)
        return outputs

 上述的模型代码有几点需要注意的:

  • 我们再mmcv的框架下实现自己的Module时需要继承mmcv.runner下的BaseModule,它主要有三个属性/方法:1)init_cfg用来控制模型初始化的配置;2)init_weights参数初始化的函数,记录着参数初始化的信息;3)_params_init_info用来追踪参数初始化信息的defaultdict<nn.Parameter, dict>,该属性仅仅在init_weights函数执行的时候生成,在所有参数初始化完成之后删除。
  • 按照Runner的要求,我们的模型需要实现train_stepval_step两个方法。
  • train_stepval_step的返回结果都是一个字典,其中train_step返回的loss是为了后续在optimizer hook中进行反向传播,而log_varsnum_samples则是log hook输出需要的变量。

3.3 Build Dataset

 接下来是一个只有一个类别的简单数据集类:

class ToyDataset(torch.utils.data.Dataset):

    def __init__(self, data) -> None:
        super().__init__()
        self.data = data

    def __getitem__(self, idx):
        return 0, self.data[idx]
    
    def __len__(self):
        return len(self.data)

3.4 Running Script

 我们写一段脚本去运行mmcv的Runner并查看效果,下面是初始化参数的定义代码:

from mmcv import ConfigDict
from mmcv.runner import build_optimizer

# initialize
model = ToyModel()  # model
cfg = ConfigDict(data=dict(samples_per_gpu=2, workers_per_gpu=1),
                 workflow=[('train', 1), ('val', 1)],
                 optimizer=dict(type='SGD',
                                lr=0.1,
                                momentum=0.9,
                                weight_decay=0.0001),
                 lr_config=dict(policy='step', step=[100, 150]),
                 log_config=dict(interval=1,
                                 hooks=[
                                     dict(type='TextLoggerHook',
                                          by_epoch=True),
                                 ]),
                 runner=dict(type='EpochBasedRunner', max_epochs=3))  # config
optimizer = build_optimizer(model, cfg.optimizer)  # optimizer
ds_train = ToyDataset(torch.rand(5, 10))  # training loader
ds_val = ToyDataset(torch.rand(3, 10))  # vallidation loader
datasets = [ds_train, ds_val]  # dataset
# data_loaders
data_loaders = [
    build_dataloader(ds, cfg.data.samples_per_gpu, cfg.data.workers_per_gpu)
    for ds in datasets
]

 然后我们build runner并注册好参数就可以一键训练模型了:

from mmcv.runner import build_runner
from mmcv.utils import get_logger


# get logger
logger = get_logger(name='toyproj', log_file='logger.log')
# initialize runner
runner = build_runner(cfg.runner,
                      default_args=dict(model=model,
                                        batch_processor=None,
                                        optimizer=optimizer,
                                        work_dir='./',
                                        logger=logger))
# register default hooks necessary for training
runner.register_training_hooks(
    # configs of learning rate, it is typically set as:
    lr_config=cfg.lr_config,
    # configuration of logs
    log_config=cfg.log_config)
# start running
runner.run(data_loaders, cfg.workflow)

'''
Output:
2022-11-23 11:55:50,539 - toyproj - INFO - workflow: [('train', 1), ('val', 1)], max: 3 epochs
2022-11-23 11:55:55,541 - toyproj - INFO - Epoch [1][1/3]	lr: 1.000e-01, eta: 0:00:39, time: 4.998, data_time: 4.995, memory: 0, train_loss: 0.6041
2022-11-23 11:55:55,548 - toyproj - INFO - Epoch [1][2/3]	lr: 1.000e-01, eta: 0:00:17, time: 0.009, data_time: 0.009, memory: 0, train_loss: 0.5542
2022-11-23 11:55:55,552 - toyproj - INFO - Epoch [1][3/3]	lr: 1.000e-01, eta: 0:00:10, time: 0.003, data_time: 0.003, memory: 0, train_loss: 0.5444
2022-11-23 11:55:57,674 - toyproj - INFO - Epoch(val) [1][2]	val_loss: 0.3617
2022-11-23 11:55:59,746 - toyproj - INFO - Epoch [2][1/3]	lr: 1.000e-01, eta: 0:00:08, time: 2.067, data_time: 2.066, memory: 0, train_loss: 0.5542
2022-11-23 11:55:59,752 - toyproj - INFO - Epoch [2][2/3]	lr: 1.000e-01, eta: 0:00:05, time: 0.007, data_time: 0.007, memory: 0, train_loss: 0.6041
2022-11-23 11:55:59,756 - toyproj - INFO - Epoch [2][3/3]	lr: 1.000e-01, eta: 0:00:03, time: 0.004, data_time: 0.004, memory: 0, train_loss: 0.5444
2022-11-23 11:56:01,918 - toyproj - INFO - Epoch(val) [2][2]	val_loss: 0.3617
2022-11-23 11:56:03,991 - toyproj - INFO - Epoch [3][1/3]	lr: 1.000e-01, eta: 0:00:02, time: 2.068, data_time: 2.067, memory: 0, train_loss: 0.6041
2022-11-23 11:56:03,998 - toyproj - INFO - Epoch [3][2/3]	lr: 1.000e-01, eta: 0:00:01, time: 0.008, data_time: 0.007, memory: 0, train_loss: 0.5897
2022-11-23 11:56:04,002 - toyproj - INFO - Epoch [3][3/3]	lr: 1.000e-01, eta: 0:00:00, time: 0.004, data_time: 0.004, memory: 0, train_loss: 0.4735
2022-11-23 11:56:06,181 - toyproj - INFO - Epoch(val) [3][2]	val_loss: 0.3617

'''
11-24 11:41