Preparing Data for YOLO-World

Overview

For pre-training YOLO-World, we adopt several datasets as listed in the below table:

Dataset Directory

We put all data into the data directory, such as:

├── coco
│   ├── annotations
│   ├── lvis
│   ├── train2017
│   ├── val2017
├── flickr
│   ├── annotations
│   └── images
├── mixed_grounding
│   ├── annotations
│   ├── images
├── mixed_grounding
│   ├── annotations
│   ├── images
├── objects365v1
│   ├── annotations
│   ├── train
│   ├── val

NOTE: We strongly suggest that you check the directories or paths in the dataset part of the config file, especially for the values ann_file, data_root, and data_prefix.

We provide the annotations of the pre-training data in the below table:

Acknowledgement: We sincerely thank GLIP and mdetr for providing the annotation files for pre-training.

Dataset Class

For training YOLO-World, we mainly adopt two kinds of dataset classs:

1. MultiModalDataset

MultiModalDataset is a simple wrapper for pre-defined Dataset Class, such as Objects365 or COCO, which add the texts (category texts) into the dataset instance for formatting input texts.

Text JSON

The json file is formatted as follows:

[
    ['A_1','A_2'],
    ['B'],
    ['C_1', 'C_2', 'C_3'],
    ...
]

We have provided the text json for LVIS, COCO, and Objects365

2. YOLOv5MixedGroundingDataset

The YOLOv5MixedGroundingDataset extends the COCO dataset by supporting loading texts/captions from the json file. It’s desgined for MixedGrounding or Flickr30K with text tokens for each object.

🔥 Custom Datasets

For custom dataset, we suggest the users convert the annotation files according to the usage. Note that, converting the annotations to the standard COCO format is basically required.

  1. Large vocabulary, grounding, referring: you can follow the annotation format as the MixedGrounding dataset, which adds caption and tokens_positive for assigning the text for each object. The texts can be a category or a noun phrases.

  2. Custom vocabulary (fixed): you can adopt the MultiModalDataset wrapper as the Objects365 and create a text json for your custom categories.

Fine-tuning YOLO-World

Fine-tuning YOLO-World is easy and we provide the samples for COCO object detection as a simple guidance.

Fine-tuning Requirements

Fine-tuning YOLO-World is cheap:

  • it does not require 32 GPUs for multi-node distributed training. 8 GPUs or even 1 GPU is enough.

  • it does not require the long schedule, e.g., 300 epochs or 500 epochs for training YOLOv5 or YOLOv8. 80 epochs or fewer is enough considering that we provide the good pre-trained weights.

Data Preparation

The fine-tuning dataset should have the similar format as the that of the pre-training dataset.
We suggest you refer to docs/data for more details about how to build the datasets:

  • if you fine-tune YOLO-World for close-set / custom vocabulary object detection, using MultiModalDataset with a text json is preferred.

  • if you fine-tune YOLO-World for open-vocabulary detection with rich texts or grounding tasks, using MixedGroundingDataset is preferred.

Hyper-parameters and Config

Please refer to the config for fine-tuning YOLO-World-L on COCO for more details.

  1. Basic config file:

If the fine-tuning dataset contains mask annotations:

_base_ = ('../../third_party/mmyolo/configs/yolov8/yolov8_l_mask-refine_syncbn_fast_8xb16-500e_coco.py')

If the fine-tuning dataset doesn’t contain mask annotations:

_base_ = ('../../third_party/mmyolo/configs/yolov8/yolov8_l_syncbn_fast_8xb16-500e_coco.py')
  1. Training Schemes:

Reducing the epochs and adjusting the learning rate

max_epochs = 80
base_lr = 2e-4
weight_decay = 0.05
train_batch_size_per_gpu = 16
close_mosaic_epochs=10

train_cfg = dict(
    max_epochs=max_epochs,
    val_interval=5,
    dynamic_intervals=[((max_epochs - close_mosaic_epochs),
                        _base_.val_interval_stage2)])

  1. Datasets:
coco_train_dataset = dict(
    _delete_=True,
    type='MultiModalDataset',
    dataset=dict(
        type='YOLOv5CocoDataset',
        data_root='data/coco',
        ann_file='annotations/instances_train2017.json',
        data_prefix=dict(img='train2017/'),
        filter_cfg=dict(filter_empty_gt=False, min_size=32)),
    class_text_path='data/texts/coco_class_texts.json',
    pipeline=train_pipeline)
Finetuning without RepVL-PAN or Text Encoder 🚀

For further efficiency and simplicity, we can fine-tune an efficient version of YOLO-World without RepVL-PAN and the text encoder.
The efficient version of YOLO-World has the similar architecture or layers with the orignial YOLOv8 but we provide the pre-trained weights on large-scale datasets.
The pre-trained YOLO-World has strong generalization capabilities and is more robust compared to YOLOv8 trained on the COCO dataset.

You can refer to the config for Efficient YOLO-World for more details.

The efficient YOLO-World adopts EfficientCSPLayerWithTwoConv and the text encoder can be removed during inference or exporting models.


model = dict(
    type='YOLOWorldDetector',
    mm_neck=True,
    neck=dict(type='YOLOWorldPAFPN',
              guide_channels=text_channels,
              embed_channels=neck_embed_channels,
              num_heads=neck_num_heads,
              block_cfg=dict(type='EfficientCSPLayerWithTwoConv')))

Launch Fine-tuning!

It’s easy:

./dist_train.sh <path/to/config> <NUM_GPUS> --amp

COCO Fine-tuning

Update Notes

We provide the details for important updates of YOLO-World in this note.

Model Architecture

[2024-2-29]: YOLO-World-v2:

  1. We remove the I-PoolingAttention: though it improves the performance for zero-shot LVIS evaluation, it affects the inference speeds after exporting YOLO-World to ONNX or TensorRT. Considering the trade-off, we remove the I-PoolingAttention in the newest version.
  2. We replace the L2-Norm in the contrastive head with the BatchNorm. The L2-Norm contains complex operations, such as reduce, which is time-consuming for deployment. However, the BatchNorm can be fused into the convolution, which is much more efficient and also improves the zero-shot performance.

.\YOLO-World\image_demo.py

# 版权声明
# 导入必要的库
import os
import cv2
import argparse
import os.path as osp

import torch
from mmengine.config import Config, DictAction
from mmengine.runner import Runner
from mmengine.runner.amp import autocast
from mmengine.dataset import Compose
from mmengine.utils import ProgressBar
from mmyolo.registry import RUNNERS

# 定义BOUNDING_BOX_ANNOTATOR对象
BOUNDING_BOX_ANNOTATOR = None
# 定义LABEL_ANNOTATOR对象
LABEL_ANNOTATOR = None

