【Pytorch学习笔记】11.取Dataset的子集、给Dataset打乱顺序的方法(使用Subset、random_split)
2023-09-27 14:19:57 时间
(pytorch版本:1.2)
我们在使用Dataset定义好数据集后,在处理数据集时经常会碰到这些问题:如何把Dataset拆分成两个子集(如用于指定训练集和测试集、k折交叉验证等)?如何进行随机拆分?如何打乱一个Dataset内数据的顺序?
Dataset取子集、拆分
使用 torch.utils.data.Subset() 可对数据集取子集。
传入一个Dataset,一个序列切片indices,即可得到一个子集。
1.我们可以传入一个range():
indices = range(18353) # 取标号为第0个到第18352个数据
sub_imgs = torch.utils.data.Subset(imgs, indices)
len(imgs), len(sub_imgs)
2.可以取区间:
indices = range(18353, 27153) # 取标号为第18353个到第27152个数据
sub_imgs = torch.utils.data.Subset(imgs, indices)
len(imgs), len(sub_imgs)
3.可以传入一个List。有List就可以用列表生成式:
indices = [x for x in range(1234)]
sub_imgs = torch.utils.data.Subset(imgs, indices)
len(imgs), len(sub_imgs)
打乱Dataset内数据的顺序
我们可以直接传入一个乱序的index就可以达到数据集乱序的目的:
from torch import randperm
lenth = randperm(len(Leaf_dataset_train)).tolist() # 生成乱序的索引
rand_train = torch.utils.data.Subset(imgs, lenth)
# 显示一下第一张图片、原标号
X = rand_train[0]
plt.imshow(torch.transpose(X[0],0,2)), lenth[0]
我们在打乱顺序后就可以取子集对数据集进行k折交叉验证等行为。
随机拆分Dataset
使用 torch.utils.data.random_split() 可直接对数据集进行拆分,随机分成多份。
可以传入一个List,注意传入的List序列中包含每个子集的大小(数量),且这几个数的和必须等于传入Dataset的长度。
示例:
# 这里Leaf_dataset_train的大小必须等于 17000+1353
train_set, test_set = torch.utils.data.random_split(Leaf_dataset_train, [17000, 1353])
print(len(train_set), len(test_set))
相关文章
- 【学习笔记21】JavaScript数组的基本方法
- w5500驱动使用方法调试笔记
- 机器学习笔记之密度聚类——DBSCAN方法(Python代码实现)
- 机器学习笔记之核方法(二)正定核函数的充要性证明
- 机器学习笔记之马尔可夫链蒙特卡洛方法(一)蒙特卡洛方法介绍
- 《工作流管理——模型、方法和系统》笔记2:Petri网对工作流建模
- 统计学习方法笔记 -- 朴素贝叶斯
- 统计学习方法笔记 -- 隐马尔可夫模型
- 汇川使用笔记6:伺服轴控功能块封装方法示意
- PopupWindow 常用方法学习笔记
- Thinkphp学习笔记5-URL生成U方法
- 《Java学习笔记》:日期类常用方法全归纳,值得收藏。
- Java_jdbc 基础笔记之八 数据库连接(写一个查询Student对象的方法)
- Java面向对象基础知识笔记:方法、构造方法、方法重载、继承、多态、抽象类、接口、静态字段与静态方法、包、作用域、classpath与jar、模块依赖关系
- 知识图谱入门学习笔记(四)-知识抽取之问题和方法
- 今日笔记:持续集成、面向对象设计方法