Pytorch 重写Dataloader
2023-09-27 14:25:31 时间
这是一个官网的例子:torch.nn入门。
一般而言,我们会根据自己的数据需求继承Dataset(from torch.utils.data import Dataset, DataLoader)重写数据读取函数。或者利用TensorDataset更加简洁实现读取数据。
抑或利用 torchvision里面的ImageFolder
也可管理数据。这几种方法已经可以实现数据读取了,而DataLoader的作用是更加全面管理批量数据:
下面进入正题,MNIST数据利用CNN时需要转换为二维数据,所以需要对初始的线性数据进行转换。一般,可以读取先行数据后在模型中进行view来实现:
class Lambda(nn.Module): def __init__(self, func): super().__init__() self.func = func def forward(self, x): return self.func(x) def preprocess(x): return x.view(-1, 1, 28, 28) model = nn.Sequential( Lambda(preprocess), nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.AvgPool2d(4), Lambda(lambda x: x.view(x.size(0), -1)), )
文中给出另一种解决方案:重写DateLoader:将数据处理移到生成器里面
def get_data(train_ds, valid_ds, bs): return ( DataLoader(train_ds, batch_size=bs, shuffle=True), DataLoader(valid_ds, batch_size=bs * 2), ) def preprocess(x, y): return x.view(-1, 1, 28, 28), y class WrappedDataLoader: def __init__(self, dl, func): self.dl = dl self.func = func def __len__(self): return len(self.dl) def __iter__(self): batches = iter(self.dl) for b in batches: yield (self.func(*b)) train_dl, valid_dl = get_data(train_ds, valid_ds, bs) train_dl = WrappedDataLoader(train_dl, preprocess) valid_dl = WrappedDataLoader(valid_dl, preprocess)
模型就可以写成这样:
model = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1), Lambda(lambda x: x.view(x.size(0), -1)), )
相关文章
- 更快、更 Pythonic 的 PyTorch 2.0 | 非常值得期待
- Pytorch nn.Module的基本使用
- Google 的 colab 支持 pytorch 框架训练(教程)
- Pytorch中torch.Tensor.scatter_用法
- [轻笔记]Pytorch语义分割输出转换为图像显示
- 利用Pytorch进行CNN详细剖析
- PyTorch-Adam优化算法原理,公式,应用
- Pytorch入门随手记
- pytorch 模型的断点训练
- TensorFlow 系列案例(4)及Pytorch 学习(3)实现K-Means聚类算法
- Pytorch 学习(4): Pytorch中Torch 工具包的数学操作汇总速查