zl程序教程

您现在的位置是:首页 >  云平台

当前栏目

TensorFlow2-实战-手写数字识别(二):模型版【初始化参数】-->【循环(①根据参数W、B通过模型前向计算,计算出输入X对应的输出Y;②计算Loss;③计算梯度;④利用梯度下降来更新参数)】

2023-09-27 14:20:41 时间
import os
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, datasets

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# 读取数据集
mnist_dataset = datasets.mnist.load_data()
dataset_train, datasset_val = mnist_dataset[0], mnist_dataset[1]
(x, y) = dataset_train
(x_val, y_val) = datasset_val
print("type(x) = {0}, type(y) = {1}".format(type(x), type(y)))

# 将ndarray数组转为Tensor格式
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
print("x.shape = {0}, x[0] = \n{1}".format(x.shape, pd.DataFrame(x[0].numpy())))  # x.shape =  (60000, 28, 28)
y = tf.convert_to_tensor(y, dtype=tf.int32)
print("y.shape = {0}, y = {1}".format(y.shape, y))  # y.shape = (60000,), y = [5 0 4 ... 5 6 8]

# 将y转为one-hot编码
y = tf.one_hot(y, depth=10)
print("将y转为one-hot编码之后:y.shape = {0}, y = \n{1}".format(y.shape, y))  # 将y转为one-hot编码之后:y.shape = (60000, 10)

# 从(x, y)中抽取训练数据集
train_dataset = tf.data.Dataset.from_tensor_slices((x, y))
train_dataset = train_dataset.batch(200)

# 通过keras接口构建模型
model = keras.Sequential([
    layers.Dense(512, activation='relu'),
    layers.Dense(256, activation='relu'),
    layers.Dense(10)])

# 构建优化器
optimizer = optimizers.SGD(learning_rate=0.001)


# 训练一个epoch
def train_epoch(epoch):
    # Step4.每个循环一个batch
    for batch_idx, (x, y) in enumerate(train_dataset):
        # Tensorflow使用梯度带(tf.GradientTape)来记录正向运算过程,然后反向传播自动得到梯度值。
        with tf.GradientTape() as tape:
            # 将特征值进行维度变换
            x = tf.reshape(x, (-1, 28 * 28))  # [b, 28, 28] => [b, 784]
            # 步骤一. 根据参数,通过模型前向计算,得到输入特征值对应的输出
            out = model(x)  # [b, 784] => [b, 10]
            # 步骤二. 计算Loss
            loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0]
        # 步骤三. 计算模型中类型为Variable的参数的梯度
        grads = tape.gradient(loss, model.trainable_variables)
        # 步骤四. 利用梯度下降来更新参数【optimize and update w1, w2, w3, b1, b2, b3】【w' = w - lr * grad】
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        if batch_idx % 100 == 0:
            print("epoch = {0}, batch_idx = {1}, loss = {2}".format(epoch, batch_idx, loss.numpy()))


def train():
    for epoch in range(30):
        train_epoch(epoch)


if __name__ == '__main__':
    train()

打印结果:

type(x) = <class 'numpy.ndarray'>, type(y) = <class 'numpy.ndarray'>
x.shape = (60000, 28, 28), x[0] = 
     0    1    2    3         4  ...         23   24   25   26   27
0   0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
1   0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
2   0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
3   0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
4   0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
5   0.0  0.0  0.0  0.0  0.000000 ...   0.498039  0.0  0.0  0.0  0.0
6   0.0  0.0  0.0  0.0  0.000000 ...   0.250980  0.0  0.0  0.0  0.0
7   0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
8   0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
9   0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
10  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
11  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
12  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
13  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
14  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
15  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
16  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
17  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
18  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
19  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
20  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
21  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
22  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
23  0.0  0.0  0.0  0.0  0.215686 ...   0.000000  0.0  0.0  0.0  0.0
24  0.0  0.0  0.0  0.0  0.533333 ...   0.000000  0.0  0.0  0.0  0.0
25  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
26  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0
27  0.0  0.0  0.0  0.0  0.000000 ...   0.000000  0.0  0.0  0.0  0.0

