pytorch学习:准备自己的图片数据
2023-09-11 14:17:15 时间
图片数据一般有两种情况:
1、所有图片放在一个文件夹内,另外有一个txt文件显示标签。
2、不同类别的图片放在不同的文件夹内,文件夹就是图片的类别。
针对这两种不同的情况,数据集的准备也不相同,第一种情况可以自定义一个Dataset,第二种情况直接调用torchvision.datasets.ImageFolder来处理。下面分别进行说明:
一、所有图片放在一个文件夹内
这里以mnist数据集的10000个test为例, 我先把test集的10000个图片保存出来,并生着对应的txt标签文件。
先在当前目录创建一个空文件夹mnist_test, 用于保存10000张图片,接着运行代码:
import torch import torchvision import matplotlib.pyplot as plt from skimage import io mnist_test= torchvision.datasets.MNIST( './mnist', train=False, download=True ) print('test set:', len(mnist_test)) f=open('mnist_test.txt','w') for i,(img,label) in enumerate(mnist_test): img_path="./mnist_test/"+str(i)+".jpg" io.imsave(img_path,img) f.write(img_path+' '+str(label)+'\n') f.close()
经过上面的操作,10000张图片就保存在mnist_test文件夹里了,并在当前目录下生成了一个mnist_test.txt的文件,大致如下:
前期工作就装备好了,接着就进入正题了:
from torchvision import transforms, utils from torch.utils.data import Dataset, DataLoader import matplotlib.pyplot as plt from PIL import Image def default_loader(path): return Image.open(path).convert('RGB') class MyDataset(Dataset): def __init__(self, txt, transform=None, target_transform=None, loader=default_loader): fh = open(txt, 'r') imgs = [] for line in fh: line = line.strip('\n') line = line.rstrip() words = line.split() imgs.append((words[0],int(words[1]))) self.imgs = imgs self.transform = transform self.target_transform = target_transform self.loader = loader def __getitem__(self, index): fn, label = self.imgs[index] img = self.loader(fn) if self.transform is not None: img = self.transform(img) return img,label def __len__(self): return len(self.imgs) train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor()) data_loader = DataLoader(train_data, batch_size=100,shuffle=True) print(len(data_loader)) def show_batch(imgs): grid = utils.make_grid(imgs) plt.imshow(grid.numpy().transpose((1, 2, 0))) plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader): if(i<4): print(i, batch_x.size(),batch_y.size()) show_batch(batch_x) plt.axis('off') plt.show()
自定义了一个MyDataset, 继承自torch.utils.data.Dataset。然后利用torch.utils.data.DataLoader将整个数据集分成多个批次。
二、不同类别的图片放在不同的文件夹内
同样先准备数据,这里以flowers数据集为例,下载:
http://download.tensorflow.org/example_images/flower_photos.tgz
花总共有五类,分别放在5个文件夹下。大致如下图:
我的路径是d:/flowers/.
数据准备好了,就开始准备Dataset吧,这里直接调用torchvision里面的ImageFolder
import torch import torchvision from torchvision import transforms, utils import matplotlib.pyplot as plt img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower', transform=transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor()]) ) print(len(img_data)) data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True) print(len(data_loader)) def show_batch(imgs): grid = utils.make_grid(imgs,nrow=5) plt.imshow(grid.numpy().transpose((1, 2, 0))) plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader): if(i<4): print(i, batch_x.size(), batch_y.size()) show_batch(batch_x) plt.axis('off') plt.show()
就是这样。
相关文章
- 极客时间-左耳听风-程序员攻略-机器学习和人工智能
- Knockout.Js官网学习(加载或保存JSON数据)
- 大话机器学习之数据预处理与数据筛选
- 前端学习 -- Html&Css -- 条件Hack 和属性Hack
- 机器学习数学系列(2):微分选讲
- Opencv学习笔记 DNN模块概述
- Javascript 学习 笔记二
- SAP Analytics Cloud学习笔记(一):从CSV文件导入数据到Analytics Cloud里创建模型和Story
- Jerry 2016年5月20日到5月23日的学习笔记
- ML与Regularization:正则化理论即bias-variance tradeoff(权值衰减/提前终止/数据扩增/Dropout/融合技术)在机器学习中的简介、常用方法、案例应用之详细攻略
- CV:计算机视觉技最强学习路线之CV简介(传统视觉技术/相关概念)、早期/中期/近期应用领域(偏具体应用)、经典CNN架构(偏具体算法)概述、常用工具/库/框架/产品、环境安装、常用数据集、编程技巧
- DL:关于深度学习常用数据集中训练好的权重文件(Deeplab v3、MobileNet、InceptionV3、VGG系列、ResNet、Mask R-CNN )下载地址集合(持续更新)
- 已解决(机器学习中查看数据信息报错)AttributeError: target_names
- 为什么学习Python?数据给你八大理由
- 机器学习案例(六):Python 大数据进行信用卡欺诈检测(完整源码和分析)
- 【大数据&人工智能AI】每个现代数据科学家都必须阅读的 6 篇论文:该领域的每个人都熟悉深度学习的一些最重要的现代基础知识的列表
- Vue学习之--------深入理解Vuex之多组件共享数据(2022/9/4)
- Vue3学习笔记(一)——MVC与vue3概要、模板、数据绑定与综合示例
- RequireJS 学习
- 机器学习中的数据不平衡问题----通过随机采样比例大的类别使得训练集中大类的个数与小类相当,或者模型中加入惩罚项
- 动手学习数据分析(五)——数据建模及模型评估
- 【深度学习】Pytorch面试题:什么是 PyTorch?PyTorch 的基本要素是什么?Conv1d、Conv2d 和 Conv3d 有什么区别?
- 结合Java和机器学习技术,如何驾驭大数据提升业务效率和竞争力?
- 【Pytorch】第 1 章 :强化学习和 PyTorch 入门
- 深度学习笔记:在小数据集上从头训练卷积神经网络
- 【深度学习】语义分割实验:Unet网络/MSRC2数据集
- 对于机器学习保险行业问答开放数据集DeepQA-1的详细注解(三)