TF之CNN:基于CIFAR-10数据集训练、检测CNN(2+2)模型(TensorBoard可视化)
TF之CNN:基于CIFAR-10数据集训练、检测CNN(2+2)模型(TensorBoard可视化)
目录
1、基于CIFAR-10数据集训练CNN(2+2)模型代码
from datetime import datetime
import time
import tensorflow as tf
import cifar10
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
"""Directory where to write event logs """
"""and checkpoint.""") #写入事件日志和检查点的目录
tf.app.flags.DEFINE_integer('max_steps', 1000000,
"""Number of batches to run.""") #要运行的批次数
tf.app.flags.DEFINE_boolean('log_device_placement', False,
"""Whether to log device placement.""") #是否记录设备放置
tf.app.flags.DEFINE_integer('log_frequency', 10,
"""How often to log results to the console.""") #将结果记录到控制台的频率
def train():
"""Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step() #tf.contrib.framework.get_or_create_global_step()
# Get images and labels for CIFAR-10.
images, labels = cifar10.distorted_inputs()
# Build a Graph that computes the logits predictions from the
# inference model.
logits = cifar10.inference(images)
# Calculate loss.
loss = cifar10.loss(logits, labels)
# Build a Graph that trains the model with one batch of examples and
# updates the model parameters.
train_op = cifar10.train(loss, global_step)
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir, #FLAGS.train_dir,写入事件日志和检查点的目录
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), #FLAGS.max_steps,要运行的批次数
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess: #Whether to log device placement
while not mon_sess.should_stop():
mon_sess.run(train_op)
def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
train()
if __name__ == '__main__':
FLAGS.train_dir='cifarlO_train/'
FLAGS.max_steps='1000000'
FLAGS.log_device_placement='False'
FLAGS.log_frequency='10'
tf.app.run()
控制台输出结果
Filling queue with 20000 CIFAR images before starting to train. This will take a few minutes.
2018-09-21 11:15:53.399945: step 0, loss = 4.67 (0.7 examples/sec; 177.888 sec/batch)
2018-09-21 11:17:13.770461: step 10, loss = 4.62 (15.9 examples/sec; 8.037 sec/batch)
2018-09-21 11:19:10.122213: step 20, loss = 4.36 (11.0 examples/sec; 11.635 sec/batch)
2018-09-21 11:21:01.145664: step 30, loss = 4.34 (11.5 examples/sec; 11.102 sec/batch)
2018-09-21 11:22:55.463296: step 40, loss = 4.37 (11.2 examples/sec; 11.432 sec/batch)
2018-09-21 11:24:43.938444: step 50, loss = 4.45 (11.8 examples/sec; 10.848 sec/batch)
2018-09-21 11:26:36.091383: step 60, loss = 4.29 (11.4 examples/sec; 11.215 sec/batch)
2018-09-21 11:28:27.229967: step 70, loss = 4.12 (11.5 examples/sec; 11.114 sec/batch)
2018-09-21 11:30:24.759522: step 80, loss = 4.04 (10.9 examples/sec; 11.753 sec/batch)
2018-09-21 11:32:04.392507: step 90, loss = 4.14 (12.8 examples/sec; 9.963 sec/batch)
2018-09-21 11:33:50.161788: step 100, loss = 4.08 (12.1 examples/sec; 10.577 sec/batch)
2018-09-21 11:35:27.867156: step 110, loss = 4.05 (13.1 examples/sec; 9.771 sec/batch)
2018-09-21 11:36:59.189017: step 120, loss = 3.99 (14.0 examples/sec; 9.132 sec/batch)
2018-09-21 11:38:44.246431: step 130, loss = 3.93 (12.2 examples/sec; 10.506 sec/batch)
2018-09-21 11:40:27.267226: step 140, loss = 4.12 (12.4 examples/sec; 10.302 sec/batch)
2018-09-21 11:42:20.492360: step 150, loss = 3.94 (11.3 examples/sec; 11.323 sec/batch)
2018-09-21 11:44:05.324174: step 160, loss = 3.93 (12.2 examples/sec; 10.483 sec/batch)
2018-09-21 11:45:45.123575: step 170, loss = 3.80 (12.8 examples/sec; 9.980 sec/batch)
2018-09-21 11:47:31.441841: step 180, loss = 3.95 (12.0 examples/sec; 10.632 sec/batch)
2018-09-21 11:49:19.129222: step 190, loss = 3.90 (11.9 examples/sec; 10.769 sec/batch)
2018-09-21 11:50:58.325049: step 200, loss = 4.15 (12.9 examples/sec; 9.920 sec/batch)
2018-09-21 11:52:34.784594: step 210, loss = 3.92 (13.3 examples/sec; 9.646 sec/batch)
2018-09-21 11:54:32.453522: step 220, loss = 3.81 (10.9 examples/sec; 11.767 sec/batch)
2018-09-21 11:56:33.002429: step 230, loss = 3.87 (10.6 examples/sec; 12.055 sec/batch)
2018-09-21 11:58:19.417427: step 240, loss = 3.67 (12.0 examples/sec; 10.641 sec/batch)
2、检测CNN(2+2)模型
检测模型在CIFAR-10 测试数据集上的准确性,实际上到6万步左右时, 模型就有了85.99%的准确率,到10万步时的准确率为86.38%,到15万步后的准确率基本稳定在86.66%左右。
3、TensorBoard查看损失的变化曲线
相关文章
- Java实现 蓝桥杯VIP 算法训练求先序排列
- 深入云原生 AI:基于 Alluxio 数据缓存的大规模深度学习训练性能优化
- 机器学习笔记 - 在IdenProf数据集上训练深度残差网络ResNet50
- CV之NS之CycleGAN:基于apple2orange数据集利用TF框架的CycleGAN算法实现图像风格迁移/图像转换—训练&测试过程图文教程全记录
- DL之CNN:基于CRNN_OCR算法(keras,CNN+RNN)利用数据集(torch,mdb格式)训练来实现新图片上不定长度字符串进行识别—预测过程
- ML之xgboost:利用xgboost算法(结合sklearn)训练mushroom蘑菇数据集(22+1,6513+1611)来预测蘑菇是否毒性(二分类预测)
- NLP之NB:基于sklearn库利用不同语种数据集训练NB(朴素贝叶斯)算法,对新语种进行语种检测
- DL之RNN:人工智能为你写小说——基于TF利用RNN算法训练数据集(William Shakespeare的《Coriolanus》)替代你写英语小说短文、训练&测试过程全记录
- ML之xgboost:基于xgboost(5f-CrVa)算法对HiggsBoson数据集(Kaggle竞赛)训练实现二分类预测(基于训练好的模型进行新数据预测)
- ML之xgboost&GBM:基于xgboost&GBM算法对HiggsBoson数据集(Kaggle竞赛)训练(两模型性能PK)实现二分类预测
- DL之DCGNN:基于TF利用DCGAN实现在MNIST数据集上训练生成新样本
- TF之pix2pix之dataset:基于TF利用自己的数据集训练pix2pix模型之DIY自己的数据集
- DL之DNN:自定义MultiLayerNet【6*100+ReLU,SGD】对MNIST数据集训练进而比较【多个超参数组合最优化】性能
- 训练大规模语音数据集后的结果分析报告
- sklearn中train_test_split详解(数据集划分为训练集与测试集)
- 目标检测入门开篇:YOLOv3快速训练高速路段车辆数据
- 基于深度学习和迁移学习的识花实践——利用 VGG16 的深度网络结构中的五轮卷积网络层和池化层,对每张图片得到一个 4096 维的特征向量,然后我们直接用这个特征向量替代原来的图片,再加若干层全连接的神经网络,对花朵数据集进行训练(属于模型迁移)
- Keras之DNN:利用DNN算法【Input(8)→12+8(relu)→O(sigmoid)】利用糖尿病数据集训练、评估模型(利用糖尿病数据集中的八个参数特征预测一个0或1结果)
- DCGAN训练人脸照片,pytorch
- mmdetection2的使用教程从数据处理、配置文件到训练与测试(支持coco数据和pascal_voc数据)