# 解析命令行参数
def parse_args():
    parser = argparse.ArgumentParser(description='YOLO-World Demo')
    # 添加命令行参数
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('image', help='image path, include image file or dir.')
    parser.add_argument(
        'text',
        help='text prompts, including categories separated by a comma or a txt file with each line as a prompt.'
    )
    parser.add_argument('--topk',
                        default=100,
                        type=int,
                        help='keep topk predictions.')
    parser.add_argument('--threshold',
                        default=0.0,
                        type=float,
                        help='confidence score threshold for predictions.')
    parser.add_argument('--device',
                        default='cuda:0',
                        help='device used for inference.')
    parser.add_argument('--show',
                        action='store_true',
                        help='show the detection results.')
    parser.add_argument('--annotation',
                        action='store_true',
                        help='save the annotated detection results as yolo text format.')
    parser.add_argument('--amp',
                        action='store_true',
                        help='use mixed precision for inference.')
    # 添加一个名为'--output-dir'的命令行参数,用于指定保存输出的目录,默认为'demo_outputs'
    parser.add_argument('--output-dir',
                        default='demo_outputs',
                        help='the directory to save outputs')
    # 添加一个名为'--cfg-options'的命令行参数,用于覆盖配置文件中的一些设置,支持键值对形式的参数
    # 如果要覆盖的值是列表,则应该以 key="[a,b]" 或 key=a,b 的格式提供
    # 还支持嵌套列表/元组值,例如 key="[(a,b),(c,d)]"
    # 注意引号是必要的,不允许有空格
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    # 解析命令行参数
    args = parser.parse_args()
    # 返回解析后的参数
    return args
# 推断检测器,运行模型进行推断
def inference_detector(runner,
                       image_path,
                       texts,
                       max_dets,
                       score_thr,
                       output_dir,
                       use_amp=False,
                       show=False,
                       annotation=False):
    # 创建包含图像信息的字典
    data_info = dict(img_id=0, img_path=image_path, texts=texts)
    # 运行数据处理管道
    data_info = runner.pipeline(data_info)
    # 创建包含数据批次信息的字典
    data_batch = dict(inputs=data_info['inputs'].unsqueeze(0),
                      data_samples=[data_info['data_samples']])

    # 使用自动混合精度和禁用梯度计算
    with autocast(enabled=use_amp), torch.no_grad():
        # 运行模型的测试步骤
        output = runner.model.test_step(data_batch)[0]
        pred_instances = output.pred_instances
        # 通过设置阈值过滤预测实例
        pred_instances = pred_instances[
            pred_instances.scores.float() > score_thr]
    # 如果预测实例数量超过最大检测数
    if len(pred_instances.scores) > max_dets:
        # 选择得分最高的前 max_dets 个预测实例
        indices = pred_instances.scores.float().topk(max_dets)[1]
        pred_instances = pred_instances[indices]

    # 将预测实例转换为 numpy 数组
    pred_instances = pred_instances.cpu().numpy()
    # 定义检测对象
    detections = None

    # 为每个检测结果添加标签
    labels = [
        f"{texts[class_id][0]} {confidence:0.2f}" for class_id, confidence in
        zip(detections.class_id, detections.confidence)
    ]

    # 读取图像
    image = cv2.imread(image_path)
    anno_image = image.copy()
    # 在图像上绘制边界框
    image = BOUNDING_BOX_ANNOTATOR.annotate(image, detections)
    # 在图像上添加标签
    image = LABEL_ANNOTATOR.annotate(image, detections, labels=labels)
    # 将标记后的图像保存到输出目录
    cv2.imwrite(osp.join(output_dir, osp.basename(image_path)), image)
    # 如果有注释
    if annotation:
        # 创建空字典用于存储图像和注释
        images_dict = {}
        annotations_dict = {}

        # 将图像路径的基本名称作为键,注释图像作为值存储在图像字典中
        images_dict[osp.basename(image_path)] = anno_image
        # 将图像路径的基本名称作为键,检测结果作为值存储在注释字典中
        annotations_dict[osp.basename(image_path)] = detections
        
        # 创建一个名为ANNOTATIONS_DIRECTORY的目录,如果目录已存在则不创建
        ANNOTATIONS_DIRECTORY =  os.makedirs(r"./annotations", exist_ok=True)

        # 设置最小图像面积百分比
        MIN_IMAGE_AREA_PERCENTAGE = 0.002
        # 设置最大图像面积百分比
        MAX_IMAGE_AREA_PERCENTAGE = 0.80
        # 设置近似百分比
        APPROXIMATION_PERCENTAGE = 0.75
        
        # 创建一个DetectionDataset对象,传入类别、图像字典和注释字典,然后转换为YOLO格式
        sv.DetectionDataset(
            classes=texts,
            images=images_dict,
            annotations=annotations_dict
        ).as_yolo(
            annotations_directory_path=ANNOTATIONS_DIRECTORY,
            min_image_area_percentage=MIN_IMAGE_AREA_PERCENTAGE,
            max_image_area_percentage=MAX_IMAGE_AREA_PERCENTAGE,
            approximation_percentage=APPROXIMATION_PERCENTAGE
        )

    # 如果需要展示图像
    if show:
        # 在窗口中展示图像,提供窗口名称
        cv2.imshow('Image', image)
        # 等待按键输入,0表示一直等待
        k = cv2.waitKey(0)
        # 如果按下ESC键(ASCII码为27),关闭所有窗口
        if k == 27:
            cv2.destroyAllWindows()
if __name__ == '__main__':
    # 解析命令行参数
    args = parse_args()

    # 加载配置文件
    cfg = Config.fromfile(args.config)
    # 如果有额外的配置选项,则合并到配置文件中
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    # 设置工作目录为当前目录下的 work_dirs 文件夹中,使用配置文件名作为子目录名
    cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])

    # 加载模型检查点
    cfg.load_from = args.checkpoint

    # 根据配置文件中是否包含 runner_type 字段来选择不同的 Runner 类型
    if 'runner_type' not in cfg:
        runner = Runner.from_cfg(cfg)
    else:
        runner = RUNNERS.build(cfg)

    # 加载文本数据
    if args.text.endswith('.txt'):
        with open(args.text) as f:
            lines = f.readlines()
        # 将文本数据转换为列表形式
        texts = [[t.rstrip('\r\n')] for t in lines] + [[' ']]
    else:
        # 将命令行参数中的文本数据转换为列表形式
        texts = [[t.strip()] for t in args.text.split(',')] + [[' ']]

    # 设置输出目录
    output_dir = args.output_dir
    # 如果输出目录不存在,则创建
    if not osp.exists(output_dir):
        os.mkdir(output_dir)

    # 在运行之前调用钩子函数
    runner.call_hook('before_run')
    # 加载或恢复模型
    runner.load_or_resume()
    # 获取数据处理流程
    pipeline = cfg.test_dataloader.dataset.pipeline
    runner.pipeline = Compose(pipeline)
    # 设置模型为评估模式
    runner.model.eval()

    # 检查输入的图像路径是否为文件夹
    if not osp.isfile(args.image):
        # 获取文件夹中所有以 .png 或 .jpg 结尾的图像文件路径
        images = [
            osp.join(args.image, img) for img in os.listdir(args.image)
            if img.endswith('.png') or img.endswith('.jpg')
        ]
    else:
        # 将输入的图像路径转换为列表形式
        images = [args.image]

    # 创建进度条对象,用于显示处理进度
    progress_bar = ProgressBar(len(images))
    # 遍历每张图像进行目标检测
    for image_path in images:
        # 调用目标检测函数进行推理
        inference_detector(runner,
                           image_path,
                           texts,
                           args.topk,
                           args.threshold,
                           output_dir=output_dir,
                           use_amp=args.amp,
                           show=args.show,
                           annotation=args.annotation)
        # 更新进度条
        progress_bar.update()

