TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线
2023-09-14 09:04:51 时间
TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线
目录
输出结果
设计代码
import tensorflow as tf
from sklearn.datasets import load_digits
#from sklearn.cross_validation import train_test_split
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
# load data
digits = load_digits() X = digits.data
y = digits.target
y = LabelBinarizer().fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3)
def add_layer(inputs, in_size, out_size, layer_name, activation_function=None, ):
# add one more layer and return the output of this layer
Weights = tf.Variable(tf.random_normal([in_size, out_size]))
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1, )
Wx_plus_b = tf.matmul(inputs, Weights) + biases
# here to dropout
Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob)
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b, )
tf.summary.histogram(layer_name + '/outputs', outputs)
return outputs
# define placeholder for inputs to network
keep_prob = tf.placeholder(tf.float32)
xs = tf.placeholder(tf.float32, [None, 64])
ys = tf.placeholder(tf.float32, [None, 10])
# add output layer
l1 = add_layer(xs, 64, 50, 'l1', activation_function=tf.nn.tanh)
prediction = add_layer(l1, 50, 10, 'l2', activation_function=tf.nn.softmax)
# the loss between prediction and real data
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
reduction_indices=[1]))
tf.summary.scalar ('loss', cross_entropy)
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
sess = tf.Session()
merged = tf.summary.merge_all()
# summary writer goes in here
train_writer = tf.summary.FileWriter("logs4/train", sess.graph)
test_writer = tf.summary.FileWriter("logs4/test", sess.graph)
sess.run(tf.global_variables_initializer())
for i in range(500):
# here to determine the keeping probability
sess.run(train_step, feed_dict={xs: X_train, ys: y_train, keep_prob: 0.5})
if i % 50 == 0:
# record loss
train_result = sess.run(merged, feed_dict={xs: X_train, ys: y_train, keep_prob: 1})
test_result = sess.run(merged, feed_dict={xs: X_test, ys: y_test, keep_prob: 1})
train_writer.add_summary(train_result, i)
test_writer.add_summary(test_result, i)
相关文章
TF:利用sklearn自带数据集使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线
相关文章
- 2021电赛F题智能送药小车方案分析(openMV数字识别,红线循迹,STM32HAL库freeRTOS,串级PID快速学习,小车自动返回)[通俗易懂]
- python屏幕文字识别_python识别图片文字
- python截图识别文字_Python文字截图识别OCR工具实例解析
- 语音识别系列︱利用达摩院ModelScope进行语音识别+标点修复(四)
- js自动生成二维码_jquery 生成二维码无法识别
- 基于深度学习的【木板】表面缺陷检测与识别
- [CVPR | 论文简读] 基于双交叉注意学习的细粒度视觉分类和对象再识别
- 内置AI算法的智能分析网关,如何将智能识别技术应用到生活场景中?
- 解决:无法将“php”项识别为 cmdlet、函数、脚本文件或可运行程序的名称
- BioRxiv|PointVS:识别重要的蛋白质-药物关联的机器学习打分函数
- PHP批量识别Nginx网站日志内的百度真假爬虫记录
- keras图片数字识别入门AI机器学习
- 洞察 | 联邦学习、同态加密、模糊提取器?隐私保护增强的新一代生物识别技术了解一下
- 全面解析反欺诈(羊毛盾)API,助你识别各类欺诈风险
- TensorFlow学习笔记(三)MNIST数字识别问题详解大数据
- Linux系统是怎样识别硬盘设备和硬盘分区的?
- 设备Linux如何识别外部存储设备(linux识别存储)
- 从引擎声预测车辆故障!深度学习应用于通用声音识别
- 如何在Linux系统中正确识别网卡?(linux网卡识别)
- Linux无法识别U盘:解决之道(linux无法识别u盘)
- Linux如何识别USB设备解决你的设备无法使用问题(linux识别usb)
- Oracle主键无法识别排查及解决办法(oracle主键无法识别)
- Oracle RAW主键用于唯一识别表行(oracle raw主键)
- phpfeof用来识别文件末尾字符的方法