TensorFlow2-实战-手写数字识别(二):模型版【初始化参数】-->【循环(①根据参数W、B通过模型前向计算,计算出输入X对应的输出Y;②计算Loss;③计算梯度;④利用梯度下降来更新参数)】
2023-09-27 14:20:41 时间
import os
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, datasets
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 读取数据集
mnist_dataset = datasets.mnist.load_data()
dataset_train, datasset_val = mnist_dataset[0], mnist_dataset[1]
(x, y) = dataset_train
(x_val, y_val) = datasset_val
print("type(x) = {0}, type(y) = {1}".format(type(x), type(y)))
# 将ndarray数组转为Tensor格式
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
print("x.shape = {0}, x[0] = \n{1}".format(x.shape, pd.DataFrame(x[0].numpy()))) # x.shape = (60000, 28, 28)
y = tf.convert_to_tensor(y, dtype=tf.int32)
print("y.shape = {0}, y = {1}".format(y.shape, y)) # y.shape = (60000,), y = [5 0 4 ... 5 6 8]
# 将y转为one-hot编码
y = tf.one_hot(y, depth=10)
print("将y转为one-hot编码之后:y.shape = {0}, y = \n{1}".format(y.shape, y)) # 将y转为one-hot编码之后:y.shape = (60000, 10)
# 从(x, y)中抽取训练数据集
train_dataset = tf.data.Dataset.from_tensor_slices((x, y))
train_dataset = train_dataset.batch(200)
# 通过keras接口构建模型
model = keras.Sequential([
layers.Dense(512, activation='relu'),
layers.Dense(256, activation='relu'),
layers.Dense(10)])
# 构建优化器
optimizer = optimizers.SGD(learning_rate=0.001)
# 训练一个epoch
def train_epoch(epoch):
# Step4.每个循环一个batch
for batch_idx, (x, y) in enumerate(train_dataset):
# Tensorflow使用梯度带(tf.GradientTape)来记录正向运算过程,然后反向传播自动得到梯度值。
with tf.GradientTape() as tape:
# 将特征值进行维度变换
x = tf.reshape(x, (-1, 28 * 28)) # [b, 28, 28] => [b, 784]
# 步骤一. 根据参数,通过模型前向计算,得到输入特征值对应的输出
out = model(x) # [b, 784] => [b, 10]
# 步骤二. 计算Loss
loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0]
# 步骤三. 计算模型中类型为Variable的参数的梯度
grads = tape.gradient(loss, model.trainable_variables)
# 步骤四. 利用梯度下降来更新参数【optimize and update w1, w2, w3, b1, b2, b3】【w' = w - lr * grad】
optimizer.apply_gradients(zip(grads, model.trainable_variables))
if batch_idx % 100 == 0:
print("epoch = {0}, batch_idx = {1}, loss = {2}".format(epoch, batch_idx, loss.numpy()))
def train():
for epoch in range(30):
train_epoch(epoch)
if __name__ == '__main__':
train()
打印结果:
type(x) = <class 'numpy.ndarray'>, type(y) = <class 'numpy.ndarray'>
x.shape = (60000, 28, 28), x[0] =
0 1 2 3 4 ... 23 24 25 26 27
0 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
1 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
2 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
3 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
4 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
5 0.0 0.0 0.0 0.0 0.000000 ... 0.498039 0.0 0.0 0.0 0.0
6 0.0 0.0 0.0 0.0 0.000000 ... 0.250980 0.0 0.0 0.0 0.0
7 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
8 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
9 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
10 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
11 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
12 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
13 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
14 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
15 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
16 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
17 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
18 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
19 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
20 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
21 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
22 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
23 0.0 0.0 0.0 0.0 0.215686 ... 0.000000 0.0 0.0 0.0 0.0
24 0.0 0.0 0.0 0.0 0.533333 ... 0.000000 0.0 0.0 0.0 0.0
25 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
26 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
27 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.0 0.0 0.0 0.0
[28 rows x 28 columns]
y.shape = (60000,), y = [5 0 4 ... 5 6 8]
将y转为one-hot编码之后:y.shape = (60000, 10), y =
[[0. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 1. 0.]]
epoch = 0, batch_idx = 0, loss = 2.328511953353882
epoch = 0, batch_idx = 100, loss = 1.009535312652588
epoch = 0, batch_idx = 200, loss = 0.8182420134544373
epoch = 1, batch_idx = 0, loss = 0.7303786277770996
epoch = 1, batch_idx = 100, loss = 0.710280179977417
epoch = 1, batch_idx = 200, loss = 0.6227332353591919
epoch = 2, batch_idx = 0, loss = 0.5902734994888306
epoch = 2, batch_idx = 100, loss = 0.6123036742210388
epoch = 2, batch_idx = 200, loss = 0.5449265837669373
epoch = 3, batch_idx = 0, loss = 0.5221412181854248
epoch = 3, batch_idx = 100, loss = 0.5586401224136353
epoch = 3, batch_idx = 200, loss = 0.49894028902053833
epoch = 4, batch_idx = 0, loss = 0.47923576831817627
epoch = 4, batch_idx = 100, loss = 0.5219769477844238
epoch = 4, batch_idx = 200, loss = 0.46610695123672485
epoch = 5, batch_idx = 0, loss = 0.44889944791793823
epoch = 5, batch_idx = 100, loss = 0.49393826723098755
epoch = 5, batch_idx = 200, loss = 0.44056564569473267
epoch = 6, batch_idx = 0, loss = 0.4255021810531616
epoch = 6, batch_idx = 100, loss = 0.4713611900806427
epoch = 6, batch_idx = 200, loss = 0.41985756158828735
epoch = 7, batch_idx = 0, loss = 0.40636977553367615
epoch = 7, batch_idx = 100, loss = 0.45304715633392334
epoch = 7, batch_idx = 200, loss = 0.40254470705986023
epoch = 8, batch_idx = 0, loss = 0.39018934965133667
epoch = 8, batch_idx = 100, loss = 0.4377802312374115
epoch = 8, batch_idx = 200, loss = 0.38769394159317017
epoch = 9, batch_idx = 0, loss = 0.3761257529258728
epoch = 9, batch_idx = 100, loss = 0.42440712451934814
epoch = 9, batch_idx = 200, loss = 0.37483566999435425
epoch = 10, batch_idx = 0, loss = 0.3637397885322571
epoch = 10, batch_idx = 100, loss = 0.41257765889167786
epoch = 10, batch_idx = 200, loss = 0.3634665012359619
epoch = 11, batch_idx = 0, loss = 0.35271918773651123
epoch = 11, batch_idx = 100, loss = 0.4018203616142273
epoch = 11, batch_idx = 200, loss = 0.3533402979373932
epoch = 12, batch_idx = 0, loss = 0.3428434729576111
epoch = 12, batch_idx = 100, loss = 0.39214104413986206
epoch = 12, batch_idx = 200, loss = 0.3441687822341919
epoch = 13, batch_idx = 0, loss = 0.333955317735672
epoch = 13, batch_idx = 100, loss = 0.3832662105560303
epoch = 13, batch_idx = 200, loss = 0.33575889468193054
epoch = 14, batch_idx = 0, loss = 0.3258323669433594
epoch = 14, batch_idx = 100, loss = 0.3750748336315155
epoch = 14, batch_idx = 200, loss = 0.32803112268447876
epoch = 15, batch_idx = 0, loss = 0.3183256685733795
epoch = 15, batch_idx = 100, loss = 0.3674846291542053
epoch = 15, batch_idx = 200, loss = 0.32090339064598083
epoch = 16, batch_idx = 0, loss = 0.31145304441452026
epoch = 16, batch_idx = 100, loss = 0.3604132831096649
epoch = 16, batch_idx = 200, loss = 0.3143956661224365
epoch = 17, batch_idx = 0, loss = 0.3050137758255005
epoch = 17, batch_idx = 100, loss = 0.3537922203540802
epoch = 17, batch_idx = 200, loss = 0.30833786725997925
epoch = 18, batch_idx = 0, loss = 0.2989848256111145
epoch = 18, batch_idx = 100, loss = 0.3475683629512787
epoch = 18, batch_idx = 200, loss = 0.30268821120262146
epoch = 19, batch_idx = 0, loss = 0.2933279275894165
epoch = 19, batch_idx = 100, loss = 0.3417797088623047
epoch = 19, batch_idx = 200, loss = 0.29743343591690063
epoch = 20, batch_idx = 0, loss = 0.2880209684371948
epoch = 20, batch_idx = 100, loss = 0.33625274896621704
epoch = 20, batch_idx = 200, loss = 0.2924404740333557
epoch = 21, batch_idx = 0, loss = 0.2830372750759125
epoch = 21, batch_idx = 100, loss = 0.3309747576713562
epoch = 21, batch_idx = 200, loss = 0.2876732051372528
epoch = 22, batch_idx = 0, loss = 0.2783474326133728
epoch = 22, batch_idx = 100, loss = 0.32599449157714844
epoch = 22, batch_idx = 200, loss = 0.2831261456012726
epoch = 23, batch_idx = 0, loss = 0.27390143275260925
epoch = 23, batch_idx = 100, loss = 0.3213098645210266
epoch = 23, batch_idx = 200, loss = 0.27885332703590393
epoch = 24, batch_idx = 0, loss = 0.2696850299835205
epoch = 24, batch_idx = 100, loss = 0.3168952465057373
epoch = 24, batch_idx = 200, loss = 0.2747831344604492
epoch = 25, batch_idx = 0, loss = 0.2656662166118622
epoch = 25, batch_idx = 100, loss = 0.3127139210700989
epoch = 25, batch_idx = 200, loss = 0.2708702087402344
epoch = 26, batch_idx = 0, loss = 0.2618401050567627
epoch = 26, batch_idx = 100, loss = 0.3087657690048218
epoch = 26, batch_idx = 200, loss = 0.26713475584983826
epoch = 27, batch_idx = 0, loss = 0.25819507241249084
epoch = 27, batch_idx = 100, loss = 0.30497828125953674
epoch = 27, batch_idx = 200, loss = 0.26355743408203125
epoch = 28, batch_idx = 0, loss = 0.2547600567340851
epoch = 28, batch_idx = 100, loss = 0.30135565996170044
epoch = 28, batch_idx = 200, loss = 0.2601349949836731
epoch = 29, batch_idx = 0, loss = 0.251463383436203
epoch = 29, batch_idx = 100, loss = 0.2978888750076294
epoch = 29, batch_idx = 200, loss = 0.25687384605407715
Process finished with exit code 0
相关文章
- MNIST数据集合在PaddlePaddle环境下使用简单神经网络识别效果
- 基于BP神经网络的调制方式识别算法MATLAB仿真,识别不同SNR下的MFSK和MPSK
- 分别使用BP/RBF/GRNN神经网络识别航迹异常matlab仿真
- C#,基于视频的目标识别算法(Moving Object Detection)的原理、挑战及其应用
- C++ OpenCV 图像转换,识别图像轮廓,画矩形
- 语音识别六十年
- Android | 教你如何在安卓上实现通用卡证识别,一键各种卡绑定
- Qt数据库应用22-文件编码格式识别
- 新手指导:教你如何查看识别hadoop是32位还是64位
- 简单、直观的实现优于复杂、难懂的实现,最近开发扑克识别过程的总结
- QTP——使用DOM识别树形节点进行Web测试