.\YOLO-World\tools\test.py

# 版权声明
# 导入必要的库
import argparse
import os
import os.path as osp

# 导入自定义模块
from mmdet.engine.hooks.utils import trigger_visualization_hook
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.evaluator import DumpResults
from mmengine.runner import Runner

# 导入自定义模块
from mmyolo.registry import RUNNERS
from mmyolo.utils import is_metainfo_lower

# 定义解析命令行参数的函数
def parse_args():
    # 创建 ArgumentParser 对象,设置描述信息
    parser = argparse.ArgumentParser(
        description='MMYOLO test (and eval) a model')
    # 添加命令行参数
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument(
        '--work-dir',
        help='the directory to save the file containing evaluation metrics')
    parser.add_argument(
        '--out',
        type=str,
        help='output result file (must be a .pkl file) in pickle format')
    parser.add_argument(
        '--json-prefix',
        type=str,
        help='the prefix of the output json file without perform evaluation, '
        'which is useful when you want to format the result to a specific '
        'format and submit it to the test server')
    parser.add_argument(
        '--tta',
        action='store_true',
        help='Whether to use test time augmentation')
    parser.add_argument(
        '--show', action='store_true', help='show prediction results')
    parser.add_argument(
        '--deploy',
        action='store_true',
        help='Switch model to deployment mode')
    parser.add_argument(
        '--show-dir',
        help='directory where painted images will be saved. '
        'If specified, it will be automatically saved '
        'to the work_dir/timestamp/show_dir')
    parser.add_argument(
        '--wait-time', type=float, default=2, help='the interval of show (s)')
    # 添加一个命令行参数,用于覆盖配置文件中的一些设置,参数为字典类型,使用自定义的DictAction处理
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    
    # 添加一个命令行参数,用于指定作业启动器的类型,可选值为['none', 'pytorch', 'slurm', 'mpi'],默认为'none'
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    
    # 添加一个命令行参数,用于指定本地进程的排名,默认为0
    parser.add_argument('--local_rank', type=int, default=0)
    
    # 解析命令行参数并返回结果
    args = parser.parse_args()
    
    # 如果环境变量中没有'LOCAL_RANK',则将命令行参数中的local_rank值赋给'LOCAL_RANK'
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
    
    # 返回解析后的命令行参数
    return args
def main():
    # 解析命令行参数
    args = parse_args()

    # 加载配置文件
    cfg = Config.fromfile(args.config)
    # 用 cfg.key 的值替换 ${key}
    # cfg = replace_cfg_vals(cfg)
    cfg.launcher = args.launcher
    if args.cfg_options is not None:
        # 根据命令行参数更新配置
        cfg.merge_from_dict(args.cfg_options)

    # 确定工作目录的优先级:CLI > 配置文件中的段 > 文件名
    if args.work_dir is not None:
        # 如果 args.work_dir 不为 None,则根据 CLI 参数更新配置
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # 如果 cfg.work_dir 为 None,则使用配置文件名作为默认工作目录
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])

    # 加载模型参数
    cfg.load_from = args.checkpoint

    if args.show or args.show_dir:
        # 触发可视化钩子
        cfg = trigger_visualization_hook(cfg, args)

    if args.deploy:
        # 添加部署钩子
        cfg.custom_hooks.append(dict(type='SwitchToDeployHook'))

    # 将 `format_only` 和 `outfile_prefix` 添加到配置中
    if args.json_prefix is not None:
        cfg_json = {
            'test_evaluator.format_only': True,
            'test_evaluator.outfile_prefix': args.json_prefix
        }
        cfg.merge_from_dict(cfg_json)

    # 确定自定义元信息字段是否全部为小写
    is_metainfo_lower(cfg)
    # 如果启用了测试时间增强(TTA),则需要检查配置中是否包含必要的参数
    if args.tta:
        # 检查配置中是否包含 tta_model 和 tta_pipeline,否则无法使用 TTA
        assert 'tta_model' in cfg, 'Cannot find ``tta_model`` in config.' \
                                   " Can't use tta !"
        assert 'tta_pipeline' in cfg, 'Cannot find ``tta_pipeline`` ' \
                                      "in config. Can't use tta !"

        # 将 tta_model 合并到 model 配置中
        cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)
        test_data_cfg = cfg.test_dataloader.dataset
        while 'dataset' in test_data_cfg:
            test_data_cfg = test_data_cfg['dataset']

        # batch_shapes_cfg 会强制控制输出图像的大小,与 TTA 不兼容
        if 'batch_shapes_cfg' in test_data_cfg:
            test_data_cfg.batch_shapes_cfg = None
        test_data_cfg.pipeline = cfg.tta_pipeline

    # 根据配置构建 Runner 对象
    if 'runner_type' not in cfg:
        # 构建默认的 Runner
        runner = Runner.from_cfg(cfg)
    else:
        # 从注册表中构建自定义的 Runner,如果配置中设置了 runner_type
        runner = RUNNERS.build(cfg)

    # 添加 `DumpResults` 虚拟指标
    if args.out is not None:
        # 确保输出文件是 pkl 或 pickle 格式
        assert args.out.endswith(('.pkl', '.pickle')), \
            'The dump file must be a pkl file.'
        runner.test_evaluator.metrics.append(
            DumpResults(out_file_path=args.out))

    # 开始测试
    runner.test()
# 如果当前脚本被直接执行,则调用主函数
if __name__ == '__main__':
    main()

.\YOLO-World\tools\train.py

# 导入必要的库和模块
import argparse  # 用于解析命令行参数
import logging  # 用于记录日志
import os  # 用于操作系统相关功能
import os.path as osp  # 用于操作文件路径

from mmengine.config import Config, DictAction  # 导入Config和DictAction类
from mmengine.logging import print_log  # 导入print_log函数
from mmengine.runner import Runner  # 导入Runner类

from mmyolo.registry import RUNNERS  # 导入RUNNERS变量
from mmyolo.utils import is_metainfo_lower  # 导入is_metainfo_lower函数

