zl程序教程

您现在的位置是:首页 >  云平台

当前栏目

行人重识别02-07:fast-reid(BoT)-pytorch编程规范(fast-reid为例)4-迭代器构建,数据加载-2

识别编程数据PyTorch迭代 构建 加载 规范
2023-09-14 09:13:07 时间

以下链接是个人关于fast-reid(BoT行人重识别) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号 海量资源。 \color{blue}{ 海量资源}。 海量资源

行人重识别02-00:fast-reid(BoT)-目录-史上最新无死角讲解

极度推荐的商业级项目: \color{red}{极度推荐的商业级项目:} 极度推荐的商业级项目:这是本人落地的行为分析项目,主要包含(1.行人检测,2.行人追踪,3.行为识别三大模块):行为分析(商用级别)00-目录-史上最新无死角讲解

前言

在上篇博客中,我们留下了一个疑问,那就是在fastreid\data\build.py文件

def build_reid_train_loader(cfg):
	......
    for d in cfg.DATASETS.NAMES:
        # 根据数据集名称,创建对应数据集迭代的类,本人调试为类 fastreid.data.datasets.market1501.Market1501对象
        dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
        train_items.extend(dataset.train)

中的 dataset.train 是如何获得的。在这篇博客中,我会为打大家解答这个疑问。

Market1501

首先,我们已经知道 dataset 为 fastreid.data.datasets.market1501.Market150 创建的对象,那么我们就找到这个Market150类:

@DATASET_REGISTRY.register()
class Market1501(ImageDataset):
    """Market1501.

    Reference:
        Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.

    URL: `<http://www.liangzheng.org/Project/project_reid.html>`_

    Dataset statistics:
        - identities: 1501 (+1 for background).
        - images: 12936 (train) + 3368 (query) + 15913 (gallery).
    """
    _junk_pids = [0, -1]
    dataset_dir = '' # 数据集目录,默认为空
    dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip' # 数据集现在地址
    dataset_name = "market1501" # 数据集名称

    def __init__(self, root='datasets', market1501_500k=False, **kwargs):
        # self.root = osp.abspath(osp.expanduser(root))
        # 数据集根目录
        self.root = root
        # 拼接数据集路径
        self.dataset_dir = osp.join(self.root, self.dataset_dir)

        # allow alternative directory structure  # 允许替换目录结构
        self.data_dir = self.dataset_dir

        # 获得数据集目录
        data_dir = osp.join(self.data_dir, 'Market-1501-v15.09.15')
        # 判断是否为一个目录。如果是则复制给self.data_dir
        if osp.isdir(data_dir):
            self.data_dir = data_dir
        # 否则给出警告
        else:
            warnings.warn('The current data structure is deprecated. Please '
                          'put data folders such as "bounding_box_train" under '
                          '"Market-1501-v15.09.15".')

        # 训练数据的目录,“bounding_box_train”——用于训练集的 751 人,包含 12,936 张图像
        self.train_dir = osp.join(self.data_dir, 'bounding_box_train')

        # 为 750 人在每个摄像头中随机选择一张图像作为query,因此一个人的query最多有 6 个,共有 3,368 张图像
        self.query_dir = osp.join(self.data_dir, 'query')

        #“bounding_box_test”——用于测试集的 750 人,包含 19,732 张图像,前缀为 0000 表示在提取这 750 人的过程中DPM检测错的图
        # (可能与query是同一个人),-1 表示检测出来其他人的图(不在这 750 人中)
        self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test')

        # 额外画廊的数据,可以理解为为bounding_box_test添加额外的数据
        self.extra_gallery_dir = osp.join(self.data_dir, 'images')

        # 设置self.market1501_500k标志,设置了该标志位,这表示为 self.gallery_dir 添加额外的 self.extra_gallery_dir数据
        self.market1501_500k = market1501_500k

        # 需要的文件目录
        required_files = [
            self.data_dir,
            self.train_dir,
            self.query_dir,
            self.gallery_dir,
        ]
        # 如果设置了 market1501_500k=Ture,则添加 extra_gallery_dir 到 required_files 之中
        if self.market1501_500k:
            required_files.append(self.extra_gallery_dir)

        # 在运行之前对 required_files 进行检查
        self.check_before_run(required_files)

        # 对训练目录进行处理
        train = self.process_dir(self.train_dir)
        # 对 query 目录数据进行处理
        query = self.process_dir(self.query_dir, is_train=False)
        # 对 gallery 目录进行处理
        gallery = self.process_dir(self.gallery_dir, is_train=False)

        # 如果设置market1501_500k=Ture,则对extra_gallery_dir进行处理,并且添加到gallery之中
        if self.market1501_500k:
            gallery += self.process_dir(self.extra_gallery_dir, is_train=False)

        # 调用父类的初始化函数
        super(Market1501, self).__init__(train, query, gallery, **kwargs)

    def process_dir(self, dir_path, is_train=True):
        # 获得目录下的所有jpg图片名称
        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
        # 用于正则表达式
        pattern = re.compile(r'([-\d]+)_c(\d)')

        data = []
        # 对每一张图片进行处理
        # 以 0001_c1s1_000151_01.jpg 为例
        # 1) 0001 表示每个人的标签编号,从0001到1501;
        # 2) c1 表示第一个摄像头(camera1),共有6个摄像头;
        # 3) s1 表示第一个录像片段(sequece1),每个摄像机都有数个录像段;
        # 4) 000151 表示 c1s1 的第000151帧图片,视频帧率25fps;
        # 5) 01 表示 c1s1_001051 这一帧上的第1个检测框,由于采用DPM检测器,对于每一帧上的行人可能会框出好几个bbox。00 表示手工标注框
        for img_path in img_paths:
            # 获得图片对应的pid(身份ID)以及摄像头编号
            pid, camid = map(int, pattern.search(img_path).groups())
            # -1表示身份不在1051中,则忽略不进行处理
            if pid == -1:
                continue  # junk images are just ignored
            # 检测id是否正确,不正确则报错
            assert 0 <= pid <= 1501  # pid == 0 means background
            # 检测摄像头标号是否正确,不正确则报错
            assert 1 <= camid <= 6
            # 摄像头的标号默认从0开始
            camid -= 1  # index starts from 0

            # 如果进行训练,则在前面加上'market1501',因为默认支持多个数据集训练,
            # 所以每个身份ID都会加上对应的数据集前缀
            if is_train:
                pid = self.dataset_name + "_" + str(pid)
            # 把单张名称路径,身份ID,以及摄像头标号放置到data中
            data.append((img_path, pid, camid))

        return data

