pytorch-TensorFlow-加载直接的数据集
2023-03-14 22:41:52 时间
在我刚刚学习深度学习的时候,就只会用现有的数据集。当我想训练直接的模型的时候,却不知道该怎样弄,但时是花了两天在网上寻找教程,可是都不太适合新手学习,所以今天我就来总结一下pytorch里面加载自己的数据集的方法。
方法一:利用torch.utils.data.TensorDataset,也是我认为最简单的方法
from torch.utils.data import TensorDataset,DataLoader x = #对应你加载的数据 是tensor类型 y = #对应你数据的标签 是tensor类型 data_set = TensorDataset(x,y) data_loader = DataLoader(data_set,batch_size = 100,shuffle = True,drop_last = True) #生成可迭代的迭代器 #训练时 for index (img,label) in enumerate(data_loaer): #训练代码
方法二:重写Dataset类
from torch.utils.data import Dataset,DataLoader class dataset(Dataset): #继承Dataset类 def __init__(self,root,transform=None): #root为你数据的路径 下面以图片为例 imgs = os.listdir(root) #返回root里面所用文件的名称 self.img = [os.path.join(root,img) for img in imgs] #生成每张照片的路径 self.transform = transform #对图片的一些处理 def __getitem__(self,index): img_path = self.img(index) label = 0 if 'woman' in img_path.split('/')[-1] else 1 #自己根据文件名来设置标签,例如这里是文件名包不包含woman。 data = Image.open(img_path) #读取图片,在PIL库中 data = self.transform(data) #对图片进行装换,一定要转换成tensor return data,label #返回图片和其标签 def __len__(self): return len(self.img) #返回数据的大小 data = dataset(root) data_loader = DataLoader(data,batch_size=100,shuffle=True,drop_last= True) #batch_size每一次迭代数据的大小 shuffle对图片是否打散 drop_last对最后的数据如果不满足batch_size的大小就舍弃
方法三:利用ImageFolder对图片进行读取
ImageFolder是对整个文件夹进行对取,每个文件夹的内容,会自动被归为一类。
from torchvision.datasets import ImageFolder data = ImageFolder(root) #root的根目录放保存每一类的文件夹 data_loader = DataLoader(data,batch_size=100,shuffle=True,drop_last= True)
基于这三种简单易行的方法,你可以很方便的根据你数据的存放的形式进行构造自己的数据集。
下面介绍一下tf 2.0的构造自己的数据集,先看一下代码吧,也不难。
import tensorflow as tf x = #数据。numpy类型 y = #标签。numpy类型 # x,y tensor类型的我还没尝试过你可以去转换成tensor类型的试试。 data = tf.data.Dataset.slices((x,y)) data_load = data.repeat().shuffle(5000).batch(128).prefetch(1) for step , (x,y) in enumerate(data_load.take(2000),1):
这种方法与pytorch的第一种方法类似。
prefetch(x) ,表示预先准备下x次迭代的数据,提高效率。
batch , 表示每一次迭代的数据大小,比如图片的话就是每一次迭代128张图片的数据。
shuffle(x) ,表示打乱数据的次序,x表示打乱的次数,每一次迭代算一次。
take(x) ,表示训练的epoch数,在pytorch里面需要在外面在嵌套一次for循环来设置训练的epochs数。
Thank for your reading !!!
公众号:FPGA之旅
相关文章
- 每年14PB数据存储需求,海量交通安全数据如何安放?
- Oracle提议将G1作为Java 9的默认垃圾收集器
- 从携程瘫痪事件看运维的85条军规
- 详谈Struts+Hibernate+Spring三大框架
- Java Hibernate 之连接池详解
- PHP开发者常犯的10个MySQL错误
- 为什么程序员不会最简单的100到1的输出?
- 敏捷数据分析方法论革命来袭
- 不懂编程和数据库也能创建表单
- 发布&订阅的消息系统 Kafka的深度解析
- 2015年薪酬大涨的15个IT岗位
- Guava并发:ListenableFuture与RateLimiter示例
- 20个最热门的IT技术职位及薪资
- Oracle宣布更多的Java 9 新特性
- 一张图看懂硅谷科技公司的男女比例
- Ztree + PHP 无限极节点 递归查找节点法
- 这根本不是 BASH 的 BUG,是特征!
- Java9 2016年发布,新特性抢先看
- 通过Redis实现RPC远程方法调用
- Spring高级事务管理难点剖析