# 解析命令行参数
def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')  # 创建参数解析器
    parser.add_argument('config', help='train config file path')  # 添加必需的参数
    parser.add_argument('--work-dir', help='the dir to save logs and models')  # 添加可选参数
    parser.add_argument(
        '--amp',
        action='store_true',
        default=False,
        help='enable automatic-mixed-precision training')  # 添加可选参数
    parser.add_argument(
        '--resume',
        nargs='?',
        type=str,
        const='auto',
        help='If specify checkpoint path, resume from it, while if not '
        'specify, try to auto resume from the latest checkpoint '
        'in the work directory.')  # 添加可选参数
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')  # 添加可选参数
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')  # 添加可选参数
    parser.add_argument('--local_rank', type=int, default=0)  # 添加可选参数
    args = parser.parse_args()  # 解析命令行参数
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)  # 设置环境变量LOCAL_RANK为args.local_rank的值

    return args  # 返回解析后的参数

def main():
    args = parse_args()  # 解析命令行参数并保存到args变量中

    # 加载配置文件
    cfg = Config.fromfile(args.config)  # 从配置文件路径args.config中加载配置信息
    # 用cfg.key的值替换${key}的占位符
    # 设置配置文件中的 launcher 为命令行参数中指定的 launcher
    cfg.launcher = args.launcher
    # 如果命令行参数中指定了 cfg_options,则将其合并到配置文件中
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    # 确定工作目录的优先级:CLI > 文件中的段 > 文件名
    if args.work_dir is not None:
        # 如果命令行参数中指定了 work_dir,则更新配置文件中的 work_dir
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # 如果配置文件中的 work_dir 为 None,则根据配置文件名设置默认的 work_dir
        if args.config.startswith('projects/'):
            config = args.config[len('projects/'):]
            config = config.replace('/configs/', '/')
            cfg.work_dir = osp.join('./work_dirs', osp.splitext(config)[0])
        else:
            cfg.work_dir = osp.join('./work_dirs',
                                    osp.splitext(osp.basename(args.config))[0])

    # 启用自动混合精度训练
    if args.amp is True:
        optim_wrapper = cfg.optim_wrapper.type
        if optim_wrapper == 'AmpOptimWrapper':
            print_log(
                'AMP training is already enabled in your config.',
                logger='current',
                level=logging.WARNING)
        else:
            assert optim_wrapper == 'OptimWrapper', (
                '`--amp` is only supported when the optimizer wrapper type is '
                f'`OptimWrapper` but got {optim_wrapper}.')
            cfg.optim_wrapper.type = 'AmpOptimWrapper'
            cfg.optim_wrapper.loss_scale = 'dynamic'

    # 确定恢复训练的优先级:resume from > auto_resume
    if args.resume == 'auto':
        cfg.resume = True
        cfg.load_from = None
    elif args.resume is not None:
        cfg.resume = True
        cfg.load_from = args.resume

    # 确定自定义元信息字段是否全部为小写
    is_metainfo_lower(cfg)

    # 从配置文件构建 runner
    # 如果配置中没有指定 'runner_type'
    if 'runner_type' not in cfg:
        # 构建默认的运行器
        runner = Runner.from_cfg(cfg)
    else:
        # 从注册表中构建定制的运行器
        # 如果配置中设置了 'runner_type'
        runner = RUNNERS.build(cfg)

    # 开始训练
    runner.train()
# 如果当前脚本被直接执行,则调用主函数
if __name__ == '__main__':
    main()

.\YOLO-World\yolo_world\datasets\mm_dataset.py

# 导入所需的模块和类
import copy
import json
import logging
from typing import Callable, List, Union

from mmengine.logging import print_log
from mmengine.dataset.base_dataset import (
        BaseDataset, Compose, force_full_init)
from mmyolo.registry import DATASETS

# 注册MultiModalDataset类到DATASETS
@DATASETS.register_module()
class MultiModalDataset:
    """Multi-modal dataset."""

    def __init__(self,
                 dataset: Union[BaseDataset, dict],
                 class_text_path: str = None,
                 test_mode: bool = True,
                 pipeline: List[Union[dict, Callable]] = [],
                 lazy_init: bool = False) -> None:
        # 初始化dataset属性
        self.dataset: BaseDataset
        if isinstance(dataset, dict):
            self.dataset = DATASETS.build(dataset)
        elif isinstance(dataset, BaseDataset):
            self.dataset = dataset
        else:
            raise TypeError(
                'dataset must be a dict or a BaseDataset, '
                f'but got {dataset}')

        # 加载类别文本文件
        if class_text_path is not None:
            self.class_texts = json.load(open(class_text_path, 'r'))
            # ori_classes = self.dataset.metainfo['classes']
            # assert len(ori_classes) == len(self.class_texts), \
            #     ('The number of classes in the dataset and the class text'
            #      'file must be the same.')
        else:
            self.class_texts = None

        # 设置测试模式
        self.test_mode = test_mode
        # 获取数据集的元信息
        self._metainfo = self.dataset.metainfo
        # 初始化数据处理pipeline
        self.pipeline = Compose(pipeline)

        # 标记是否已完全初始化
        self._fully_initialized = False
        # 如果不是延迟初始化,则进行完全初始化
        if not lazy_init:
            self.full_init()

    @property
    def metainfo(self) -> dict:
        # 返回元信息的深拷贝
        return copy.deepcopy(self._metainfo)

    def full_init(self) -> None:
        """``full_init`` dataset."""
        # 如果已经完全初始化,则直接返回
        if self._fully_initialized:
            return

        # 对数据集进行完全初始化
        self.dataset.full_init()
        self._ori_len = len(self.dataset)
        self._fully_initialized = True

    @force_full_init
    # 根据索引获取数据信息,返回一个字典
    def get_data_info(self, idx: int) -> dict:
        """Get annotation by index."""
        # 通过数据集对象获取指定索引的数据信息
        data_info = self.dataset.get_data_info(idx)
        # 如果类别文本不为空,则将其添加到数据信息字典中
        if self.class_texts is not None:
            data_info.update({'texts': self.class_texts})
        return data_info

    # 根据索引获取数据
    def __getitem__(self, idx):
        # 如果数据集未完全初始化,则打印警告信息并手动调用`full_init`方法以加快速度
        if not self._fully_initialized:
            print_log(
                'Please call `full_init` method manually to '
                'accelerate the speed.',
                logger='current',
                level=logging.WARNING)
            self.full_init()

        # 获取数据信息
        data_info = self.get_data_info(idx)

        # 如果数据集具有'test_mode'属性且不为测试模式,则将数据集信息添加到数据信息字典中
        if hasattr(self.dataset, 'test_mode') and not self.dataset.test_mode:
            data_info['dataset'] = self
        # 如果不是测试模式,则将数据集信息添加到数据信息字典中
        elif not self.test_mode:
            data_info['dataset'] = self
        # 返回经过管道处理后的数据信息
        return self.pipeline(data_info)

    # 返回数据集的长度
    @force_full_init
    def __len__(self) -> int:
        return self._ori_len
