# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import tempfile
from argparse import ArgumentParser

import cv2
import mmcv

from mmtrack.apis import inference_sot, init_model


def main():
    parser = ArgumentParser()
    parser.add_argument('config', help='Config file')
    parser.add_argument('--input', help='input video file')
    parser.add_argument('--output', help='output video file (mp4 format)')
    parser.add_argument('--checkpoint', help='Checkpoint file')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--show',
        action='store_true',
        default=False,
        help='whether to show visualizations.')
    parser.add_argument(
        '--color', default=(0, 255, 0), help='Color of tracked bbox lines.')
    parser.add_argument(
        '--thickness', default=3, type=int, help='Thickness of bbox lines.')
    parser.add_argument('--fps', type=int, help='FPS of the output video')
    parser.add_argument('--gt_bbox_file', help='The path of gt_bbox file')
    args = parser.parse_args()

    # load images
    if osp.isdir(args.input):
        imgs = sorted(
            filter(lambda x: x.endswith(('.jpg', '.png', '.jpeg')),
                   os.listdir(args.input)),
            key=lambda x: int(x.split('.')[0]))
        IN_VIDEO = False
    else:
        imgs = mmcv.VideoReader(args.input)
        IN_VIDEO = True

    OUT_VIDEO = False
    # define output
    if args.output is not None:
        if args.output.endswith('.mp4'):
            OUT_VIDEO = True
            out_dir = tempfile.TemporaryDirectory()
            out_path = out_dir.name
            _out = args.output.rsplit(os.sep, 1)
            if len(_out) > 1:
                os.makedirs(_out[0], exist_ok=True)
        else:
            out_path = args.output
            os.makedirs(out_path, exist_ok=True)
    fps = args.fps
    if args.show or OUT_VIDEO:
        if fps is None and IN_VIDEO:
            fps = imgs.fps
        if not fps:
            raise ValueError('Please set the FPS for the output video.')
        fps = int(fps)

    # build the model from a config file and a checkpoint file
    model = init_model(args.config, args.checkpoint, device=args.device)

    prog_bar = mmcv.ProgressBar(len(imgs))
    # test and show/save the images
    for i, img in enumerate(imgs):
        if isinstance(img, str):
            img_path = osp.join(args.input, img)
            img = mmcv.imread(img_path)
        if i == 0:
            if args.gt_bbox_file is not None:
                bboxes = mmcv.list_from_file(args.gt_bbox_file)
                init_bbox = list(map(float, bboxes[0].split(',')))
            else:
                init_bbox = list(cv2.selectROI(args.input, img, False, False))

            # convert (x1, y1, w, h) to (x1, y1, x2, y2)
            init_bbox[2] += init_bbox[0]
            init_bbox[3] += init_bbox[1]

        result = inference_sot(model, img, init_bbox, frame_id=i)
        if args.output is not None:
            if IN_VIDEO or OUT_VIDEO:
                out_file = osp.join(out_path, f'{i:06d}.jpg')
            else:
                out_file = osp.join(out_path, img_path.rsplit(os.sep, 1)[-1])
        else:
            out_file = None
        model.show_result(
            img,
            result,
            show=args.show,
            wait_time=int(1000. / fps) if fps else 0,
            out_file=out_file,
            thickness=args.thickness)
        prog_bar.update()

    if args.output and OUT_VIDEO:
        print(
            f'\nmaking the output video at {args.output} with a FPS of {fps}')
        mmcv.frames2video(out_path, args.output, fps=fps, fourcc='mp4v')
        out_dir.cleanup()


if __name__ == '__main__':
    main()

全部源码如下,下面开始解析。

代码分析

这个demo_sot.py是在mmtracking项目文件夹下面的demo文件夹下的演示代码

引用库

import os
import os.path as osp
import tempfile
from argparse import ArgumentParser

import cv2
import mmcv

from mmtrack.apis import inference_sot, init_model

配置参数

