zl程序教程

您现在的位置是:首页 >  工具

当前栏目

深度学习6 基于ImageDataGenerator的数据集划分

学习数据 基于 深度 划分
2023-09-14 09:15:04 时间

基于keras的ImageDataGenerator,在划分数据集时不需要人为的把数据分为两部分,可以通过设置datagen.flow_from_directory的subset='validation'或'training'来实现,其中的seed是随机种子控制读取数据的顺序。

1、面向分类任务的数据集划分

from keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(rescale=1.0/ 255, 
               rotation_range=180, #整数。随机旋转的度数范围[0,180]
               zoom_range=0.1, # 图片可能以比例(-0.1,0.1)内缩放
               width_shift_range=0.05,  # 图片可能在左右比例(-0.05,0.05)(百分比)内水平移动
               height_shift_range=0.05,  # 图片可能在上下比例(-0.05,0.05)(百分比)内水平移动
               fill_mode='nearest', # {"constant", "nearest", "reflect" or "wrap"} 之一。默认为 'nearest'。输入边界以外的点根据给定的模式填充:
               horizontal_flip=True, 
               vertical_flip=True, 
               validation_split=0.2)#该模式下的迭代器可以扩充数据
datagen = ImageDataGenerator(rescale=1.0/ 255, validation_split=0.2)#()#
seed=1
train_gen = datagen.flow_from_directory(data_path,
                                                target_size=shape,
                                                batch_size=batch_size,
                                                seed=seed,
                                                class_mode='categorical',
                                                subset='training',
                                                shuffle=True)
val_gen = datagen.flow_from_directory(data_path,
                                                target_size=shape,
                                                batch_size=batch_size,
                                                seed=seed,
                                                class_mode='categorical',
                                                subset='validation',
                                                shuffle=True)
print(val_gen.class_indices)#输出类别名称对应的编号

2、面向语义分割的数据集划分

语义分割的数据集存在一定特殊情况,即输入输出都是图像,且只有一个类别。因此,其文件的存储方式应该是IMAGE/0/img,MASK/0/img。其中IMAGE和MASK是一级目录,分别代表训练数据和语义分割结果。0是二级目录,纯粹为了占位,但是两个一级目录下的二级目录必须存在,且文件名应该相同。img指具体的文件,在两个一级目录(IMAGE和MASK)下的名称必须保持一样。
因此,在datagen.flow_from_directory中需要设置class_mode=None

import keras
img_size = (512, 512)
batch_size = 40
seed = 1
path='数据-%s'
#  图像数据生成器,可在该函数内实现数据集增强,validation_split=0.2是分割验证集的比例,即验证集占20%,与下面的subset='training'和subset='validation'相对应
datagen = keras.preprocessing.image.ImageDataGenerator(
            rescale=1.0 / 255,
            #featurewise_center=True,
            #featurewise_std_normalization=True,
            validation_split=0.2
)
#  将上面的图像数据生成器应用于训练集图像及标签,得到训练集的生成器
traingen_imgs = datagen.flow_from_directory(
                directory=path%'IMAGE',
                target_size=img_size,
                class_mode=None,
                batch_size=batch_size,
                seed=seed,
                subset='training',
                color_mode='grayscale'
)
traingen_labels = datagen.flow_from_directory(
                directory=path%'MASK',
                target_size=img_size,
                class_mode=None,
                batch_size=batch_size,
                seed=seed,
                subset='training',
                color_mode='grayscale'
)
#  将上面的图像数据生成器应用于验证集图像及标签,得到验证集的生成器
validgen_imgs = datagen.flow_from_directory(
                directory=path%'IMAGE',
                target_size=img_size,
                class_mode=None,
                batch_size=batch_size,
                seed=seed,
                subset='validation',
                color_mode='grayscale'
)
validgen_labels = datagen.flow_from_directory(
                  directory=path%'MASK',
                  target_size=img_size,
                  class_mode=None,
                  batch_size=batch_size,
                  seed=seed,
                  subset='validation',
                  color_mode='grayscale'
)

数据迭代器的使用

def data_gen(image,mask):
    while True:
        yield (image.next(), mask.next())
model.fit_generator(generator=data_gen(traingen_imgs, traingen_labels),    
                    steps_per_epoch=len(traingen_imgs.classes)/traingen_imgs.batch_size,
                    validation_data=data_gen(validgen_imgs, validgen_labels),
                    validation_steps=len(validgen_imgs.classes)/validgen_imgs.batch_size,
                    callbacks=callbacks,
                    epochs=30)