zl程序教程

您现在的位置是:首页 >  工具

当前栏目

深度学习笔记:Tensorflow BatchDataset应用示例¶

应用笔记学习 示例 深度 Tensorflow
2023-09-14 09:15:01 时间

目录

1. 前言

2. 将MNIST数据集转换为BatchDataset

2.1 加载并转换为BatchDataset

2.2 TypeError: 'BatchDataset' object is not subscriptable

2.3 AttributeError: 'BatchDataset' object has no attribute 'make_one_shot_iterator'

2.4 TF2.x 处理方法: 直接对BatchDataset进行iteration

3. cats-vs-dogs数据集例

3.1 数据下载

3.2 基于image_dataset_from_directory()将整个数据集读入


1. 前言

        对于小型玩具数据集(toy dataset),可以一次性完全加载进来进行处理。但是在现实中,深度学习往往都要涉及到超大的数据集,比如说像ImageNet这样的数据集你不可能一次性加载进来。即便你的深度学习服务器的内存巨大可以一次性加载,那也不是一个有效利用内存的好做法。tensorflow.data.Dataset模块提供了针对这种情况的有效的处理方法。本文简要介绍基于tensorflow.data.Dataset模块对大数据集的一些处理方法以及一些常见的问题及其解决方案。

2. 将MNIST数据集转换为BatchDataset

        以下以tensorflow内置的MNIST数据集为例进行说明。MNIST数据集相对来说比较小,一般来说其实是可以一次性的加载进来进行处理。这里首先为例来说明BatchDataset处理的一些方面。然后再下一个例子中再考虑一个非内置的存放在硬盘中的一个“大”数据集的处理例,会涉及更多一些的元素。

import tensorflow as tf
from tensorflow.keras.datasets import mnist
print(tf.__version__)

        本例子基于tensorflow 2.5版本在JupyterNotebook中运行验证过。如果版本不一样的话,可能运行结果会有所不同,特别是如果用TF1.x的话。 

2.1 加载并转换为BatchDataset

        以下我们先加载MNIST数据集(如上所述,MNIST其实是可以直接加载的,这里为了简单起见以它为例,先回避从硬盘读取大数据集的处理),然后用tf.data.Dataset模块中的from_tensor_sclies()函数将训练集和测试集转换为BatchDataset.

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

TRAIN_BUF = 1280
BATCH_SIZE = 64

batched_train = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)
batched_test = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)

print(type(train_images), type(test_images))
print(train_images.shape)
print(batched_train)
print(batched_test)
<class 'numpy.ndarray'> <class 'numpy.ndarray'>
(60000, 28, 28)
<BatchDataset shapes: (None, 28, 28), types: tf.uint8>
<BatchDataset shapes: (None, 28, 28), types: tf.uint8>

        从以上打印结果可以看出,原始的load_data()输出的train_images和test_images的类型为普通的numpy多维矩阵,而经过from_tensor_slices变换后生成的batched_train和batched_test的类型则是BatchedDataset。接下来我们看看怎么确认batched_train/test的内容呢?

2.2 TypeError: 'BatchDataset' object is not subscriptable

        取train_dataset的第一个元素看看。。。 

train_dataset[0]

 

        Oops! 发生了什么?BatchDataset对象不是像普通的numpy多维矩阵那样通过下标访问!为什么呢?

        从以上的形状(shape)打印结果来看,其实我们已经可以看出差异,train_images的第一个维度的大小是60000,表示train_images是实实在在地包含了60000个样本的(在深度学习中,作为一个惯例,一个数据集张量的第一个轴(axis)表示样本轴。具体说明可以参考<<Francois Chollet: Deep Learning with Python>>).而batched_train的样本轴是None!

        那如何才能从batched_train/test中恢复出正确的数据样本呢?答案是把它们转换成迭代器(iterator),事实上它们被使用时就是先被转换成迭代器然后再从中逐个Batch地提取的,只不过大部分时候这些都是被封装进行后续编译训练处理中,对用户不可见(或者说不用用户操心)而已。

2.3 AttributeError: 'BatchDataset' object has no attribute 'make_one_shot_iterator'

         在网上搜一搜,很多人告诉如下这样把BatchDataset转换成迭代器:

batched_train_iter = batched_train.make_one_shot_iterator()

        燃鹅。。。得到的是这样的:

         相信很多人都碰到过这个错误。这个代码调用方式在网上随便一查到处都是,为什么会出错呢?原因在于tensorflow从1.x升级到2.x时对代码做了很多修改,以上调用方式是属于TF1.x的,在TF2.x时代行不通了。在TF2.x,make_one_shot_iterator被打入了'冷宫'--被挪到了tensorflow.compat.v1.data中去了。说的直白一点就是,这个函数已经被deprecated了,但是为了保持对TF1.x的后向兼容,所以暂时设了一个'冷宫'存放。。。潜台词是能不用就不要再用了,说不定哪天就完全被删除了。但是下面我们还是用它先试试看。