parser = ArgumentParser()
    parser.add_argument('config', help='Config file')
    parser.add_argument('--input', help='input video file')
    parser.add_argument('--output', help='output video file (mp4 format)')
    parser.add_argument('--checkpoint', help='Checkpoint file')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--show',
        action='store_true',
        default=False,
        help='whether to show visualizations.')
    parser.add_argument(
        '--color', default=(0, 255, 0), help='Color of tracked bbox lines.')
    parser.add_argument(
        '--thickness', default=3, type=int, help='Thickness of bbox lines.')
    parser.add_argument('--fps', type=int, help='FPS of the output video')
    parser.add_argument('--gt_bbox_file', help='The path of gt_bbox file')
    args = parser.parse_args()

使用 argparse 库解析命令行参数。这些参数包括:

config:配置文件
–input:输入视频文件或图像目录
–output:输出视频文件(mp4 格式)
–checkpoint:检查点文件
–device:用于推理的设备(默认为 ‘cuda:0’)
–show:是否显示可视化结果
–color:跟踪边界框线条的颜色
–thickness:边界框线条的粗细
–fps:输出视频的帧率
–gt_bbox_file:ground truth 边界框文件的路径

加载输入

# load images
    if osp.isdir(args.input):
        imgs = sorted(
            filter(lambda x: x.endswith(('.jpg', '.png', '.jpeg')),
                   os.listdir(args.input)),
            key=lambda x: int(x.split('.')[0]))
        IN_VIDEO = False
    else:
        imgs = mmcv.VideoReader(args.input)
        IN_VIDEO = True

    OUT_VIDEO = False

