zl程序教程

您现在的位置是:首页 >  其它

当前栏目

2020年人工神经网络第二次作业-参考答案第八题

2020 作业 参考答案 第二次 人工神经网络 第八
2023-09-11 14:15:21 时间

如下是 2020年人工神经网络第二次作业 中第八题的参考答案。

 

01 第八题参考答案


1.题目分析

(1) 训练样本

根据题目中的两类样本点在坐标系的位置,可以获得训练样本输入数据矩阵x_train;和类别矩阵:y_train。

x_train= [[-3.  0.]
 [-2.  1.]
 [-2. -1.]
 [ 0.  2.]
 [ 0.  1.]
 [ 0. -1.]
 [ 0. -2.]
 [ 2.  1.]
 [ 3. -1.]
 [ 3.  0.]]
y_train= [0 0 0 1 1 1 1 0 0 0]

▲ 两个类别样本点所处的位置

▲ 两个类别样本点所处的位置

(2) LVQ网络结构

LVQ网络具有四个节点。其中前两个设定为第一类;后两个设定为第二类别。网络结构如下图所示:

▲ LVQ网络结构

▲ LVQ网络结构

对于网络四个节点进行随机初始化。初始化的范围在(-4.5,4.5)×(-2.5,2.5)。下面显示了四个神经节点在某次随机初始化之后所在的位置(四个绿色点):

▲ LVQ竞争层节点随机初始化的位置

▲ LVQ竞争层节点随机初始化的位置

2.求解过程

求解过程中相关程序参见后面附录中 作业中的程序

(1) 训练参数

  • 学习速率 η \eta η从0.1线性减少到0.01
  • 训练步骤:N=100

在训练过程中,对于10个样本采取随机的顺序对LVQ进行训练。训练采用LVQ的最基本的训练算法:即只对胜出的神经元进行奖赏性,或者惩罚性学习。

(2) 训练过程1

如下是LVQ初始化状态。
▲ LVQ初始状态

▲ LVQ初始状态

下图显示了学习过程网络节点位置演变过程。可以看到最终四个节点(分成两类)分别处在两类样本集合的各自的中心位置。效果很好。
▲ LVQ训练过程

▲ LVQ训练过程

下面显示了另外一次随机初始化之后网络训练的演化过程:

▲ LVQ训练中隐层节点演过过程

▲ LVQ训练中隐层节点演过过程

(3) 训练过程2

如下是另外一种随机初始化的过程。其中第二类隐层节点中(棕色大点)存在一个位于最左边的情况。由于它受到第一类训练样本的排斥,所以在整个训练过程中,它就移出整个范围。最终竞争节点第二类中中只有一个节点回到了第二类的中心。

▲ LVQ初始化位置

▲ LVQ初始化位置

下面的动态过程显示了整个训练结果的演化过程。
▲ LVQ训练过程

▲ LVQ训练过程

下面是另外一次出现这种情况的训练过程:

3.结果讨论

  • LVQ的训练结果会受到初始化的影响。有的时候,训练结果可以收到了一个全局最优的结果。有的时候,会出现部分竞争节点远离样本区域,使得可以用于学习样本的神经元节点减少。


 

※ 作业中的程序


1.LVQ程序

#!/usr/local/bin/python
# -*- coding: gbk -*-
#============================================================
# HW28.PY                      -- by Dr. ZhuoQing 2020-11-26
#
# Note:
#============================================================

from headm import *
import hw28data

x_train = hw28data.x_train
y_train = hw28data.y_train

W = array([[random.uniform(-4.5,4.5), random.uniform(-2.5,2.5)] for i in range(4)])

def show_W(w):
    plt.axis([-5,5,-3,3])
    plt.scatter(w[0:2,0], w[0:2,1], s=75, c='darkblue', label='Node-1')
    plt.scatter(w[2:4,0], w[2:4,1], s=75, c='darkred', label='Node-2')
    hw28data.show_train(x_train, y_train)

    plt.xlabel("x1")
    plt.ylabel("x2")
    plt.grid(True)
    plt.legend(loc="upper left")

#------------------------------------------------------------
def WTA_nearest(x, v):
    err = [x - vv for vv in v]
    dist = [dot(e,e) for e in err]
    id = where(dist==amin(dist))[0][0]

    if id < 2: classid = 0
    else: classid = 1

    return id, classid

def LVQ_train(x, y, w, eta):
    iddim = list(range(x.shape[0]))
    random.shuffle(iddim)

    for id in iddim:
        iidd,classid = WTA_nearest(x[id], w)

        if classid == y[id]:
            w[iidd] = w[iidd] + eta* (x[id] - w[iidd])
        else: w[iidd] = w[iidd] - eta * (x[id] - w[iidd])

    return w

#------------------------------------------------------------
ETA_BEGIN       = 0.1
ETA_END         = 0.01
TRAIN_LOOP      = 100
plt.draw()
plt.pause(.2)
pltgif = PlotGIF()

for id,eta in enumerate(linspace(ETA_BEGIN, ETA_END, TRAIN_LOOP)):
    W = LVQ_train(x_train, y_train, W, eta)

    plt.clf()
    show_W(W)
    plt.title('Step:%d, Eta:%4.2f'%(id, eta))
    plt.draw()
    plt.pause(.001)
    plt.tight_layout()

    pltgif.append(plt)

pltgif.save(r'd:\temp\1.gif')
printf('\a')
plt.show()

#------------------------------------------------------------
#        END OF FILE : HW28.PY
#============================================================

2.作业2-8中数据子程序

#!/usr/local/bin/python
# -*- coding: gbk -*-
#============================================================
# HW28DATA.PY                  -- by Dr. ZhuoQing 2020-11-26
#
# Note:
#============================================================

from headm import *

x_train = array([[-3,0], [-2, 1], [-2,-1], [0,2], [0,1], [0,-1], [0,-2],\
                 [2,1], [2,-1], [3,0]]).astype('float32')
y_train = array([0,0,0,1,1,1,1,0,0,0])

#------------------------------------------------------------
def show_train(x, y):
    x_1 = x_train[y_train==0]
    x_2 = x_train[y_train==1]

    plt.scatter(x_1[:,0], x_1[:,1], s=30, c='blue', label='Class-1')
    plt.scatter(x_2[:,0], x_2[:,1], s=30, c='red',  label='Class-2')

#------------------------------------------------------------

if __name__ == "__main__":
    printff('x_train=', x_train)
    printff('y_train=', y_train)

    show_train(x_train, y_train)

    plt.xlabel("x1")
    plt.ylabel("x2")
    plt.grid(True)
    plt.legend(loc="upper right")
    plt.tight_layout()
#    plt.show()

#------------------------------------------------------------
#        END OF FILE : HW28DATA.PY
#============================================================