基于VGG对五种类别图片的迁移学习
2023-09-11 14:19:30 时间
数据集的介绍
分为训练集和测试集两个部分,每个部分都包含5个类别的数据,分别为汽车、恐龙、大象、花以及马。
代码实现
主要分为以下的五个步骤:
- 读取本地的图片数据及类别
- VGG模型结构的修改(添加自定义的分类层)
- freeze掉原始的VGG模型
- 编译、训练并保存模型
相关包的导入
import numpy as np
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator,load_img,img_to_array
from tensorflow.python.keras.applications.vgg16 import VGG16,preprocess_input
from tensorflow.python.keras import layers
from tensorflow.python.keras.optimizers import adam_v2
from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python.keras.callbacks import ModelCheckpoint
读取本地的图片数据及类别
class TransferModel(object):
def __init__(self):
self.train_dir = './data/train'
self.test_dir = './data/test'
self.model_size = (224,224)
self.batch_size = 32
self.train_generator = ImageDataGenerator(rescale=1.0/255.0)
self.test_generator = ImageDataGenerator(rescale=1.0/255.0)
self.base_model = VGG16(include_top=False)
def get_local_data(self):
"""
读取本地的图片数据以及类别
:return:训练数据和测试数据的迭代器
"""
train_gen = self.train_generator.flow_from_directory(directory=self.train_dir,
target_size=self.model_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
test_gen = self.test_generator.flow_from_directory(directory=self.test_dir,
target_size=self.model_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
return train_gen,test_gen
对train_gen以及test_gen进行打印,可以得到以下的结果:
Found 400 images belonging to 5 classes.
Found 100 images belonging to 5 classes.
<tensorflow.python.keras.preprocessing.image.DirectoryIterator object at 0x000001AB1BD65E80>
<tensorflow.python.keras.preprocessing.image.DirectoryIterator object at 0x000001AB1BD65520>
VGG模型结构的修改
def refine_vgg_model(self):
x = self.base_model.outputs[0]
# 采用GlobalAveragePooling2D减少模型的参数
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(1024,activation=tf.nn.relu)(x)
y_predict = layers.Dense(5,activation=tf.nn.softmax)(x)
model = keras.Model(inputs=self.base_model.inputs,outputs=y_predict)
return model
先是对VGG_nontop模型的输出进行GlobalAveragePooling2D减少全连接的参数,然后自定义构建两个全连接层,获得新的模型。
freeze掉原始的VGG模型参数
def freeze_vgg_model(self):
for layer in self.base_model.layers:
layer.trainable = False
编译、训练并保存模型
def compile(self,model):
model.compile(optimizer=adam_v2.Adam(),
loss=sparse_categorical_crossentropy,
metrics=['accuracy'])
def fit(self,model,train_gen,test_gen):
check = ModelCheckpoint('./ckpt/transfer_{epoch:02d}-{val_accuracy:.2f}.h5',
monitor='val_accuracy',
save_best_only=True,
save_weights_only=True,
mode='auto',
period=1)
model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[check])
主函数
if __name__ == '__main__':
tm = TransferModel()
train_gen,test_gen = tm.get_local_data()
model = tm.refine_vgg_model()
# print(tm.refine_vgg_model().summary())
tm.freeze_vgg_model()
tm.compile(model)
tm.fit(model,train_gen,test_gen)
训练结束后将得到以下文件:
模型预测
def predict(self,model):
model.load_weights('./ckpt/transfer_02-0.93.h5')
image = load_img('./data/test/bus/300.jpg',target_size=(224,224))
# print(image)
image = img_to_array(image)
# print("图片的形状:", image.shape)
# 形状从3维度修改成4维
img = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
# print("改变形状结果:", img.shape)
# 3、处理图像内容,归一化处理等,进行预测
img = preprocess_input(img)
print(img.shape)
y_predict = model.predict(img)
index = np.argmax(y_predict, axis=1)
#
print(self.label_dict[str(index[0])])
预测结果如下:
相关文章
- 华为鲲鹏分析扫描工具,代码迁移工具,Docker安装
- 学习ASP.NET Core Blazor编程系列四——迁移
- [DeeplearningAI笔记]卷积神经网络2.9-2.10迁移学习与数据增强
- Google Earth Engine——北纬85度和南纬60度之间所有地区到最近的人口密集区的迁移时间数据集
- es库数据迁移
- 关于数据迁移,测试应该做什么?
- 实用操作--迁移到Spring Boot 3 和 Spring 6 需要关注的JAVA新特性
- Kafka的灵魂伴侣Logi-KafkaManger(4)之运维管控–集群运维(数据迁移和集群在线升级)
- 预训练模型迁移学习
- 迁移学习_pytorch简单实战
- 《Core Data应用开发实践指南》一3.5 通过迁移管理器来迁移数据
- (续)使用MindSpore_hub 进行 加载模型用于推理或迁移学习
- 【转载】 迁移学习(Transfer learning),多任务学习(Multitask learning)和端到端学习(End-to-end deep learning)
- PyTorch深度学习实战 | 搭建卷积神经网络进行图像分类与图像风格迁移
- 从Unix到Linux的迁移部署:移植、升级和测试
- 深度学习趣谈:什么是迁移学习?(附带Tensorflow代码实现)
- EF 数据迁移 常见错误