图像分类-flower_photos 实验研究
数据集: flower_photos
- daisy: 633张图片 雏菊
- dandelion: 898张图片 蒲公英
- roses: 641张图片 玫瑰
- sunflowers: 699张图片 向日葵
- tulips: 799张图片 郁金香
数据存储在本地磁盘,读取用的是 tf.keras.preprocessing.image_dataset_from_directory(),其中的 image_size 用作 image resize,batch_size 用作 batch
(福利推荐:阿里云、腾讯云、华为云服务器最新限时优惠活动,云服务器1核2G仅88元/年、2核4G仅698元/3年,点击这里立即抢购>>>)
最后的 train_ds = train_ds.shuffle().cache().prefetch(),这样做的目的是减少 IO blocking
下面是模型搭建的代码:
model = tf.keras.Sequential([ tf.keras.layers.experimental.preprocessing.Rescaling(1. / 255), tf.keras.layers.Conv2D(32, 3, activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Conv2D(32, 3, activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Conv2D(32, 3, activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(num_classes) ])
此处把 pixel 的 rescale 放进了 Sequential,当做模型搭建的一部分,有利于模型部署
callbacks 里面用到了 ReduceLROnPlateau,Tensorboard,EarlyStopping
上图是训练模型四次的 log 记录图,其中 val_acc 的区间在 [0.6499, 0.6785],这个是正常现象,所以训练出来的模型准确率是会存在波动的
代码地址: https://github.com/MaoXianXin/Tensorflow_tutorial,但是需要在如上图的地方 flower dataset 这个 commit 处开一个新分支,然后找到 3.py 这个脚本,就能重复上图的实验了
因为上面的实验,准确率才 [0.6499, 0.6785],我们需要进行优化,第一个改进是添加 data augmentation,此处我们直接在模型搭建环节添加,代码如下所示
model = tf.keras.Sequential([ tf.keras.layers.experimental.preprocessing.Rescaling(1. / 255), augmentation_dict[args.key], tf.keras.layers.Conv2D(32, 3, activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Conv2D(32, 3, activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Conv2D(32, 3, activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(num_classes) ])
augmentation_dict[args.key],这个就是添加的 data augmentation,此处我们只添加单种,具体 data augmentation 种类如下所示
augmentation_dict = { 'RandomFlip': tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"), 'RandomRotation': tf.keras.layers.experimental.preprocessing.RandomRotation(0.2), 'RandomContrast': tf.keras.layers.experimental.preprocessing.RandomContrast(0.2), 'RandomZoom': tf.keras.layers.experimental.preprocessing.RandomZoom(height_factor=0.1, width_factor=0.1), 'RandomTranslation': tf.keras.layers.experimental.preprocessing.RandomTranslation(height_factor=0.1, width_factor=0.1), 'RandomCrop': tf.keras.layers.experimental.preprocessing.RandomCrop(img_height, img_width), 'RandomFlip_prob': RandomFlip_prob("horizontal_and_vertical"), 'RandomRotation_prob': RandomRotation_prob(0.2), 'RandomTranslation_prob': RandomTranslation_prob(height_factor=0.1, width_factor=0.1), }
接下来我们看下实验结果的 log 记录图
可以看到,val_acc 大于 0.6785 (未添加数据增强) 的有 RandomTranslation > RandomRotation_prob > RandomRotation > RandomFlip_prob = RandomFlip > RandomZoom > RandomTranslation_prob
从结果看来,数据增强是有效的,接下来我们进行第二个改进,更换更强的网络模型,我们这里选择 MobileNetV2
这里我们分两种情况进行实验,第一种是把 MobileNetV2 当做 feature extraction 来使用,这个要求我们 freeze 模型的 卷积部分,只训练添加进去的 top-classifier 部分,下面上代码
data_augmentation = tf.keras.Sequential([ augmentation_dict[args.key], ]) preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input base_model = tf.keras.applications.MobileNetV2(input_shape=img_size, include_top=False, weights='imagenet') base_model.trainable = False inputs = tf.keras.Input(shape=img_size) x = data_augmentation(inputs) x = preprocess_input(x) x = base_model(x, training=False) x = tf.keras.layers.GlobalAveragePooling2D()(x) x = tf.keras.layers.Dropout(0.2)(x) outputs = tf.keras.layers.Dense(num_classes)(x) model = tf.keras.Model(inputs, outputs) print(model.summary())
实验结果如下图所示
可以看到,准确率提升很显著,从数据增强的 0.7316 提升到了 0.8937,这主要得益于 pre-trained model 是在 ImageNet 大数据集上做过训练,提取到的特征泛化性更好
为了进一步提升模型的准确率,我们采用第二种方式,对 pre-trained model 做 fine-tune,就是在第一种方式的基础上,我们 unfreeze 部分卷积层,因为浅层的卷积提取的特征都是很基础的特征,意味着很通用,但是深层的卷积提取的特征都是和数据集高度相关的,这里我们要解决的是 flower_photos,所以可以对深层的一部分卷积做训练,以进一步提高模型的准确率
下面上代码
data_augmentation = tf.keras.Sequential([ augmentation_dict[args.key], ]) preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input base_model = tf.keras.applications.MobileNetV2(input_shape=img_size, include_top=False, weights='imagenet') base_model.trainable = True # Let's take a look to see how many layers are in the base model print("Number of layers in the base model: ", len(base_model.layers)) # Fine-tune from this layer onwards fine_tune_at = 100 # Freeze all the layers before the `fine_tune_at` layer for layer in base_model.layers[:fine_tune_at]: layer.trainable = False inputs = tf.keras.Input(shape=img_size) x = data_augmentation(inputs) x = preprocess_input(x) x = base_model(x, training=False) x = tf.keras.layers.GlobalAveragePooling2D()(x) x = tf.keras.layers.Dropout(0.2)(x) outputs = tf.keras.layers.Dense(num_classes)(x) model = tf.keras.Model(inputs, outputs) model.load_weights('./save_models') print(model.summary()) optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4) model.compile( optimizer=optimizer, loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) K.set_value(model.optimizer.learning_rate, 1e-4)
这里有个特别需要注意的地方是 learning_rate 的设置,K.set_value(model.optimizer.learning_rate, 1e-4),这个地方还是我特地查看了下 learning_rate 的 log 记录图才发现的不对劲
可以看到,进行 fine-tune 的话,模型准确率进一步提升,从 0.8937 —> 0.9482
到此为止,我们实现了在 flower_photos 数据集上 val_acc = 0.9482,下一步可能会用 RandAugment 或者 Semi-supervised 来提升模型的泛化能力
代码地址: https://github.com/MaoXianXin/Tensorflow_tutorial
你还在原价购买阿里云、腾讯云、华为云、天翼云产品?那就亏大啦!现在申请成为四大品牌云厂商VIP用户,可以3折优惠价购买云服务器等云产品,并且可享四大云服务商产品终身VIP优惠价,还等什么?赶紧点击下面对应链接免费申请VIP客户吧:
相关文章
- 学生数据库管理系统
- SpringDataJpa 用MySQL语句怎么分页,spring全家桶SpringDataJpa 用MySQL语句怎么分页
- Docker创建MySQL容器模板命令
- Elasticsearch对应MySQL的对应关系
- 使用SpringDataJpa保存(save)报错误:SQL Error: 1062, SQLState: 23000 控制台会报:Duplicate entry ‘数‘ for key ‘PRIMA
- Navicat Premium 连接sqlserver数据库时提示安装Client失败,解决方案
- Mysql查询当前用户所有数据库语句(SHOW DATABASES)
- MySQL语句-查看当前数据库有哪些表(SHOW TABLES)
- MySQL5.0版本以上新增的 information_schema 数据库是什么?
- MariaDB数据库备份之逻辑备份
- MariaDB数据库创建用户
- MariaDB数据库给用户授权
- MariaDB数据库刷新权限表命令
- MariaDB数据库删除用户命令
- PhpStudy 2016搭建-sqli-libs靶场
- MySQL手动注入步骤
- Pikachu靶场-SQL注入-数字型注入(post)过关步骤
- Pikachu靶场-SQL注入-字符型注入(get)过关步骤
- 利用SQL注入漏洞实现MySQL数据库读写文件
- Kali-工具-sqlmap常见用法