# 注册 MultiModalMixedDataset 类到 DATASETS 模块
@DATASETS.register_module()
class MultiModalMixedDataset(MultiModalDataset):
    """Multi-modal Mixed dataset.
    mix "detection dataset" and "caption dataset"
    Args:
        dataset_type (str): dataset type, 'detection' or 'caption'
    """

    # 初始化方法,接受多种参数,包括 dataset、class_text_path、dataset_type、test_mode、pipeline 和 lazy_init
    def __init__(self,
                 dataset: Union[BaseDataset, dict],
                 class_text_path: str = None,
                 dataset_type: str = 'detection',
                 test_mode: bool = True,
                 pipeline: List[Union[dict, Callable]] = [],
                 lazy_init: bool = False) -> None:
        # 设置 dataset_type 属性
        self.dataset_type = dataset_type
        # 调用父类的初始化方法
        super().__init__(dataset,
                         class_text_path,
                         test_mode,
                         pipeline,
                         lazy_init)

    # 强制完全初始化装饰器,用于 get_data_info 方法
    @force_full_init
    def get_data_info(self, idx: int) -> dict:
        """Get annotation by index."""
        # 调用 dataset 的 get_data_info 方法获取数据信息
        data_info = self.dataset.get_data_info(idx)
        # 如果 class_texts 不为空,则更新 data_info 中的 'texts' 字段
        if self.class_texts is not None:
            data_info.update({'texts': self.class_texts})
        # 根据 dataset_type 设置 data_info 中的 'is_detection' 字段
        data_info['is_detection'] = 1 \
            if self.dataset_type == 'detection' else 0
        return data_info

.\YOLO-World\yolo_world\datasets\transformers\mm_mix_img_transforms.py

# 导入必要的库和模块
import collections
import copy
from abc import ABCMeta, abstractmethod
from typing import Optional, Sequence, Tuple, Union

import mmcv
import numpy as np
from mmcv.transforms import BaseTransform
from mmdet.structures.bbox import autocast_box_type
from mmengine.dataset import BaseDataset
from mmengine.dataset.base_dataset import Compose
from numpy import random
from mmyolo.registry import TRANSFORMS

# 定义一个抽象基类,用于多模态多图像混合变换
class BaseMultiModalMixImageTransform(BaseTransform, metaclass=ABCMeta):
    """A Base Transform of Multimodal multiple images mixed.

    Suitable for training on multiple images mixed data augmentation like
    mosaic and mixup.

    Cached mosaic transform will random select images from the cache
    and combine them into one output image if use_cached is True.

    Args:
        pre_transform(Sequence[str]): Sequence of transform object or
            config dict to be composed. Defaults to None.
        prob(float): The transformation probability. Defaults to 1.0.
        use_cached (bool): Whether to use cache. Defaults to False.
        max_cached_images (int): The maximum length of the cache. The larger
            the cache, the stronger the randomness of this transform. As a
            rule of thumb, providing 10 caches for each image suffices for
            randomness. Defaults to 40.
        random_pop (bool): Whether to randomly pop a result from the cache
            when the cache is full. If set to False, use FIFO popping method.
            Defaults to True.
        max_refetch (int): The maximum number of retry iterations for getting
            valid results from the pipeline. If the number of iterations is
            greater than `max_refetch`, but results is still None, then the
            iteration is terminated and raise the error. Defaults to 15.
    """
    # 初始化函数,设置各种参数
    def __init__(self,
                 pre_transform: Optional[Sequence[str]] = None,  # 预处理转换序列的可选参数
                 prob: float = 1.0,  # 概率参数,默认为1.0
                 use_cached: bool = False,  # 是否使用缓存的布尔值,默认为False
                 max_cached_images: int = 40,  # 最大缓存图像数量,默认为40
                 random_pop: bool = True,  # 是否随机弹出的布尔值,默认为True
                 max_refetch: int = 15):  # 最大重新获取次数,默认为15
    
        # 设置最大重新获取次数
        self.max_refetch = max_refetch
        # 设置概率参数
        self.prob = prob
    
        # 设置是否使用缓存的布尔值
        self.use_cached = use_cached
        # 设置最大缓存图像数量
        self.max_cached_images = max_cached_images
        # 设置是否随机弹出的布尔值
        self.random_pop = random_pop
        # 初始化结果缓存列表
        self.results_cache = []
    
        # 如果预处理转换序列为None,则将预处理转换设置为None,否则使用Compose函数创建预处理转换
        if pre_transform is None:
            self.pre_transform = None
        else:
            self.pre_transform = Compose(pre_transform)
    
    @abstractmethod
    def get_indexes(self, dataset: Union[BaseDataset,
                                         list]) -> Union[list, int]:
        """Call function to collect indexes.
    
        Args:
            dataset (:obj:`Dataset` or list): The dataset or cached list.
    
        Returns:
            list or int: indexes.
        """
        pass
    
    @abstractmethod
    def mix_img_transform(self, results: dict) -> dict:
        """Mixed image data transformation.
    
        Args:
            results (dict): Result dict.
    
        Returns:
            results (dict): Updated result dict.
        """
        pass
    # 更新标签文本内容
    def _update_label_text(self, results: dict) -> dict:
        """Update label text."""
        # 如果结果中没有文本信息,则直接返回结果
        if 'texts' not in results:
            return results

        # 将所有文本信息合并并去重
        mix_texts = sum(
            [results['texts']] +
            [x['texts'] for x in results['mix_results']], [])
        mix_texts = list({tuple(x) for x in mix_texts})
        # 创建文本到索引的映射
        text2id = {text: i for i, text in enumerate(mix_texts)}

        # 更新结果中的标签文本
        for res in [results] + results['mix_results']:
            for i, label in enumerate(res['gt_bboxes_labels']):
                text = res['texts'][label]
                updated_id = text2id[tuple(text)]
                res['gt_bboxes_labels'][i] = updated_id
            res['texts'] = mix_texts
        # 返回更新后的结果
        return results

    # 装饰器,用于自动转换框类型
    @autocast_box_type()