首先判断输入的参数args.input是否是一个目录,
如果是一个目录,os.listdir(args.input)就列出该目录下面所有的文件和子目录,然后使用filter(lambda x: x.endswith((‘.jpg’, ‘.png’, ‘.jpeg’)), …)过滤出以.jpg, .png, 或 .jpeg 结尾的文件
lambda x:A,y表示对y使用以x:A的函数,这里就是对于os.listdir(args.input)列出的所有输入目录下面的子文件,使用ilter(lambda x: x.endswith((‘.jpg’, ‘.png’, ‘.jpeg’)),
sorted(…, key=lambda x: int(x.split(‘.’)[0])): 对过滤后的图像文件名进行排序。排序的依据是文件名(不包括扩展名)转化为整数后的值。这意味着如果文件名为 “img10.jpg”, “img2.jpg”, “img1.jpg”,它们会被正确地按数字顺序排序。
IN_VIDEO = False: 设置 IN_VIDEO 为 False,表示输入的不是视频。

如果不是一个目录,即一个视频文件,

mmcv.VideoReader(args.input): 使用 mmcv(一个常用于计算机视觉任务的库)的 VideoReader 函数读取视频文件。这个函数会返回一个迭代器,每次迭代都会返回一个视频帧。
IN_VIDEO = True: 设置 IN_VIDEO 为 True,表示输入的是视频。

OUT_VIDEO = False: 设置 OUT_VIDEO 为 False。这行代码表示在后续代码中,程序有可能会输出或保存为一个视频文件,但目前还没有设置要输出视频。

配置输出

# define output
    if args.output is not None:
        if args.output.endswith('.mp4'):
            OUT_VIDEO = True
            out_dir = tempfile.TemporaryDirectory()
            out_path = out_dir.name
            _out = args.output.rsplit(os.sep, 1)
            if len(_out) > 1:
                os.makedirs(_out[0], exist_ok=True)
        else:
            out_path = args.output
            os.makedirs(out_path, exist_ok=True)
    fps = args.fps
    if args.show or OUT_VIDEO:
        if fps is None and IN_VIDEO:
            fps = imgs.fps
        if not fps:
            raise ValueError('Please set the FPS for the output video.')
        fps = int(fps)

if args.output is not None:如果参数设置的输出不为空
if args.output.endwith(‘.mp4’) 如果用户希望输出一个MP4视频文件
OUT_VIDEO = True 设置全局变量为True,表示要输出视频
out_dir = tempfile.TemporaryDirectory(): 创建一个临时目录,并将其路径存储在out_dir变量中。
out_path = out_dir.name: 获取临时目录的路径,并将其存储在out_path变量中。
_out = args.output.rsplit(os.sep, 1): 使用rsplit方法将args.output字符串从右边开始分割,分割符为os.sep(这通常是文件路径中的分隔符,如/或\),并且只分割一次。结果存储在_out变量中。

代码获取args.fps的值,即每秒帧数,用于后续的视频输出。
如果args.show或OUT_VIDEO为True(即用户希望显示或输出视频),代码会检查fps的值。
如果fps为None且IN_VIDEO为True(似乎是一个未在代码段中定义的变量,可能表示输入也是一个视频),则使用输入视频的fps。
如果fps仍然为None或False,则抛出一个ValueError,要求用户设置输出视频的FPS。
最后,将fps转换为整数。

初始化模型

# build the model from a config file and a checkpoint file
    model = init_model(args.config, args.checkpoint, device=args.device)
    prog_bar = mmcv.ProgressBar(len(imgs))

调用了一个 init_model 的函数,并将结果赋值给变量 model。这个函数来自
mmtracki.apis.inference.py

args.config: 是模型的配置文件路径,通常包含模型的结构、优化器设置、训练参数等。

args.checkpoint: 是模型的检查点文件路径,通常包含模型的权重或其他训练过程中的状态。

device=args.device: 指定模型应该在哪个设备上运行,例如 CPU 或 GPU。args.device 可能是一个字符串,如 “cpu” 或 “cuda:0”。

我们首先来看这个init_model函数

def init_model(config,
               checkpoint=None,
               device='cuda:0',
               cfg_options=None,
               verbose_init_params=False):
    """Initialize a model from config file.

    Args:
        config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
        checkpoint (str, optional): Checkpoint path. Default as None.
        cfg_options (dict, optional): Options to override some settings in
            the used config. Default to None.
        verbose_init_params (bool, optional): Whether to print the information
            of initialized parameters to the console. Default to False.

    Returns:
        nn.Module: The constructed detector.
    """
    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
    if cfg_options is not None:
        config.merge_from_dict(cfg_options)
    if 'detector' in config.model:
        config.model.detector.pretrained = None
    model = build_model(config.model)

    if not verbose_init_params:
        # Creating a temporary file to record the information of initialized
        # parameters. If not, the information of initialized parameters will be
        # printed to the console because of the call of
        # `mmcv.runner.BaseModule.init_weights`.
        tmp_file = tempfile.NamedTemporaryFile(delete=False)
        file_handler = logging.FileHandler(tmp_file.name, mode='w')
        model.logger.addHandler(file_handler)
        # We need call `init_weights()` to load pretained weights in MOT
        # task.
        model.init_weights()
        file_handler.close()
        model.logger.removeHandler(file_handler)
        tmp_file.close()
        os.remove(tmp_file.name)
    else:
        # We need call `init_weights()` to load pretained weights in MOT task.
        model.init_weights()

    if checkpoint is not None:
        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
        if 'meta' in checkpoint and 'CLASSES' in checkpoint['meta']:
            model.CLASSES = checkpoint['meta']['CLASSES']
    if not hasattr(model, 'CLASSES'):
        if hasattr(model, 'detector') and hasattr(model.detector, 'CLASSES'):
            model.CLASSES = model.detector.CLASSES
        else:
            print("Warning: The model doesn't have classes")
            model.CLASSES = None
    model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model

函数的输入

函数接受5个输入:
①config:这个是这是配置文件的路径或配置对象。是一个str字符串或者 mmcv.Config。mmcv.Config是一个对象,用于管理和解析配置文件。如果传入的是字符串,则它应该是指向配置文件的路径。

②checkpoint (str, optional): 检查点(checkpoint)路径,这是一个可选参数。检查点通常包含模型的权重和可能的其他状态信息。如果在初始化模型时要加载预训练的权重,则会用到这个参数。

③device (str, optional): 设备字符串,指定模型应该在哪个设备上运行。默认值是’cuda:0’,意味着模型将在第一个GPU上运行。如果要在CPU上运行,可以传入’cpu’。

④cfg_options (dict, optional): 一个字典,用于覆盖配置文件中的某些设置。这是一个可选参数,默认值为None。

⑤verbose_init_params (bool, optional): 一个布尔值,决定是否将初始化的参数信息打印到控制台。默认值为False。

检查config

接下来

if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')

检查 config 的类型:
如果 config 是一个字符串(str),那么它很可能是一个配置文件的路径。这种情况下,代码使用 mmcv.Config.fromfile(config) 来从这个路径加载配置文件,并将其解析为 mmcv.Config 对象。
如果 config 不是一个 mmcv.Config 对象,那么代码会抛出一个 TypeError,提示 config 必须是一个文件名或 Config 对象。

if cfg_options is not None:
        config.merge_from_dict(cfg_options)

合并配置选项:
如果 cfg_options 不为 None,代码会将其作为一个字典合并到 config 中。这意味着 cfg_options 中的任何设置都会覆盖 config 中的相应设置。

if 'detector' in config.model:
        config.model.detector.pretrained = None

处理预训练模型:

如果 config.model 包含一个 ‘detector’ 键,代码会将其 pretrained 属性设置为 None。这通常意味着在构建模型时不会使用预训练的权重。

model = build_model(config.model)

构建模型:

最后,代码使用 build_model(config.model) 来根据 config.model 中的配置构建一个模型。这里假设 build_model 是一个已经定义好的函数,它可以根据提供的配置信息来创建和返回一个模型对象。

模型初始化信息

if not verbose_init_params:
        # Creating a temporary file to record the information of initialized
        # parameters. If not, the information of initialized parameters will be
        # printed to the console because of the call of
        # `mmcv.runner.BaseModule.init_weights`.
        tmp_file = tempfile.NamedTemporaryFile(delete=False)
        file_handler = logging.FileHandler(tmp_file.name, mode='w')
        model.logger.addHandler(file_handler)
        # We need call `init_weights()` to load pretained weights in MOT
        # task.
        model.init_weights()
        file_handler.close()
        model.logger.removeHandler(file_handler)
        tmp_file.close()
        os.remove(tmp_file.name)
    else:
        # We need call `init_weights()` to load pretained weights in MOT task.
        model.init_weights()

判断verbose_init_params的值:如果verbose_init_params为False,则执行以下的代码块。
创建临时文件:使用tempfile.NamedTemporaryFile创建一个临时文件,并且设置delete=False,这意味着当文件关闭后,它不会被自动删除。
设置文件处理器:为model.logger添加一个文件处理器,这样model.logger输出的日志信息就会被写入到之前创建的临时文件中。
初始化模型权重:无论verbose_init_params的值如何,都会执行此行代码来加载预训练的权重。
关闭文件处理器和临时文件
删除临时文件
如果verbose_init_params为True:如果verbose_init_params为True,则不会创建临时文件,而是直接调用model.init_weights()来加载预训练的权重。在这种情况下,由于mmcv.runner.BaseModule.init_weights的调用,初始化的参数信息将被直接打印到控制台。

检查checkpoint

if checkpoint is not None:
        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
        if 'meta' in checkpoint and 'CLASSES' in checkpoint['meta']:
            model.CLASSES = checkpoint['meta']['CLASSES']
    if not hasattr(model, 'CLASSES'):
        if hasattr(model, 'detector') and hasattr(model.detector, 'CLASSES'):
            model.CLASSES = model.detector.CLASSES
        else:
            print("Warning: The model doesn't have classes")
            model.CLASSES = None

检查检查点是否为空:检查checkpoint变量是否不为None。如果checkpoint有一个有效的值(即不是None),则进入该if语句块。
加载检查点:如果checkpoint不为空,这行代码会调用load_checkpoint函数,尝试加载检查点。load_checkpoint函数接受三个参数:model(模型对象)、checkpoint(检查点路径或对象)和map_location=‘cpu’(指定模型应加载到CPU上)。
从检查点中提取类别信息:这部分代码首先检查checkpoint字典中是否有一个meta键,并且meta字典中是否有一个CLASSES键。如果两者都存在,那么它会将CLASSES信息从检查点中提取出来,并赋值给模型的CLASSES属性。
检查模型是否有CLASSES属性:这行代码检查模型对象model是否没有CLASSES属性。如果没有,它会进入该if语句块。
从模型的detector属性中提取CLASSES:这部分代码首先检查模型对象model是否有一个detector属性,并且detector属性是否有一个CLASSES属性。如果两者都存在,那么它会将CLASSES信息从detector中提取出来,并赋值给模型的CLASSES属性。
警告:模型没有类别:如果上述所有条件都不满足(即模型及其detector属性都没有CLASSES属性),则会打印一条警告消息,并将模型的CLASSES属性设置为None。

这段代码的主要目的是从检查点或模型的detector属性中加载CLASSES信息,并确保模型具有CLASSES属性。如果模型或其detector属性中没有CLASSES信息,则会发出警告并将CLASSES设置为None。

加载模型

model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model

将config对象赋值给model的cfg属性,将模型移动到指定的设备上,将模型设置为评估模式,返回了配置并移动到指定设备的模型。

测试、展示和保存图像视频

# test and show/save the images
    for i, img in enumerate(imgs):
        if isinstance(img, str):
            img_path = osp.join(args.input, img)
            img = mmcv.imread(img_path)
        if i == 0:
            if args.gt_bbox_file is not None:
                bboxes = mmcv.list_from_file(args.gt_bbox_file)
                init_bbox = list(map(float, bboxes[0].split(',')))
            else:
                init_bbox = list(cv2.selectROI(args.input, img, False, False))

            # convert (x1, y1, w, h) to (x1, y1, x2, y2)
            init_bbox[2] += init_bbox[0]
            init_bbox[3] += init_bbox[1]

        result = inference_sot(model, img, init_bbox, frame_id=i)
        if args.output is not None:
            if IN_VIDEO or OUT_VIDEO:
                out_file = osp.join(out_path, f'{i:06d}.jpg')
            else:
                out_file = osp.join(out_path, img_path.rsplit(os.sep, 1)[-1])
        else:
            out_file = None
        model.show_result(
            img,
            result,
            show=args.show,
            wait_time=int(1000. / fps) if fps else 0,
            out_file=out_file,
            thickness=args.thickness)
        prog_bar.update()

    if args.output and OUT_VIDEO:
        print(
            f'\nmaking the output video at {args.output} with a FPS of {fps}')
        mmcv.frames2video(out_path, args.output, fps=fps, fourcc='mp4v')
        out_dir.cleanup()

遍历图像序列:for i, img in enumerate(imgs): 遍历 imgs 列表(或其他可迭代对象),i 是索引,img 是每个图像或图像路径。

处理图像路径:如果 img 是一个字符串(可能是文件路径),代码将其转换为绝对路径,并使用 mmcv.imread 读取图像。

初始化边界框:
如果 args.gt_bbox_file 存在,则从文件中读取边界框坐标。
如果不存在,使用 cv2.selectROI 手动选择图像上的感兴趣区域(ROI)作为初始边界框。
将边界框从 (x1, y1, w, h) 格式转换为 (x1, y1, x2, y2) 格式。

执行模型推理:调用 inference_sot(model, img, init_bbox, frame_id=i) 函数,对图像执行目标跟踪推理,并返回结果。

处理输出结果:

根据 args.output 和其他条件确定输出文件的路径。
使用 model.show_result 显示结果,可以选择是否显示图像、设置等待时间、保存输出文件等。
更新进度条(prog_bar.update())。

生成输出视频:如果 args.output 存在且 OUT_VIDEO 为真,将输出目录中的帧转换为视频文件,并清理输出目录。

03-14 09:18