TensorFlow下数据加载——tf.data的使用
1. 概述
在Pytorch中数据加载是通过torch.utils.data.dataset
与torch.utils.data.dataloader
完成的,而在TensorFlow中现在主推的是使用tf.data
实现数据加载。此前,在TensorFlow中读取数据一般有两种方法:
- 1)使用placeholder读内存中的数据
- 2)使用queue读硬盘中的数据(参考:十图详解tensorflow数据读取机制(附代码))
新出的数据加载工具更加简洁高效,也是后面主推的方式,若是使用TensorFlow的Eager模式就必须使用这种数据加载方式,在这个模式下的数据加载也会存在一些细微不同(Eager模式下丢掉了Session,可以像python Debug一样调试程序,对于数据迭代可以使用python的内部函数iter
实现)。
下面是tf.data.Dataset
的类继承关系图:
一般用到的是tf.data.Dataset.from_tensor_slices
完成数据加载,但是根据上图的继承关系也提供了另外3种数据加载方式:
- 1)
tf.data.TextLineDataset()
:这个函数的输入是一个文件的列表,输出是一个dataset。dataset中的每一个元素就对应了文件中的一行。可以使用这个函数来读入CSV文件。 - 2)
tf.data.FixedLengthRecordDataset()
:这个函数的输入是一个文件的列表和一个record_bytes,之后dataset的每一个元素就是文件中固定字节数record_bytes的内容。通常用来读取以二进制形式保存的文件,如CIFAR10数据集就是这种形式。 - 3)
tf.data.TFRecordDataset()
:顾名思义,这个函数是用来读TFRecord文件的,dataset中的每一个元素就是一个TFExample。
2. DataSet的常用函数
使用tf.data.Dataset
实现数据加载只需要使用调用一个函数就可以了,类似于下面的形式:
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
在上面的数据集定义好之后会有一些函数可供选择:
- 1)
tf.data.Dataset.map(f, num_parallel_calls)
Dataset.map 转换通过将函数 f 应用于输入数据集的每对元素(data, label)来生成新数据集。比如在上面的例子中,就是把(filename, label)中filename指定的图像读取出来并调整大小。num_parallel_calls指定使用多少个线程来进行map操作。可以设置为CPU的最大核心数目(=multiprocessing.cpu_count())。如果不指定的话,只使用一个线程处理数据。 - 2)
tf.data.Dataset.batch(batch_size)
这个函数特别重要。 假如输入图像大小为(227,227,3),模型的输入shape为(None,227,227,3),其中None是batch_size。 - 3)
tf.data.Dataset.repeat(count)
重复这个数据集多少次。如果不传count这个参数,默认会无限重复这个数据集。加入count=1,那么当你训练完一轮之后,就会报错tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence
。在实际使用中,基本可以不传count参数,无限重复这个数据集。
3. 数据加载
3.1 Eager模式
使用过TensorFlow的大家都会知道,TF通过计算图将计算的定义和执行分隔开, 这是一种声明式(declaretive)的编程模型。确实,这种静态图的执行模式优点很多,但是在debug时确实非常不方便(类似于对编译好的C语言程序调用,此时是我们无法对其进行内部的调试),因此有了Eager Execution。
引入的Eager Execution模式后,TensorFlow就拥有了类似于Pytorch一样动态图模型能力,我们可以不必再等到see.run(*)才能看到执行结果,可以方便在IDE随时调试代码,查看OPs执行结果。在代码中添加一句话就可以实现Eager模式的启用:
tf.enable_eager_execution()
3.2 非Eager模式下的数据加载
使用make_one_shot_iterator()
初始化迭代
def get_item_by_tf(file_path, label):
img_byte = tf.io.read_file(file_path)
img_decode = tf.image.decode_jpeg(img_byte)
img_decode = tf.cast(img_decode, dtype=tf.float32)
img_regular = tf.divide(tf.subtract(img_decode, 127.0), 255.0)
# img_resize = tf.image.resize_images(img_decode)
return img_regular, label
if __name__ == "__main__":
data_root = "E:/DataBase/MNIST/MNIST_DataSet/imgs_train/"
img_list = os.listdir(data_root)
img_list = [os.path.join(data_root, item) for item in img_list]
img_list = tf.constant(img_list)
dataset = tf.data.Dataset.from_tensor_slices(img_list)
batch_size_var = 32
dataset = dataset.map(get_item_by_tf, num_parallel_calls=mlt.cpu_count())
dataset = dataset.batch(batch_size=2 * batch_size_var).shuffle(buffer_size=10 * batch_size_var).repeat()
next_op = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
try:
while True:
img_data = sess.run(next_op)
print(img_data[0].shape)
except tf.errors.OutOfRangeError:
print("load end!")
在遇到一些特殊情况下使用TF自带的读取工具无法有效读取数据,需要使用到第三方的书读取,这里就是要使用到tf.py_func(tf.numpy_function)
来进行读取,只需要将对应的回调函数进行替换就行了:
def cv_reader_call(file_path):
file_path = file_path.decode()
img = cv2.imread(file_path, 0)[:,:,np.newaxis]
return img
def get_item_by_cv(file_path):
# return tf.numpy_function(cv_reader_call, [file_path], [tf.uint8])
return tf.py_func(cv_reader_call, [file_path], [tf.uint8])
使用make_initializable_iterator()
初始化迭代
if __name__ == "__main__":
data_root = "E:/DataBase/MNIST/MNIST_DataSet/imgs_train/"
img_list = os.listdir(data_root)
img_list = [os.path.join(data_root, item) for item in img_list]
img_list = tf.constant(img_list)
dataset = tf.data.Dataset.from_tensor_slices(img_list)
batch_size_var = tf.placeholder(dtype=tf.int64, shape=[])
dataset = dataset.map(get_item_by_cv, num_parallel_calls=mlt.cpu_count())
dataset = dataset.batch(batch_size=2 * batch_size_var).shuffle(buffer_size=10 * batch_size_var).repeat()
data_iter = dataset.make_initializable_iterator()
next_op = data_iter.get_next()
with tf.Session() as sess:
sess.run(data_iter.initializer, feed_dict={batch_size_var: 32})
3.3 Eager模式下的数据加载
tf.enable_eager_execution()
if __name__ == "__main__":
data_root = "E:/DataBase/MNIST/MNIST_DataSet/imgs_train/"
img_list = os.listdir(data_root)
img_list = [os.path.join(data_root, item) for item in img_list]
img_list = tf.constant(img_list)
dataset = tf.data.Dataset.from_tensor_slices(img_list)
batch_size_var = 32
dataset = dataset.map(get_item_by_cv, num_parallel_calls=mlt.cpu_count())
dataset = dataset.batch(batch_size=2 * batch_size_var).shuffle(buffer_size=10 * batch_size_var).repeat()
data_iter = iter(dataset)
while True:
try:
img_data = next(data_iter)
print(img_data[0].shape)
except StopIteration:
print("load end!")
4. REF
相关文章
- 一些关于大数据的总结
- [Ubuntu] ubuntu的tty下挂载移动硬盘拷贝数据
- 数据浪潮之间的前端工程师
- 使用ListView控件展示数据
- ch6-定制数据对象(打包代码和数据)
- 【第三篇】学习 android 事件总线androidEventbus之list数据事件的传递,发送list数据事件到另外一个Activity
- spring boot:使接口返回统一的RESTful格式数据(spring boot 2.3.1)
- Ds中有数据,但Gridview上未显示的原因小结
- atitit.数据验证--db数据库数据验证约束
- 使用nodejs对Marketing Cloud的contact主数据进行修改操作
- 遍历vue里面的数据。得到的数组。多了个后缀__ob__: Observer怎么处理?
- BigData:大数据开发的简介、核心知识(linux基础+Java/Python编程语言+Hadoop{HDFS、HBase、Hive}+Docker)、经典场景应用之详细攻略
- Python爬虫案例:下载文章数据,转制成PDF格式
- 语音识别数据加载以及图谱
- 031:vue+openlayers加载GPX数据(代码示例)
- leaflet加载Geotiff数据,并在地图上显示(101)
- leaflet 加载WKT数据(示例代码050)
- 时间序列数据库——索引用ES、聚合分析时加载数据用什么?docvalues的列存储貌似更优优势一些。那分布式计算呢?ES做
- (02)Cartographer源码无死角解析-(03) 新数据运行与地图保存、加载地图启动仅定位模式
- 动作识别0-09:mmaction2(SlowFast)-源码无死角解析(5)-数据加载,预处理2
- Python实现聚类分析和数据降维
- 数据治理为何突然成了流量明星?
- 基于深度学习的三维重建网络PatchMatchNet(二):dtu数据集介绍及PatchMatchNet中加载数据部分代码解析