import matplotlib.pyplot as plt
from PIL import Image

batched_train_iter = tf.compat.v1.data.make_one_shot_iterator(batched_train)

next_batch = batched_train_iter.get_next()
print(type(next_batch), next_batch.shape)
plt.imshow(next_batch[0])
<class 'tensorflow.python.framework.ops.EagerTensor'> (64, 28, 28)

        从上面的运行结果可以看出,先用make_one_shot_iterator变成iterator,然后再用get_next()函数就能获得下一个batch。next_batch现在就是一个常规的张量(tensorflow.python.framework.ops.EagerTensor),可以通过下标引用了。

        在TF1.x中还有一个与make_one_shot_iterator相近的一个函数make_initializable_iterator(),后者却并不在tf.compat.v1.data中,如果运行以下代码的话会报错:

        batched_train_iter = tf.compat.v1.data.make_initializable_iterator(batched_train)

        难道被彻底删除了,还是说挪到了别的什么位置?Anyway,反正它们应该在TF2.x是不受待见的,就此忘掉它们吧。

        那在TF2.x中,更直接自然的处理方式应该是什么呢?

2.4 TF2.x 处理方法: 直接对BatchDataset进行iteration

        简单直接。。。

k = 0
for batch in batched_train:    
    if k == 1:
        plt.imshow(next_batch[0])
    k += 1
print('Totally, there are {0} batches'.format(k))    
    Totally, there are 938 batches(MNIST中总共有60000张图片,每个batch为64张)

        或者你也可以先创建一个显式的迭代器,然后再对该迭代器进行迭代。如下所示: 

batched_train_iter = batched_train.__iter__()
k = 0
for next_batch in batched_train_iter:    
    if k == 1:
        plt.imshow(next_batch[0])
    k += 1
print('Totally, there are {0} batches'.format(k))   

3. cats-vs-dogs数据集例

        如前所述,MNIST这种玩具数据集本身内置于Tensorflow中不需要特意地转换为BatchDataset,直接使用即可。但是在实际应用中通常是存放在硬盘中,而且可能也大到不允许一次性地读入,这个时候应该怎么办呢?下面我们以另一个经典的数据集cats-vs-dogs数据集来说明一个进阶处理方法。

3.1 数据下载

        猫狗数据集出自kaggle竞赛,可以从kaggle网页下载,但是在国内访问kaggle好像是有问题。幸好从microsoft网页下载也可以方便地下载到。

Download Kaggle Cats and Dogs Dataset from Official Microsoft Download Center

        下载展开后目录结构如下所示:

        (1) cats-vs-dogs\cat

        (2) cats-vs-dogs\dog

3.2 基于image_dataset_from_directory()将整个数据集读入

from tensorflow.keras.preprocessing import image_dataset_from_directory
import os, shutil, pathlib

cats_dogs_dir = pathlib.Path("F:\DL\cats-vs-dogs")

cats_dogs_dataset = image_dataset_from_directory(
    cats_dogs_dir,
    image_size=(180, 180),
    batch_size=64)

print(cats_dogs_dataset)

 Found 25000 files belonging to 2 classes.

<BatchDataset shapes: ((None, 180, 180, 3), (None,)), types: (tf.float32, tf.int32)>

        如上所示,这并不是真正地将整个数据集完全读入。image_dataset_from_directory()返回的是一个BatchDataset对象。但是从打印信息来看,可以知道这个数据集有25000张图像,分属于两个类别,这个是image_dataset_from_directory()自动地从两个子目录“cat”,“dog”识别出来的。image_dataset_from_directory()有很多选项,但是这个函数的用法不是本文的重点,所以这里就不细说了。

        有了BatchDataset对象,我们就可以跟前一章的例子一样逐个batch提取并确认其中数据样本了。这里作为示例,我们从中随机挑几张图片出来看看。

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

batched_iter = cats_dogs_dataset.__iter__()

next_batch = next(batched_iter) 
next_batch_data = next_batch[0]
next_batch_label = next_batch[1]
print(next_batch_data.shape, next_batch_label.shape)

#k = 0
fig,ax = plt.subplots(2,4, figsize=[16,8])
idx = [i for i in np.random.randint(0,next_batch_data.shape[0],8)]   
print(idx)
for i in range(4):            
    ax[0][i].imshow(next_batch_data[idx[i]]/255.)
    ax[1][i].imshow(next_batch_data[idx[i+4]]/255.)                
(64, 180, 180, 3) (64,)
[58, 13, 40, 58, 28, 60, 54, 3]