[28 rows x 28 columns]
y.shape = (60000,), y = [5 0 4 ... 5 6 8]
将y转为one-hot编码之后:y.shape = (60000, 10), y = 
[[0. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]]
epoch = 0, batch_idx = 0, loss = 2.328511953353882
epoch = 0, batch_idx = 100, loss = 1.009535312652588
epoch = 0, batch_idx = 200, loss = 0.8182420134544373
epoch = 1, batch_idx = 0, loss = 0.7303786277770996
epoch = 1, batch_idx = 100, loss = 0.710280179977417
epoch = 1, batch_idx = 200, loss = 0.6227332353591919
epoch = 2, batch_idx = 0, loss = 0.5902734994888306
epoch = 2, batch_idx = 100, loss = 0.6123036742210388
epoch = 2, batch_idx = 200, loss = 0.5449265837669373
epoch = 3, batch_idx = 0, loss = 0.5221412181854248
epoch = 3, batch_idx = 100, loss = 0.5586401224136353
epoch = 3, batch_idx = 200, loss = 0.49894028902053833
epoch = 4, batch_idx = 0, loss = 0.47923576831817627
epoch = 4, batch_idx = 100, loss = 0.5219769477844238
epoch = 4, batch_idx = 200, loss = 0.46610695123672485
epoch = 5, batch_idx = 0, loss = 0.44889944791793823
epoch = 5, batch_idx = 100, loss = 0.49393826723098755
epoch = 5, batch_idx = 200, loss = 0.44056564569473267
epoch = 6, batch_idx = 0, loss = 0.4255021810531616
epoch = 6, batch_idx = 100, loss = 0.4713611900806427
epoch = 6, batch_idx = 200, loss = 0.41985756158828735
epoch = 7, batch_idx = 0, loss = 0.40636977553367615
epoch = 7, batch_idx = 100, loss = 0.45304715633392334
epoch = 7, batch_idx = 200, loss = 0.40254470705986023
epoch = 8, batch_idx = 0, loss = 0.39018934965133667
epoch = 8, batch_idx = 100, loss = 0.4377802312374115
epoch = 8, batch_idx = 200, loss = 0.38769394159317017
epoch = 9, batch_idx = 0, loss = 0.3761257529258728
epoch = 9, batch_idx = 100, loss = 0.42440712451934814
epoch = 9, batch_idx = 200, loss = 0.37483566999435425
epoch = 10, batch_idx = 0, loss = 0.3637397885322571
epoch = 10, batch_idx = 100, loss = 0.41257765889167786
epoch = 10, batch_idx = 200, loss = 0.3634665012359619
epoch = 11, batch_idx = 0, loss = 0.35271918773651123
epoch = 11, batch_idx = 100, loss = 0.4018203616142273
epoch = 11, batch_idx = 200, loss = 0.3533402979373932
epoch = 12, batch_idx = 0, loss = 0.3428434729576111
epoch = 12, batch_idx = 100, loss = 0.39214104413986206
epoch = 12, batch_idx = 200, loss = 0.3441687822341919
epoch = 13, batch_idx = 0, loss = 0.333955317735672
epoch = 13, batch_idx = 100, loss = 0.3832662105560303
epoch = 13, batch_idx = 200, loss = 0.33575889468193054
epoch = 14, batch_idx = 0, loss = 0.3258323669433594
epoch = 14, batch_idx = 100, loss = 0.3750748336315155
epoch = 14, batch_idx = 200, loss = 0.32803112268447876
epoch = 15, batch_idx = 0, loss = 0.3183256685733795
epoch = 15, batch_idx = 100, loss = 0.3674846291542053
epoch = 15, batch_idx = 200, loss = 0.32090339064598083
epoch = 16, batch_idx = 0, loss = 0.31145304441452026
epoch = 16, batch_idx = 100, loss = 0.3604132831096649
epoch = 16, batch_idx = 200, loss = 0.3143956661224365
epoch = 17, batch_idx = 0, loss = 0.3050137758255005
epoch = 17, batch_idx = 100, loss = 0.3537922203540802
epoch = 17, batch_idx = 200, loss = 0.30833786725997925
epoch = 18, batch_idx = 0, loss = 0.2989848256111145
epoch = 18, batch_idx = 100, loss = 0.3475683629512787
epoch = 18, batch_idx = 200, loss = 0.30268821120262146
epoch = 19, batch_idx = 0, loss = 0.2933279275894165
epoch = 19, batch_idx = 100, loss = 0.3417797088623047
epoch = 19, batch_idx = 200, loss = 0.29743343591690063
epoch = 20, batch_idx = 0, loss = 0.2880209684371948
epoch = 20, batch_idx = 100, loss = 0.33625274896621704
epoch = 20, batch_idx = 200, loss = 0.2924404740333557
epoch = 21, batch_idx = 0, loss = 0.2830372750759125
epoch = 21, batch_idx = 100, loss = 0.3309747576713562
epoch = 21, batch_idx = 200, loss = 0.2876732051372528
epoch = 22, batch_idx = 0, loss = 0.2783474326133728
epoch = 22, batch_idx = 100, loss = 0.32599449157714844
epoch = 22, batch_idx = 200, loss = 0.2831261456012726
epoch = 23, batch_idx = 0, loss = 0.27390143275260925
epoch = 23, batch_idx = 100, loss = 0.3213098645210266
epoch = 23, batch_idx = 200, loss = 0.27885332703590393
epoch = 24, batch_idx = 0, loss = 0.2696850299835205
epoch = 24, batch_idx = 100, loss = 0.3168952465057373
epoch = 24, batch_idx = 200, loss = 0.2747831344604492
epoch = 25, batch_idx = 0, loss = 0.2656662166118622
epoch = 25, batch_idx = 100, loss = 0.3127139210700989
epoch = 25, batch_idx = 200, loss = 0.2708702087402344
epoch = 26, batch_idx = 0, loss = 0.2618401050567627
epoch = 26, batch_idx = 100, loss = 0.3087657690048218
epoch = 26, batch_idx = 200, loss = 0.26713475584983826
epoch = 27, batch_idx = 0, loss = 0.25819507241249084
epoch = 27, batch_idx = 100, loss = 0.30497828125953674
epoch = 27, batch_idx = 200, loss = 0.26355743408203125
epoch = 28, batch_idx = 0, loss = 0.2547600567340851
epoch = 28, batch_idx = 100, loss = 0.30135565996170044
epoch = 28, batch_idx = 200, loss = 0.2601349949836731
epoch = 29, batch_idx = 0, loss = 0.251463383436203
epoch = 29, batch_idx = 100, loss = 0.2978888750076294
epoch = 29, batch_idx = 200, loss = 0.25687384605407715

Process finished with exit code 0