zl程序教程

您现在的位置是:首页 >  其它

当前栏目

torchvision dataset和torch.utils.data.DataLoader

Data Dataset torch utils torchvision
2023-09-11 14:22:51 时间

一、torchvision主要包括一下几个包:

vision.datasets : 几个常用视觉数据集,可以下载和加载,这里主要的高级用法就是可以看源码如何自己写自己的Dataset的子类
vision.models : 流行的模型,例如 AlexNet, VGG, ResNet 和 Densenet 以及 与训练好的参数。
vision.transforms : 常用的图像操作,例如:随机切割,旋转,数据类型转换,图像到tensor ,numpy 数组到tensor , tensor 到 图像等。
vision.utils : 用于把形似 (3 x H x W) 的张量保存到硬盘中,给一个mini-batch的图像可以产生一个图像格网。

1.torchvision.dataset()用法

CIFAR
dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)

dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)

参数说明: - root : cifar-10-batches-py 的根目录

- train : True = 训练集, False = 测试集

- download : True = 从互联上下载数据,并将其放在root目录下。如果数据集已经下载,什么都不干。

返回类型:tuple,返回参数:(image,target)其中target应是目标类的类索引。

 


举例:

from torch.utils.data.sampler import Sampler
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
 
dataset = MyDataset() # 第一步:构建 Dataset 对象
dataloader = DataLoader(dataset) # 第二步:通过Dataloader来构建迭代对象
 
num_epoches = 100
for epoch in range(num_epoches):
    for i, data in enumerate(dataloader):
        # 训练代码

二、 torch.utils.data.DataLoader

torch.utils.data.DataLoader(dataset, batch_size=1,shuffle=False,sampler=None,
                            batch_sampler=None, num_workers=0, collate_fn=None,
                            pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None)
 
 
 
    Arguments:
        dataset (Dataset): 是一个DataSet对象,表示需要加载的数据集.
        batch_size (int, optional): 每一个batch加载多少组样本,即指定batch_size,默认是 1 
        shuffle (bool, optional): 布尔值True或者是False ,表示每一个epoch之后是否对样本进行随机打乱,默认是False
------------------------------------------------------------------------------------
        sampler (Sampler, optional): 自定义从数据集中抽取样本的策略,如果指定这个参数,那么shuffle必须为False
        batch_sampler (Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥)
------------------------------------------------------------------------------------
        num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
        collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数(这个还不是很懂)
        pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
------------------------------------------------------------------------------------
        drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了,如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
------------------------------------------------------------------------------------
        timeout (numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
 
        worker_init_fn (callable, optional): If not ``None``, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: ``None``)

 DataLoader的源码:https://blog.csdn.net/u014380165/article/details/79058479

class DataLoader(object):
"""
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.

    Arguments:
        dataset (Dataset): dataset from which to load the data.
        batch_size (int, optional): how many samples per batch to load
            (default: 1).
        shuffle (bool, optional): set to ``True`` to have the data reshuffled
            at every epoch (default: False).
        sampler (Sampler, optional): defines the strategy to draw samples from
            the dataset. If specified, ``shuffle`` must be False.
        batch_sampler (Sampler, optional): like sampler, but returns a batch of
            indices at a time. Mutually exclusive with batch_size, shuffle,
            sampler, and drop_last.
        num_workers (int, optional): how many subprocesses to use for data
            loading. 0 means that the data will be loaded in the main process.
            (default: 0)
        collate_fn (callable, optional): merges a list of samples to form a mini-batch.
        pin_memory (bool, optional): If ``True``, the data loader will copy tensors
            into CUDA pinned memory before returning them.
        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
            if the dataset size is not divisible by the batch size. If ``False`` and
            the size of dataset is not divisible by the batch size, then the last batch
            will be smaller. (default: False)
        timeout (numeric, optional): if positive, the timeout value for collecting a batch
            from workers. Should always be non-negative. (default: 0)
        worker_init_fn (callable, optional): If not None, this will be called on each
            worker subprocess with the worker id as input, after seeding and before data
            loading. (default: None)
"""

    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                 timeout=0, worker_init_fn=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn

        if timeout < 0:
            raise ValueError('timeout option should be non-negative')

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')

        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')

        if self.num_workers < 0:
            raise ValueError('num_workers cannot be negative; '
                             'use num_workers=0 to disable multiprocessing.')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler

    def __iter__(self):
        return DataLoaderIter(self)

    def __len__(self):
        return len(self.batch_sampler)