上面的代码,注释已经比较详细了,就不进行讲解了。大家或许奇怪为什么没有看到__getitem__函数,其实在上篇博客中已经进行讲解了。他是在 class CommDataset(Dataset) 中实现的。从上面可以看到 Market1501 继承于ImageDataset。

ImageDataset

class ImageDataset(Dataset):
    """
    ImageDataset是一个基类,所有的数据迭代器都应该继承于他,__getitem__返回
    img,pid,camid,以及img_path。其中img包含了(channel, height, width)信息。
    每次训练迭代返数据形状为batch_size, channel, height, width
    A base class representing ImageDataset.
    All other image datasets should subclass it.
    ``__getitem__`` returns an image given index.
    It will return ``img``, ``pid``, ``camid`` and ``img_path``
    where ``img`` has shape (channel, height, width). As a result,
    data in each batch has shape (batch_size, channel, height, width).
    """

    def __init__(self, train, query, gallery, **kwargs):
        super(ImageDataset, self).__init__(train, query, gallery, **kwargs)

    def show_train(self):
        """
        打印训练数据相关信息
        """
        # 对训练数据进行进行解析,获得id身份总数目,以及摄像头标号总数目
        num_train_pids, num_train_cams = self.parse_data(self.train)

        headers = ['subset', '# ids', '# images', '# cameras']
        csv_results = [['train', num_train_pids, len(self.train), num_train_cams]]

        # tabulate it,进行信息打印
        table = tabulate(
            csv_results,
            tablefmt="pipe",
            headers=headers,
            numalign="left",
        )
        logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))

    def show_test(self):
        # 获得需要gallery,query的ID以及摄像头数目。然后进行打印显示
        num_query_pids, num_query_cams = self.parse_data(self.query)
        num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)

        headers = ['subset', '# ids', '# images', '# cameras']
        csv_results = [
            ['query', num_query_pids, len(self.query), num_query_cams],
            ['gallery', num_gallery_pids, len(self.gallery), num_gallery_cams],
        ]

        # tabulate it
        table = tabulate(
            csv_results,
            tablefmt="pipe",
            headers=headers,
            numalign="left",
        )
        logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))

ImageDataset主要实现了两个打印信息的函数,其实都无关要紧。

结语

到这里为止,我们已经翻译了论文,知道数据是如何获取的,并且熟悉了训练总体过程,那么剩下的,就是对网络总体结构的解析了。

在这里插入图片描述