# 注册多模态马赛克数据增强类到TRANSFORMS中
@TRANSFORMS.register_module()
class MultiModalMosaic(BaseMultiModalMixImageTransform):
    """Mosaic augmentation.

    给定4个图像,马赛克变换将它们合并成一个输出图像。输出图像由每个子图像的部分组成。

    .. code:: text

                        马赛克变换
                           center_x
                +------------------------------+
                |       pad        |           |
                |      +-----------+    pad    |
                |      |           |           |
                |      |  image1   +-----------+
                |      |           |           |
                |      |           |   image2  |
     center_y   |----+-+-----------+-----------+
                |    |   cropped   |           |
                |pad |   image3    |   image4  |
                |    |             |           |
                +----|-------------+-----------+
                     |             |
                     +-------------+

     马赛克变换步骤如下:

         1. 选择4个图像的交叉点作为马赛克中心
         2. 根据索引获取左上角图像,并从自定义数据集中随机采样另外3个图像
         3. 如果图像大于马赛克块,则将子图像裁剪

    必需键:

    - img
    - gt_bboxes (BaseBoxes[torch.float32]) (可选)
    - gt_bboxes_labels (np.int64) (可选)
    - gt_ignore_flags (bool) (可选)
    - mix_results (List[dict])

    修改后的键:

    - img
    - img_shape
    - gt_bboxes (可选)
    - gt_bboxes_labels (可选)
    - gt_ignore_flags (可选)
    Args:
        img_scale (Sequence[int]): Image size after mosaic pipeline of single
            image. The shape order should be (width, height).
            Defaults to (640, 640).
        center_ratio_range (Sequence[float]): Center ratio range of mosaic
            output. Defaults to (0.5, 1.5).
        bbox_clip_border (bool, optional): Whether to clip the objects outside
            the border of the image. In some dataset like MOT17, the gt bboxes
            are allowed to cross the border of images. Therefore, we don't
            need to clip the gt bboxes in these cases. Defaults to True.
        pad_val (int): Pad value. Defaults to 114.
        pre_transform(Sequence[dict]): Sequence of transform object or
            config dict to be composed.
        prob (float): Probability of applying this transformation.
            Defaults to 1.0.
        use_cached (bool): Whether to use cache. Defaults to False.
        max_cached_images (int): The maximum length of the cache. The larger
            the cache, the stronger the randomness of this transform. As a
            rule of thumb, providing 10 caches for each image suffices for
            randomness. Defaults to 40.
        random_pop (bool): Whether to randomly pop a result from the cache
            when the cache is full. If set to False, use FIFO popping method.
            Defaults to True.
        max_refetch (int): The maximum number of retry iterations for getting
            valid results from the pipeline. If the number of iterations is
            greater than `max_refetch`, but results is still None, then the
            iteration is terminated and raise the error. Defaults to 15.
    """
    # 初始化函数,设置数据增强的参数
    def __init__(self,
                 img_scale: Tuple[int, int] = (640, 640),  # 设置图像缩放的大小,默认为(640, 640)
                 center_ratio_range: Tuple[float, float] = (0.5, 1.5),  # 设置中心比例范围,默认为(0.5, 1.5)
                 bbox_clip_border: bool = True,  # 是否裁剪边界框,默认为True
                 pad_val: float = 114.0,  # 设置填充值,默认为114.0
                 pre_transform: Sequence[dict] = None,  # 预处理变换序列,默认为None
                 prob: float = 1.0,  # 数据增强的概率,默认为1.0
                 use_cached: bool = False,  # 是否使用缓存,默认为False
                 max_cached_images: int = 40,  # 最大缓存图像数量,默认为40
                 random_pop: bool = True,  # 是否随机弹出,默认为True
                 max_refetch: int = 15):  # 最大重新获取次数,默认为15
        assert isinstance(img_scale, tuple)  # 断言img_scale是元组类型
        assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. ' \  # 断言概率在[0,1]范围内
                                 f'got {prob}.'
        if use_cached:
            assert max_cached_images >= 4, 'The length of cache must >= 4, ' \  # 断言缓存长度大于等于4
                                           f'but got {max_cached_images}.'
    
        # 调用父类的初始化函数
        super().__init__(
            pre_transform=pre_transform,
            prob=prob,
            use_cached=use_cached,
            max_cached_images=max_cached_images,
            random_pop=random_pop,
            max_refetch=max_refetch)
    
        # 设置参数值
        self.img_scale = img_scale
        self.center_ratio_range = center_ratio_range
        self.bbox_clip_border = bbox_clip_border
        self.pad_val = pad_val
    
    # 获取数据集的索引
    def get_indexes(self, dataset: Union[BaseDataset, list]) -> list:
        """Call function to collect indexes.
    
        Args:
            dataset (:obj:`Dataset` or list): The dataset or cached list.
    
        Returns:
            list: indexes.
        """
        # 随机生成3个索引
        indexes = [random.randint(0, len(dataset)) for _ in range(3)]
        return indexes
    
    # 返回对象的字符串表示形式
    def __repr__(self) -> str:
        repr_str = self.__class__.__name__
        repr_str += f'(img_scale={self.img_scale}, '  # 添加图像缩放参数
        repr_str += f'center_ratio_range={self.center_ratio_range}, '  # 添加中心比例范围参数
        repr_str += f'pad_val={self.pad_val}, '  # 添加填充值参数
        repr_str += f'prob={self.prob})'  # 添加概率参数
        return repr_str