先看看__init__中的几个重要的输入:

1、dataset,这个就是PyTorch已有的数据读取接口(比如torchvision.datasets.ImageFolder)或者自定义的数据接口的输出,该输出要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的自定义类的对象

2、batch_size,根据具体情况设置即可。

3、shuffle,一般在训练数据中会采用。

4、collate_fn,是用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。

5、batch_sampler,从注释可以看出,其和batch_size、shuffle等参数是互斥的,一般采用默认。

6、sampler,从代码可以看出,其和shuffle是互斥的,一般默认即可。

7、num_workers,从注释可以看出这个参数必须大于等于0,0的话表示数据导入在主进程中进行,其他大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。

8、pin_memory,注释写得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一个数据拷贝的问题。

9、timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。

在__init__中,RandomSampler类表示随机采样且不重复,所以起到的就是shuffle的作用。

BatchSampler类则是把batch size个RandomSampler类对象封装成一个,这样就实现了随机选取一个batch的目的。

这两个采样类都是定义在sampler.py脚本中,地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py。

以上这些都是初始化的时候进行的,当代码运行到要从torch.utils.data.DataLoader类生成的对象中取数据的时候,比如:

train_data=torch.utils.data.DataLoader(...)
for i, (input, target) in enumerate(train_data):
...

就会调用DataLoader类的__iter__方法,__iter__方法就一行代码:return DataLoaderIter(self),输入正是DataLoader类的属性。因此当调用__iter__方法的时候就牵扯到另外一个类:DataLoaderIter,接下来介绍。

DataLoaderIter类源码如下:

class DataLoaderIter(object):
    "Iterates once over the DataLoader's dataset, as specified by the sampler"

    def __init__(self, loader):
        self.dataset = loader.dataset
        self.collate_fn = loader.collate_fn
        self.batch_sampler = loader.batch_sampler
        self.num_workers = loader.num_workers
        self.pin_memory = loader.pin_memory and torch.cuda.is_available()
        self.timeout = loader.timeout
        self.done_event = threading.Event()

        self.sample_iter = iter(self.batch_sampler)

        if self.num_workers > 0:
            self.worker_init_fn = loader.worker_init_fn
            self.index_queue = multiprocessing.SimpleQueue()
            self.worker_result_queue = multiprocessing.SimpleQueue()
            self.batches_outstanding = 0
            self.worker_pids_set = False
            self.shutdown = False
            self.send_idx = 0
            self.rcvd_idx = 0
            self.reorder_dict = {}

            base_seed = torch.LongTensor(1).random_()[0]
            self.workers = [
                multiprocessing.Process(
                    target=_worker_loop,
                    args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
                          base_seed + i, self.worker_init_fn, i))
                for i in range(self.num_workers)]

            if self.pin_memory or self.timeout > 0:
                self.data_queue = queue.Queue()
                self.worker_manager_thread = threading.Thread(
                    target=_worker_manager_loop,
                    args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
                          torch.cuda.current_device()))
                self.worker_manager_thread.daemon = True
                self.worker_manager_thread.start()
            else:
                self.data_queue = self.worker_result_queue

            for w in self.workers:
                w.daemon = True  # ensure that the worker exits on process exit
                w.start()

            _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
            _set_SIGCHLD_handler()
            self.worker_pids_set = True

            # prime the prefetch loop
            for _ in range(2 * self.num_workers):
                self._put_indices()

 

https://blog.csdn.net/u014380165/article/details/79058479

这个链接超级详细,实在是看不下去了

 

关于继承dataset类单独写一个博客