zl程序教程

您现在的位置是:首页 >  其他

当前栏目

DL之DNN优化技术:自定义MultiLayerNetExtend算法(BN层使用/不使用+权重初始值不同)对Mnist数据集训练评估学习过程

训练算法技术学习数据 优化 自定义 过程
2023-09-14 09:04:47 时间

DL之DNN优化技术:自定义MultiLayerNetExtend算法(BN层使用/不使用+权重初始值不同)对Mnist数据集训练评估学习过程

 

目录

输出结果

设计思路

核心代码

更多输出


 

 

相关文章
DL之DNN优化技术:采用三种激活函数(sigmoid、relu、tanh)构建5层神经网络,权重初始值(He初始化和Xavier初始化)影响隐藏层的激活值分布的直方图可视化
DL之DNN优化技术:自定义MultiLayerNet【5*100+ReLU】对MNIST数据集训练进而比较三种权重初始值(Xavier初始化、He初始化)性能差异
DL之DNN优化技术:利用MultiLayerNetExtend算法(BN层使用/不使用+权重初始值不同)对Mnist数据集训练评估学习过程
DL之DNN优化技术:DNN中参数初始化【Lecun参数初始化、He参数初始化和Xavier参数初始化】的简介、使用方法详细攻略
DL之DNN优化技术:自定义MultiLayerNetExtend算法(BN层使用/不使用+权重初始值不同)对Mnist数据集训练评估学习过程全部代码

 

输出结果

更多输出详见最后

 

设计思路

 

核心代码

(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)

x_train = x_train[:1000]
t_train = t_train[:1000]


max_epochs = 20
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.01


    bn_network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100], output_size=10, 
                                    weight_init_std=weight_init_std, use_batchnorm=True)
    network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100], output_size=10,
                                weight_init_std=weight_init_std)
    optimizer = SGD(lr=learning_rate)  
    train_acc_list = []                              
    bn_train_acc_list = []
    iter_per_epoch = max(train_size / batch_size, 1) 


    for i in range(1000000000):
        #定义x_batch、t_batch
        batch_mask = np.random.choice(train_size, batch_size)
        x_batch = x_train[batch_mask]
        t_batch = t_train[batch_mask]
    
        for _network in (bn_network, network):
            grads = _network.gradient(x_batch, t_batch) 
            optimizer.update(_network.params, grads)     
    
        if i % iter_per_epoch == 0: 
            train_acc = network.accuracy(x_train, t_train)       
            bn_train_acc = bn_network.accuracy(x_train, t_train) 
            train_acc_list.append(train_acc)
            bn_train_acc_list.append(bn_train_acc)
    
            print("epoch:" + str(epoch_cnt) + " | " + str(train_acc) + " - " + str(bn_train_acc)) 
    
            epoch_cnt += 1
            if epoch_cnt >= max_epochs: 
                break
                
    return train_acc_list, bn_train_acc_list  

 

 

更多输出

============== 1/16 ==============
epoch:0 | 0.093 - 0.085
epoch:1 | 0.117 - 0.08
epoch:2 | 0.117 - 0.081
epoch:3 | 0.117 - 0.1
epoch:4 | 0.117 - 0.125
epoch:5 | 0.117 - 0.143
epoch:6 | 0.117 - 0.163
epoch:7 | 0.117 - 0.191
epoch:8 | 0.117 - 0.213
epoch:9 | 0.117 - 0.236
epoch:10 | 0.117 - 0.258
epoch:11 | 0.117 - 0.268
epoch:12 | 0.117 - 0.28
epoch:13 | 0.117 - 0.297
epoch:14 | 0.117 - 0.31
epoch:15 | 0.117 - 0.322
epoch:16 | 0.117 - 0.335
epoch:17 | 0.117 - 0.36
epoch:18 | 0.116 - 0.378
epoch:19 | 0.117 - 0.391

============== 2/16 ==============
epoch:0 | 0.087 - 0.099
epoch:1 | 0.097 - 0.108
epoch:2 | 0.097 - 0.151
epoch:3 | 0.097 - 0.185
epoch:4 | 0.097 - 0.216
epoch:5 | 0.097 - 0.226
epoch:6 | 0.097 - 0.243
epoch:7 | 0.097 - 0.281
epoch:8 | 0.097 - 0.306
epoch:9 | 0.097 - 0.323
epoch:10 | 0.097 - 0.344
epoch:11 | 0.097 - 0.364
epoch:12 | 0.097 - 0.38
epoch:13 | 0.097 - 0.394
epoch:14 | 0.097 - 0.402
epoch:15 | 0.097 - 0.415
epoch:16 | 0.097 - 0.441
epoch:17 | 0.097 - 0.454
epoch:18 | 0.097 - 0.464
epoch:19 | 0.097 - 0.48

============== 3/16 ==============
epoch:0 | 0.104 - 0.108
epoch:1 | 0.364 - 0.111
epoch:2 | 0.499 - 0.121
epoch:3 | 0.587 - 0.153

……

epoch:17 | 0.116 - 0.62
epoch:18 | 0.116 - 0.615
epoch:19 | 0.116 - 0.652

============== 16/16 ==============
epoch:0 | 0.092 - 0.092
epoch:1 | 0.094 - 0.288
epoch:2 | 0.116 - 0.373
epoch:3 | 0.116 - 0.407
epoch:4 | 0.116 - 0.416
epoch:5 | 0.116 - 0.418
epoch:6 | 0.116 - 0.488
epoch:7 | 0.117 - 0.493
epoch:8 | 0.117 - 0.502
epoch:9 | 0.117 - 0.517
epoch:10 | 0.117 - 0.52
epoch:11 | 0.117 - 0.507
epoch:12 | 0.117 - 0.524
epoch:13 | 0.117 - 0.521
epoch:14 | 0.117 - 0.523
epoch:15 | 0.117 - 0.522
epoch:16 | 0.117 - 0.522
epoch:17 | 0.116 - 0.523
epoch:18 | 0.116 - 0.481
epoch:19 | 0.116 - 0.509

 

相关文章
CSDN:2019.04.09起