zl程序教程

您现在的位置是:首页 >  数据库

当前栏目

pytorch-TensorFlow-加载直接的数据集

2023-03-14 22:41:52 时间

在我刚刚学习深度学习的时候,就只会用现有的数据集。当我想训练直接的模型的时候,却不知道该怎样弄,但时是花了两天在网上寻找教程,可是都不太适合新手学习,所以今天我就来总结一下pytorch里面加载自己的数据集的方法。


方法一:利用torch.utils.data.TensorDataset,也是我认为最简单的方法

from torch.utils.data import TensorDataset,DataLoader
x =       #对应你加载的数据   是tensor类型
y =        #对应你数据的标签  是tensor类型
data_set = TensorDataset(x,y)   
data_loader = DataLoader(data_set,batch_size = 100,shuffle = True,drop_last = True) #生成可迭代的迭代器

#训练时
for index (img,label) in enumerate(data_loaer):
    #训练代码

方法二:重写Dataset类

from torch.utils.data import Dataset,DataLoader
class dataset(Dataset):  #继承Dataset类
    def __init__(self,root,transform=None):  #root为你数据的路径  下面以图片为例
    imgs = os.listdir(root)    #返回root里面所用文件的名称
    self.img = [os.path.join(root,img) for img in imgs]  #生成每张照片的路径
    self.transform = transform  #对图片的一些处理
    def __getitem__(self,index):
        img_path = self.img(index)
        label = 0 if 'woman' in img_path.split('/')[-1] else 1 #自己根据文件名来设置标签,例如这里是文件名包不包含woman。
        data =  Image.open(img_path) #读取图片,在PIL库中
        data = self.transform(data)   #对图片进行装换,一定要转换成tensor
        return data,label    #返回图片和其标签
    def __len__(self):
        return len(self.img)   #返回数据的大小

data = dataset(root)
data_loader = DataLoader(data,batch_size=100,shuffle=True,drop_last= True)
                            #batch_size每一次迭代数据的大小               shuffle对图片是否打散      drop_last对最后的数据如果不满足batch_size的大小就舍弃
        

方法三:利用ImageFolder对图片进行读取

ImageFolder是对整个文件夹进行对取,每个文件夹的内容,会自动被归为一类。

from torchvision.datasets import ImageFolder
data = ImageFolder(root)  #root的根目录放保存每一类的文件夹
data_loader = DataLoader(data,batch_size=100,shuffle=True,drop_last= True)

基于这三种简单易行的方法,你可以很方便的根据你数据的存放的形式进行构造自己的数据集。

下面介绍一下tf 2.0的构造自己的数据集,先看一下代码吧,也不难。

import tensorflow as tf
x =     #数据。numpy类型 
y =     #标签。numpy类型
# x,y tensor类型的我还没尝试过你可以去转换成tensor类型的试试。
data = tf.data.Dataset.slices((x,y))
data_load = data.repeat().shuffle(5000).batch(128).prefetch(1)

for step , (x,y) in enumerate(data_load.take(2000),1):
    

这种方法与pytorch的第一种方法类似。

prefetch(x) ,表示预先准备下x次迭代的数据,提高效率。

batch , 表示每一次迭代的数据大小,比如图片的话就是每一次迭代128张图片的数据。

shuffle(x) ,表示打乱数据的次序,x表示打乱的次数,每一次迭代算一次。

take(x) ,表示训练的epoch数,在pytorch里面需要在外面在嵌套一次for循环来设置训练的epochs数。


Thank for your reading !!!

公众号:FPGA之旅