zl程序教程

您现在的位置是:首页 >  后端

当前栏目

pytorch 从Dataset类中获取数据

PyTorch Dataset 获取数据 类中
2023-09-14 09:11:21 时间

转自:https://www.jianshu.com/p/4818a1a4b5bd

1.介绍

Dataset类是为torch.utils.data.DataLoader做准备,支持两种类型的访问

* map-style datasets #__getitem__()
* iterable-style datasets  #__iter__()
(1) 
print("trainDataset 的类型:", type(trainDataset))
>>> trainDataset 的类型: <class 'torchvision.datasets.mnist.MNIST'>

(2)
print("trainDataset 的长度:", len(trainDataset))
>>> trainDataset 的长度: 60000

(3)
print("trainDataset[0] 的类型:", type(trainDataset[0]))
print("trainDataset[0] 的长度:", len(trainDataset[0]))
>>>  trainDataset[0] 的类型: <class 'tuple'>
     trainDataset[0] 的长度: 2

(4)
print("trainDataset[0][0] 的类型:", type(trainDataset[0][0]))
print("trainDataset[0][0] 的形状:", trainDataset[0][0].shape)
>>>  trainDataset[0][0] 的类型: <class 'torch.Tensor'>
     trainDataset[0][0] 的形状: torch.Size([1, 28, 28])

(5)
print("trainDataset[0][1] 的类型:", type(trainDataset[0][1]))
print("trainDataset[0][1] :", trainDataset[0][1])
>>>  trainDataset[0][1] 的类型: <class 'int'>
     trainDataset[0][1] : 5

从上述代码可以看到,能够通过一些方法去访问。