深度学习3 迁移学习分批次保存特征并训练全连接
2023-09-14 09:15:04 时间
1、基于h5py批量保存特征,并使用全局平均池化缩减特征尺寸,减小对内存的需求
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import applications,Model
from tensorflow.keras.layers import Dense,Dropout,Input,GlobalAveragePooling2D
import numpy as np
import os,gc
batch_size=100#ResNet50,VGG16,InceptionV3
shape=(240, 480)
models=["VGG16","VGG19","InceptionV3","ResNet50"]
models=["EfficientNetB0"]
import h5py
import numpy as np
import math,gc
file_name='data_EfficientNetB0_240.h5'
h5f=h5py.File(file_name)
def save_h5(h5f,data,target):
shape_list=list(data.shape)
if not h5f.__contains__(target):
shape_list[0]=None #设置数组的第一个维度是0
dataset = h5f.create_dataset(target, data=data,maxshape=tuple(shape_list), chunks=True)
return
else:
dataset = h5f[target]
len_old=dataset.shape[0]
len_new=len_old+data.shape[0]
shape_list[0]=len_new
dataset.resize(tuple(shape_list)) #修改数组的第一个维度
dataset[len_old:len_new] = data #存入新的文件
for modelname in models:
mstr="applications.%s(include_top=False,input_shape=shape, weights='imagenet')"%modelname
base = eval(mstr)
print(modelname)
model_output = GlobalAveragePooling2D()(base.layers[-1].output)
model= Model(inputs=base.input, outputs=model_output)
#model.summary();import sys;sys.exit()
#from tensorflow.keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(rescale=1.0 / 255)#()
root_path='./features/'
if not os.path.exists(root_path):
os.makedirs(root_path)
data_path='训练集'
dirList=['']
for path in dirList:
generator = datagen.flow_from_directory(data_path+path,
target_size=shape,
batch_size=batch_size,
class_mode='categorical',
shuffle=True)
print(generator.class_indices)#输出数据的labels
step=len(generator.classes)/generator.batch_size
step=math.ceil(step)
for i in range(int(step)):
x,y=generator.next()
if i>0:
features= model.predict(x)
save_h5(h5f,data=np.array(features),target='features')
save_h5(h5f,data=np.array(y),target='label')
print(step,i,y.shape,features.shape)
del x,y
gc.collect()
np.save(open('%s%s_fileName_%s.npy'%(root_path,modelname,path), 'wb'), generator.filenames)
gc.collect()
2、使用自定义的数据迭代器批量加载数据,并训练全连接
from tensorflow.keras.models import load_model
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense,Dropout,Input,GlobalAveragePooling2D
import numpy as np
import h5py
from sklearn.model_selection import train_test_split
def save_h5(h5f,data,target):
shape_list=list(data.shape)
if not h5f.__contains__(target):
shape_list[0]=None
dataset = h5f.create_dataset(target, data=data,maxshape=tuple(shape_list), chunks=True)
return
else:
dataset = h5f[target]
len_old=dataset.shape[0]
len_new=len_old+data.shape[0]
shape_list[0]=len_new
dataset.resize(tuple(shape_list))
dataset[len_old:len_new] = data
def data_gen(dtype='train',rate=0.2,batch_size=500):
file_name='data_EfficientNetB0_240.h5'
h5f=h5py.File(file_name)
while True:
for index in range(0,h5f['features'].shape[0],batch_size):
data=h5f['features'][index:index+batch_size]
i_label=h5f['label'][index:index+batch_size]
X_train, X_test, y_train, y_test =train_test_split(data,i_label,train_size=rate, test_size=1-rate, random_state=15)
if dtype=='train':
yield (X_train, y_train)
else:
yield (X_test, y_test)
from tensorflow.keras.callbacks import ModelCheckpoint,EarlyStopping,History
file_name='data_EfficientNetB0_240.h5'
h5f=h5py.File(file_name)
dshape=h5f['features'].shape
if True:
inputs = Input(shape=(dshape[1:]))
#x = GlobalAveragePooling2D()(inputs)
x = Dense(200, activation='relu')(inputs)
x = Dropout(0.5)(x)
output = Dense(137, activation='softmax')(x)
model3 = Model(inputs=inputs, outputs=output)
model3.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['acc'])
else:
model3=load_model('model-fcn_all_b0.h5')
model3.summary()
if True:
history = History()
model_checkpoint = ModelCheckpoint('model-fcn_all_b0.h5', monitor='val_loss', save_best_only=True)
early_stopping=EarlyStopping(monitor='val_acc', patience=50, verbose=2, mode='auto')
callbacks = [
history,
model_checkpoint,
early_stopping,
]
#model3.summary()
rate=0.9 #train_rate
batch_size=5000
model3.fit_generator(data_gen(dtype='train',rate=rate,batch_size=batch_size),
steps_per_epoch=dshape[0]// batch_size,
epochs=1000,
validation_data=(data_gen(dtype='test',rate=rate,batch_size=batch_size)),
validation_steps=dshape[0]// batch_size,
callbacks=callbacks,
shuffle=True,
verbose=1)
loss,acc=model3.evaluate_generator(data_gen(dtype='test',rate=rate,batch_size=batch_size),steps=dshape[0]// batch_size)
print('loss:',loss,'acc:',acc)
相关文章
- VMWARE 虚拟机导入腾讯云-在线迁移篇
- 【玩转腾讯云】disk2vhd离线迁移服务器
- phpstorm存放在系统盘的大文件迁移
- 细粒度语义理解超越CLIP!华为云与浙大联合发布多模态预训练模型LOUPE,无需人工标注,零样本迁移实现目标检测!
- Oracle让行数据可以迁移(Oracle允许行迁移)
- 深度迁移学习首次应用于宇宙星系研究 对星系分类准确性达99.8%
- 实战经验:Oracle数据库数据迁移实践(oracle数据库数据迁移)
- 解决Mysql海量数据迁移之道(mysql海量数据迁移)
- 本周 Linux 新闻: 俄罗斯宣布可能会将政府电脑系统迁移到 Linux
- DB2 数据库迁移到 MySQL:实现无缝转换的步骤和方法(db2tomysql)
- 从MySQL到SQL Server:无缝迁移教程(mysql迁移mssql)
- Oracle导出助力数据迁移成功(导出oracle)
- ess从MSSQL到ACCESS的数据迁移之路(mssql转accs)
- 从DBF文件数据迁移到Oracle集群(dbf2oracle)
- Oracle OGG升级再进化实现可靠的数据迁移(oracle ogg升级)
- Oracle IMP命令简介对数据库进行快速迁移(oracle imp说明)
- Oracle数据迁移A表到B表(oracle a表到b表)
- Perl访问MSSQL并迁移到MySQL数据库脚本实例