zl程序教程

您现在的位置是:首页 >  大数据

当前栏目

Pytorch制作数据集

数据PyTorch 制作
2023-06-13 09:12:32 时间

pytorch中制作数据集是要基于Dataset类来进行

首先查看一下Dataset的官方教程

如图,Dataset是一个抽象类,只能被继承,不能被实例化,我们要构建自己的数据集类时需要继承Dataset类,并且所有的子类需要重写Dataset中的__getitem__和__len__函数,前者是我们构建数据集的重点,而后者只是返回数据集的长度。

需要读取的数据存放在名为dataset的文件夹下,文件结构如图:

数据就是.jpg的图片,标签是文件夹名ants,ants下的所有图片都是关于蚂蚁的图片,另有文件夹bees,与ants类似。

from torch.utils.data import Dataset#import Dataset类
from PIL import Image#图像处理
import os#操作系统相关库,用来根据路径读取数据

class Mydata(Dataset):#我们读取数据的类要继承Dataset类
    def __init__(self,root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(root_dir,label_dir)#将两个路径合并为一个
        self.img_path=os.listdir(self.path)#获取所有图片的文件名列表

    def __getitem__(self, idx):#返回数据标签对的函数
        img_name=self.img_path[idx]#通过idx(索引)访问数据,最终实例化Mydata后可直接##通过索引访问数据-标签对
        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)#合并路##径
        img=Image.open(img_item_path)#获得图片对象
        label=self.label_dir
        return img,label#最终返回数据-标签对即可
    def __len__(self):
        return len(self.img_path)

root_dir="dataset/train"
ants_label_dir="ants"
bees_label_dir="bees"
ants_dataset=Mydata(root_dir,ants_label_dir)#蚂蚁数据集
bees_dataset=Mydata(root_dir,bees_label_dir)#蜜蜂数据集
train_dataset=ants_dataset+bees_dataset#合并两个数据集
img,label=ants_dataset[0]#通过索引读取数据对
img.show()#打印图片
print(label)#打印label