# 注册 MultiModalMosaic9 类到 TRANSFORMS 模块
@TRANSFORMS.register_module()
class MultiModalMosaic9(BaseMultiModalMixImageTransform):
    """Mosaic9 augmentation.

    给定9个图像,mosaic 变换将它们合并成一个输出图像。输出图像由每个子图像的部分组成。

    .. code:: text

                +-------------------------------+------------+
                | pad           |      pad      |            |
                |    +----------+               |            |
                |    |          +---------------+  top_right |
                |    |          |      top      |   image2   |
                |    | top_left |     image1    |            |
                |    |  image8  o--------+------+--------+---+
                |    |          |        |               |   |
                +----+----------+        |     right     |pad|
                |               | center |     image3    |   |
                |     left      | image0 +---------------+---|
                |    image7     |        |               |   |
            +---+-----------+---+--------+               |   |
            |   |  cropped  |            |  bottom_right |pad|
            |   |bottom_left|            |    image4     |   |
            |   |  image6   |   bottom   |               |   |
            +---|-----------+   image5   +---------------+---|
                |    pad    |            |        pad        |
                +-----------+------------+-------------------+

     Mosaic 变换步骤如下:

         1. 根据索引获取中心图像,并从自定义数据集中随机采样另外8个图像。
         2. 在 Mosaic 后随机偏移图像

    需要的键:

    - img
    - gt_bboxes (BaseBoxes[torch.float32]) (可选)
    - gt_bboxes_labels (np.int64) (可选)
    - gt_ignore_flags (bool) (可选)
    - mix_results (List[dict])

    修改的键:

    - img
    - img_shape
    - gt_bboxes (可选)
    # gt_bboxes_labels (可选):真实边界框标签,用于指定对象的类别
    # gt_ignore_flags (可选):真实边界框忽略标志,用于指定是否忽略某些对象
    
    Args:
        img_scale (Sequence[int]): 单个图像经过马赛克管道后的图像大小。形状顺序应为(宽度,高度)。
            默认为(640,640)。
        bbox_clip_border (bool, optional): 是否裁剪超出图像边界的对象。在某些数据集中,如MOT17,允许gt边界框越过图像边界。
            因此,在这些情况下,我们不需要裁剪gt边界框。默认为True。
        pad_val (int): 填充值。默认为114。
        pre_transform(Sequence[dict]): 要组合的转换对象或配置字典序列。
        prob (float): 应用此转换的概率。默认为1.0。
        use_cached (bool): 是否使用缓存。默认为False。
        max_cached_images (int): 缓存的最大长度。缓存越大,此转换的随机性越强。一般来说,为每个图像提供5个缓存足以保证随机性。默认为50。
        random_pop (bool): 当缓存已满时是否随机弹出一个结果。如果设置为False,则使用FIFO弹出方法。默认为True。
        max_refetch (int): 从管道获取有效结果的最大重试次数。如果迭代次数大于`max_refetch`,但结果仍为None,则终止迭代并引发错误。默认为15。
    # 初始化函数,设置默认参数和属性
    def __init__(self,
                 img_scale: Tuple[int, int] = (640, 640),  # 设置图像缩放尺寸,默认为(640, 640)
                 bbox_clip_border: bool = True,  # 是否裁剪边界框,默认为True
                 pad_val: Union[float, int] = 114.0,  # 设置填充值,默认为114.0
                 pre_transform: Sequence[dict] = None,  # 预处理变换序列,默认为None
                 prob: float = 1.0,  # 概率值,默认为1.0
                 use_cached: bool = False,  # 是否使用缓存,默认为False
                 max_cached_images: int = 50,  # 最大缓存图像数量,默认为50
                 random_pop: bool = True,  # 是否随机弹出,默认为True
                 max_refetch: int = 15):  # 最大重新获取次数,默认为15
        assert isinstance(img_scale, tuple)  # 断言img_scale为元组类型
        assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. ' \  # 断言概率值在[0,1]范围内
                                 f'got {prob}.'
        if use_cached:
            assert max_cached_images >= 9, 'The length of cache must >= 9, ' \  # 如果使用缓存,断言最大缓存图像数量大于等于9
                                           f'but got {max_cached_images}.'
    
        super().__init__(  # 调用父类的初始化函数
            pre_transform=pre_transform,
            prob=prob,
            use_cached=use_cached,
            max_cached_images=max_cached_images,
            random_pop=random_pop,
            max_refetch=max_refetch)
    
        self.img_scale = img_scale  # 设置img_scale属性
        self.bbox_clip_border = bbox_clip_border  # 设置bbox_clip_border属性
        self.pad_val = pad_val  # 设置pad_val属性
    
        # 中间变量
        self._current_img_shape = [0, 0]  # 当前图像形状
        self._center_img_shape = [0, 0]  # 中心图像形状
        self._previous_img_shape = [0, 0]  # 上一个图像形状
    
    # 获取索引函数,返回一个包含8个随机索引的列表
    def get_indexes(self, dataset: Union[BaseDataset, list]) -> list:
        """Call function to collect indexes.
    
        Args:
            dataset (:obj:`Dataset` or list): The dataset or cached list.
    
        Returns:
            list: indexes.
        """
        indexes = [random.randint(0, len(dataset)) for _ in range(8)]  # 生成8个随机索引
        return indexes
    
    # 返回对象的字符串表示形式
    def __repr__(self) -> str:
        repr_str = self.__class__.__name__  # 获取类名
        repr_str += f'(img_scale={self.img_scale}, '  # 添加img_scale属性
        repr_str += f'pad_val={self.pad_val}, '  # 添加pad_val属性
        repr_str += f'prob={self.prob})'  # 添加prob属性
        return repr_str  # 返回字符串表示形式
# 注册 YOLOv5MultiModalMixUp 类到 TRANSFORMS 模块中
@TRANSFORMS.register_module()
class YOLOv5MultiModalMixUp(BaseMultiModalMixImageTransform):
    """MixUp data augmentation for YOLOv5.

    .. code:: text

    The mixup transform steps are as follows:

        1. Another random image is picked by dataset.
        2. Randomly obtain the fusion ratio from the beta distribution,
            then fuse the target
        of the original image and mixup image through this ratio.

    Required Keys:

    - img
    - gt_bboxes (BaseBoxes[torch.float32]) (optional)
    - gt_bboxes_labels (np.int64) (optional)
    - gt_ignore_flags (bool) (optional)
    - mix_results (List[dict])


    Modified Keys:

    - img
    - img_shape
    - gt_bboxes (optional)
    - gt_bboxes_labels (optional)
    - gt_ignore_flags (optional)


    Args:
        alpha (float): parameter of beta distribution to get mixup ratio.
            Defaults to 32.
        beta (float):  parameter of beta distribution to get mixup ratio.
            Defaults to 32.
        pre_transform (Sequence[dict]): Sequence of transform object or
            config dict to be composed.
        prob (float): Probability of applying this transformation.
            Defaults to 1.0.
        use_cached (bool): Whether to use cache. Defaults to False.
        max_cached_images (int): The maximum length of the cache. The larger
            the cache, the stronger the randomness of this transform. As a
            rule of thumb, providing 10 caches for each image suffices for
            randomness. Defaults to 20.
        random_pop (bool): Whether to randomly pop a result from the cache
            when the cache is full. If set to False, use FIFO popping method.
            Defaults to True.
        max_refetch (int): The maximum number of iterations. If the number of
            iterations is greater than `max_refetch`, but gt_bbox is still
            empty, then the iteration is terminated. Defaults to 15.
    """
    # 初始化函数,设置默认参数值
    def __init__(self,
                 alpha: float = 32.0,
                 beta: float = 32.0,
                 pre_transform: Sequence[dict] = None,
                 prob: float = 1.0,
                 use_cached: bool = False,
                 max_cached_images: int = 20,
                 random_pop: bool = True,
                 max_refetch: int = 15):
        # 如果使用缓存,确保缓存长度大于等于2
        if use_cached:
            assert max_cached_images >= 2, 'The length of cache must >= 2, ' \
                                           f'but got {max_cached_images}.'
        # 调用父类的初始化函数
        super().__init__(
            pre_transform=pre_transform,
            prob=prob,
            use_cached=use_cached,
            max_cached_images=max_cached_images,
            random_pop=random_pop,
            max_refetch=max_refetch)
        # 设置 alpha 和 beta 参数
        self.alpha = alpha
        self.beta = beta

    # 获取索引函数,返回随机索引
    def get_indexes(self, dataset: Union[BaseDataset, list]) -> int:
        """Call function to collect indexes.

        Args:
            dataset (:obj:`Dataset` or list): The dataset or cached list.

        Returns:
            int: indexes.
        """
        # 返回一个随机索引,范围为 [0, 数据集长度)
        return random.randint(0, len(dataset))
    def mix_img_transform(self, results: dict) -> dict:
        """YOLOv5 MixUp transform function.

        Args:
            results (dict): Result dict

        Returns:
            results (dict): Updated result dict.
        """
        # 确保结果字典中包含'mix_results'键
        assert 'mix_results' in results

        # 从'mix_results'中获取第一个结果字典
        retrieve_results = results['mix_results'][0]
        # 获取原始图像和混合图像
        retrieve_img = retrieve_results['img']
        ori_img = results['img']
        # 确保原始图像和混合图像的形状相同
        assert ori_img.shape == retrieve_img.shape

        # 从 beta 分布中随机获取融合比例,大约为0.5
        ratio = np.random.beta(self.alpha, self.beta)
        mixup_img = (ori_img * ratio + retrieve_img * (1 - ratio))

        # 获取混合图像的 ground truth 边界框、标签和忽略标志
        retrieve_gt_bboxes = retrieve_results['gt_bboxes']
        retrieve_gt_bboxes_labels = retrieve_results['gt_bboxes_labels']
        retrieve_gt_ignore_flags = retrieve_results['gt_ignore_flags']

        # 合并原始图像和混合图像的 ground truth 边界框、标签和忽略标志
        mixup_gt_bboxes = retrieve_gt_bboxes.cat(
            (results['gt_bboxes'], retrieve_gt_bboxes), dim=0)
        mixup_gt_bboxes_labels = np.concatenate(
            (results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0)
        mixup_gt_ignore_flags = np.concatenate(
            (results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0)
        
        # 如果结果字典中包含'gt_masks'键
        if 'gt_masks' in results:
            # 确保'retrieve_results'中也包含'gt_masks'键
            assert 'gt_masks' in retrieve_results
            # 合并原始图像和混合图像的 ground truth masks
            mixup_gt_masks = results['gt_masks'].cat(
                [results['gt_masks'], retrieve_results['gt_masks']])
            results['gt_masks'] = mixup_gt_masks

        # 更新结果字典中的图像、图像形状、ground truth 边界框、标签和忽略标志
        results['img'] = mixup_img.astype(np.uint8)
        results['img_shape'] = mixup_img.shape
        results['gt_bboxes'] = mixup_gt_bboxes
        results['gt_bboxes_labels'] = mixup_gt_bboxes_labels
        results['gt_ignore_flags'] = mixup_gt_ignore_flags

        return results
# 注册 YOLOXMultiModalMixUp 类到 TRANSFORMS 模块中
@TRANSFORMS.register_module()
class YOLOXMultiModalMixUp(BaseMultiModalMixImageTransform):
    """MixUp data augmentation for YOLOX.

    .. code:: text

                         mixup transform
                +---------------+--------------+
                | mixup image   |              |
                |      +--------|--------+     |
                |      |        |        |     |
                +---------------+        |     |
                |      |                 |     |
                |      |      image      |     |
                |      |                 |     |
                |      |                 |     |
                |      +-----------------+     |
                |             pad              |
                +------------------------------+

    The mixup transform steps are as follows:

        1. Another random image is picked by dataset and embedded in
           the top left patch(after padding and resizing)
        2. The target of mixup transform is the weighted average of mixup
           image and origin image.

    Required Keys:

    - img
    - gt_bboxes (BaseBoxes[torch.float32]) (optional)
    - gt_bboxes_labels (np.int64) (optional)
    - gt_ignore_flags (bool) (optional)
    - mix_results (List[dict])


    Modified Keys:

    - img
    - img_shape
    - gt_bboxes (optional)
    - gt_bboxes_labels (optional)
    - gt_ignore_flags (optional)
    Args:
        img_scale (Sequence[int]): Image output size after mixup pipeline.
            The shape order should be (width, height). Defaults to (640, 640).
        ratio_range (Sequence[float]): Scale ratio of mixup image.
            Defaults to (0.5, 1.5).
        flip_ratio (float): Horizontal flip ratio of mixup image.
            Defaults to 0.5.
        pad_val (int): Pad value. Defaults to 114.
        bbox_clip_border (bool, optional): Whether to clip the objects outside
            the border of the image. In some dataset like MOT17, the gt bboxes
            are allowed to cross the border of images. Therefore, we don't
            need to clip the gt bboxes in these cases. Defaults to True.
        pre_transform(Sequence[dict]): Sequence of transform object or
            config dict to be composed.
        prob (float): Probability of applying this transformation.
            Defaults to 1.0.
        use_cached (bool): Whether to use cache. Defaults to False.
        max_cached_images (int): The maximum length of the cache. The larger
            the cache, the stronger the randomness of this transform. As a
            rule of thumb, providing 10 caches for each image suffices for
            randomness. Defaults to 20.
        random_pop (bool): Whether to randomly pop a result from the cache
            when the cache is full. If set to False, use FIFO popping method.
            Defaults to True.
        max_refetch (int): The maximum number of iterations. If the number of
            iterations is greater than `max_refetch`, but gt_bbox is still
            empty, then the iteration is terminated. Defaults to 15.
    """
    # 初始化函数,设置默认参数值
    def __init__(self,
                 img_scale: Tuple[int, int] = (640, 640),
                 ratio_range: Tuple[float, float] = (0.5, 1.5),
                 flip_ratio: float = 0.5,
                 pad_val: float = 114.0,
                 bbox_clip_border: bool = True,
                 pre_transform: Sequence[dict] = None,
                 prob: float = 1.0,
                 use_cached: bool = False,
                 max_cached_images: int = 20,
                 random_pop: bool = True,
                 max_refetch: int = 15):
        # 断言img_scale是元组类型
        assert isinstance(img_scale, tuple)
        # 如果使用缓存,确保最大缓存图片数量大于等于2
        if use_cached:
            assert max_cached_images >= 2, 'The length of cache must >= 2, ' \
                                           f'but got {max_cached_images}.'
        # 调用父类的初始化函数
        super().__init__(
            pre_transform=pre_transform,
            prob=prob,
            use_cached=use_cached,
            max_cached_images=max_cached_images,
            random_pop=random_pop,
            max_refetch=max_refetch)
        # 设置各个参数的值
        self.img_scale = img_scale
        self.ratio_range = ratio_range
        self.flip_ratio = flip_ratio
        self.pad_val = pad_val
        self.bbox_clip_border = bbox_clip_border

    # 获取索引的函数
    def get_indexes(self, dataset: Union[BaseDataset, list]) -> int:
        """Call function to collect indexes.

        Args:
            dataset (:obj:`Dataset` or list): The dataset or cached list.

        Returns:
            int: indexes.
        """
        # 返回一个随机索引
        return random.randint(0, len(dataset))

    # 返回对象的字符串表示形式
    def __repr__(self) -> str:
        repr_str = self.__class__.__name__
        repr_str += f'(img_scale={self.img_scale}, '
        repr_str += f'ratio_range={self.ratio_range}, '
        repr_str += f'flip_ratio={self.flip_ratio}, '
        repr_str += f'pad_val={self.pad_val}, '
        repr_str += f'max_refetch={self.max_refetch}, '
        repr_str += f'bbox_clip_border={self.bbox_clip_border})'
        return repr_str
03-11 00:36