zl程序教程

您现在的位置是:首页 >  其他

当前栏目

基于VGG对五种类别图片的迁移学习

迁移学习 基于 图片 五种 类别
2023-09-11 14:19:30 时间

数据集的介绍

在这里插入图片描述
分为训练集和测试集两个部分,每个部分都包含5个类别的数据,分别为汽车、恐龙、大象、花以及马。

代码实现

主要分为以下的五个步骤:

  1. 读取本地的图片数据及类别
  2. VGG模型结构的修改(添加自定义的分类层)
  3. freeze掉原始的VGG模型
  4. 编译、训练并保存模型

相关包的导入

import numpy as np
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator,load_img,img_to_array
from tensorflow.python.keras.applications.vgg16 import VGG16,preprocess_input
from tensorflow.python.keras import layers
from tensorflow.python.keras.optimizers import adam_v2
from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python.keras.callbacks import ModelCheckpoint

读取本地的图片数据及类别

class TransferModel(object):

    def __init__(self):
        self.train_dir = './data/train'
        self.test_dir = './data/test'
        self.model_size = (224,224)
        self.batch_size = 32

        self.train_generator = ImageDataGenerator(rescale=1.0/255.0)
        self.test_generator = ImageDataGenerator(rescale=1.0/255.0)
        self.base_model = VGG16(include_top=False)

    def get_local_data(self):
        """
        读取本地的图片数据以及类别
        :return:训练数据和测试数据的迭代器
        """
        train_gen = self.train_generator.flow_from_directory(directory=self.train_dir,
                                                             target_size=self.model_size,
                                                             batch_size=self.batch_size,
                                                             class_mode='binary',
                                                             shuffle=True)
        test_gen = self.test_generator.flow_from_directory(directory=self.test_dir,
                                                           target_size=self.model_size,
                                                           batch_size=self.batch_size,
                                                           class_mode='binary',
                                                           shuffle=True)
        return train_gen,test_gen

对train_gen以及test_gen进行打印,可以得到以下的结果:

Found 400 images belonging to 5 classes.
Found 100 images belonging to 5 classes.
<tensorflow.python.keras.preprocessing.image.DirectoryIterator object at 0x000001AB1BD65E80>
<tensorflow.python.keras.preprocessing.image.DirectoryIterator object at 0x000001AB1BD65520>

VGG模型结构的修改

    def refine_vgg_model(self):
        x = self.base_model.outputs[0]
		
		# 采用GlobalAveragePooling2D减少模型的参数
        x = layers.GlobalAveragePooling2D()(x)
        x = layers.Dense(1024,activation=tf.nn.relu)(x)
        y_predict = layers.Dense(5,activation=tf.nn.softmax)(x)

        model = keras.Model(inputs=self.base_model.inputs,outputs=y_predict)


        return model

先是对VGG_nontop模型的输出进行GlobalAveragePooling2D减少全连接的参数,然后自定义构建两个全连接层,获得新的模型。

freeze掉原始的VGG模型参数

    def freeze_vgg_model(self):
        for layer in self.base_model.layers:
            layer.trainable = False

编译、训练并保存模型

    def compile(self,model):
        model.compile(optimizer=adam_v2.Adam(),
                      loss=sparse_categorical_crossentropy,
                      metrics=['accuracy'])

    def fit(self,model,train_gen,test_gen):
        check = ModelCheckpoint('./ckpt/transfer_{epoch:02d}-{val_accuracy:.2f}.h5',
                                monitor='val_accuracy',
                                save_best_only=True,
                                save_weights_only=True,
                                mode='auto',
                                period=1)
        model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[check])

主函数

if __name__ == '__main__':
    tm = TransferModel()
    train_gen,test_gen = tm.get_local_data()

    model = tm.refine_vgg_model()
    # print(tm.refine_vgg_model().summary())
    tm.freeze_vgg_model()
    tm.compile(model)
    tm.fit(model,train_gen,test_gen)

训练结束后将得到以下文件:
在这里插入图片描述

模型预测

    def predict(self,model):
        model.load_weights('./ckpt/transfer_02-0.93.h5')

        image = load_img('./data/test/bus/300.jpg',target_size=(224,224))
        # print(image)
        image = img_to_array(image)
        # print("图片的形状:", image.shape)

        # 形状从3维度修改成4维
        img = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
        # print("改变形状结果:", img.shape)

        # 3、处理图像内容,归一化处理等,进行预测
        img = preprocess_input(img)
        print(img.shape)
        y_predict = model.predict(img)
        index = np.argmax(y_predict, axis=1)
        #
        print(self.label_dict[str(index[0])])

预测结果如下:
在这里插入图片描述