zl程序教程

您现在的位置是:首页 >  大数据

当前栏目

TensorFlow下数据加载——tf.data的使用

数据 加载 Data Tensorflow TF 使用
2023-09-11 14:22:29 时间

1. 概述

在Pytorch中数据加载是通过torch.utils.data.datasettorch.utils.data.dataloader完成的,而在TensorFlow中现在主推的是使用tf.data实现数据加载。此前,在TensorFlow中读取数据一般有两种方法:

新出的数据加载工具更加简洁高效,也是后面主推的方式,若是使用TensorFlow的Eager模式就必须使用这种数据加载方式,在这个模式下的数据加载也会存在一些细微不同(Eager模式下丢掉了Session,可以像python Debug一样调试程序,对于数据迭代可以使用python的内部函数iter实现)。

下面是tf.data.Dataset的类继承关系图:
在这里插入图片描述
一般用到的是tf.data.Dataset.from_tensor_slices完成数据加载,但是根据上图的继承关系也提供了另外3种数据加载方式:

  • 1)tf.data.TextLineDataset():这个函数的输入是一个文件的列表,输出是一个dataset。dataset中的每一个元素就对应了文件中的一行。可以使用这个函数来读入CSV文件。
  • 2)tf.data.FixedLengthRecordDataset():这个函数的输入是一个文件的列表和一个record_bytes,之后dataset的每一个元素就是文件中固定字节数record_bytes的内容。通常用来读取以二进制形式保存的文件,如CIFAR10数据集就是这种形式。
  • 3)tf.data.TFRecordDataset():顾名思义,这个函数是用来读TFRecord文件的,dataset中的每一个元素就是一个TFExample。

2. DataSet的常用函数

使用tf.data.Dataset实现数据加载只需要使用调用一个函数就可以了,类似于下面的形式:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))

在上面的数据集定义好之后会有一些函数可供选择:

  • 1)tf.data.Dataset.map(f, num_parallel_calls)
    Dataset.map 转换通过将函数 f 应用于输入数据集的每对元素(data, label)来生成新数据集。比如在上面的例子中,就是把(filename, label)中filename指定的图像读取出来并调整大小。num_parallel_calls指定使用多少个线程来进行map操作。可以设置为CPU的最大核心数目(=multiprocessing.cpu_count())。如果不指定的话,只使用一个线程处理数据。
  • 2)tf.data.Dataset.batch(batch_size)
    这个函数特别重要。 假如输入图像大小为(227,227,3),模型的输入shape为(None,227,227,3),其中None是batch_size。
  • 3)tf.data.Dataset.repeat(count)
    重复这个数据集多少次。如果不传count这个参数,默认会无限重复这个数据集。加入count=1,那么当你训练完一轮之后,就会报错tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence。在实际使用中,基本可以不传count参数,无限重复这个数据集。

3. 数据加载

3.1 Eager模式

使用过TensorFlow的大家都会知道,TF通过计算图将计算的定义和执行分隔开, 这是一种声明式(declaretive)的编程模型。确实,这种静态图的执行模式优点很多,但是在debug时确实非常不方便(类似于对编译好的C语言程序调用,此时是我们无法对其进行内部的调试),因此有了Eager Execution。

引入的Eager Execution模式后,TensorFlow就拥有了类似于Pytorch一样动态图模型能力,我们可以不必再等到see.run(*)才能看到执行结果,可以方便在IDE随时调试代码,查看OPs执行结果。在代码中添加一句话就可以实现Eager模式的启用:

tf.enable_eager_execution()

3.2 非Eager模式下的数据加载

使用make_one_shot_iterator()初始化迭代

def get_item_by_tf(file_path, label):
    img_byte = tf.io.read_file(file_path)
    img_decode = tf.image.decode_jpeg(img_byte)
    img_decode = tf.cast(img_decode, dtype=tf.float32)
    img_regular = tf.divide(tf.subtract(img_decode, 127.0), 255.0)
    # img_resize = tf.image.resize_images(img_decode)
    return img_regular, label

if __name__ == "__main__":
    data_root = "E:/DataBase/MNIST/MNIST_DataSet/imgs_train/"
    img_list = os.listdir(data_root)
    img_list = [os.path.join(data_root, item) for item in img_list]
    img_list = tf.constant(img_list)
    dataset = tf.data.Dataset.from_tensor_slices(img_list)

    batch_size_var = 32
    dataset = dataset.map(get_item_by_tf, num_parallel_calls=mlt.cpu_count())
    dataset = dataset.batch(batch_size=2 * batch_size_var).shuffle(buffer_size=10 * batch_size_var).repeat()

    next_op = dataset.make_one_shot_iterator().get_next()

    with tf.Session() as sess:
        try:
            while True:
                img_data = sess.run(next_op)
                print(img_data[0].shape)
        except tf.errors.OutOfRangeError:
            print("load end!")

在遇到一些特殊情况下使用TF自带的读取工具无法有效读取数据,需要使用到第三方的书读取,这里就是要使用到tf.py_func(tf.numpy_function)来进行读取,只需要将对应的回调函数进行替换就行了:

def cv_reader_call(file_path):
    file_path = file_path.decode()
    img = cv2.imread(file_path, 0)[:,:,np.newaxis]
    return img

def get_item_by_cv(file_path):
    # return tf.numpy_function(cv_reader_call, [file_path], [tf.uint8])
    return tf.py_func(cv_reader_call, [file_path], [tf.uint8])

使用make_initializable_iterator()初始化迭代

if __name__ == "__main__":
    data_root = "E:/DataBase/MNIST/MNIST_DataSet/imgs_train/"
    img_list = os.listdir(data_root)
    img_list = [os.path.join(data_root, item) for item in img_list]
    img_list = tf.constant(img_list)
    dataset = tf.data.Dataset.from_tensor_slices(img_list)

    batch_size_var = tf.placeholder(dtype=tf.int64, shape=[])
    dataset = dataset.map(get_item_by_cv, num_parallel_calls=mlt.cpu_count())
    dataset = dataset.batch(batch_size=2 * batch_size_var).shuffle(buffer_size=10 * batch_size_var).repeat()

    data_iter = dataset.make_initializable_iterator()
    next_op = data_iter.get_next()

    with tf.Session() as sess:
        sess.run(data_iter.initializer, feed_dict={batch_size_var: 32})

3.3 Eager模式下的数据加载

tf.enable_eager_execution()

if __name__ == "__main__":
    data_root = "E:/DataBase/MNIST/MNIST_DataSet/imgs_train/"
    img_list = os.listdir(data_root)
    img_list = [os.path.join(data_root, item) for item in img_list]
    img_list = tf.constant(img_list)
    dataset = tf.data.Dataset.from_tensor_slices(img_list)

    batch_size_var = 32
    dataset = dataset.map(get_item_by_cv, num_parallel_calls=mlt.cpu_count())
    dataset = dataset.batch(batch_size=2 * batch_size_var).shuffle(buffer_size=10 * batch_size_var).repeat()

    data_iter = iter(dataset)
    while True:
        try:
            img_data = next(data_iter)
            print(img_data[0].shape)
        except StopIteration:
            print("load end!")

4. REF

  1. 『TensorFlow』数据读取类_data.Dataset
  2. tf.data学习指南(超实用、超详细)