在多分类任务实验中手动实现 使用 𝑳𝟐 正则化
2023-02-18 16:33:18 时间
1 导入所需要的包
import torch
import numpy as np
import random
from IPython import display
from matplotlib import pyplot as plt
import torchvision
import torchvision.transforms as transforms
2 下载MNIST数据集
mnist_train = torchvision.datasets.MNIST(root='../Datasets/MNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='../Datasets/MNIST', train=False,download=True, transform=transforms.ToTensor())
3 读取数据
batch_size = 256
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)
4 初始化参数+定义隐藏层的激活函数
num_inputs,num_hiddens,num_outputs =784, 256,10
def init_param():
W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32)
b1 = torch.zeros(1, dtype=torch.float32)
W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32)
b2 = torch.zeros(1, dtype=torch.float32)
params =[W1,b1,W2,b2]
for param in params:
param.requires_grad_(requires_grad=True)
return W1,b1,W2,b2
def relu(x):
x = torch.max(input=x,other=torch.tensor(0.0))
return x
5 定义模型
def net(X):
X = X.view((-1,num_inputs))
H = relu(torch.matmul(X,W1.t())+b1)
return torch.matmul(H,W2.t())+b2
6 定义交叉熵损失函数和优化器
loss = torch.nn.CrossEntropyLoss()
def SGD(paras,lr):
for param in params:
param.data -= lr * param.grad
7 定义L2范数
def l2_penalty(w):
return (w**2).sum()/2
8 定义训练函数
def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None,mylambda=0):
train_ls, test_ls = [], []
for epoch in range(num_epochs):
ls, count = 0, 0
for X,y in train_iter :
X = X.reshape(-1,num_inputs)
l=loss(net(X),y)+ mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)
optimizer.zero_grad()
l.backward()
optimizer.step()
ls += l.item()
count += y.shape[0]
train_ls.append(ls)
ls, count = 0, 0
for X,y in test_iter:
X = X.reshape(-1,num_inputs)
l=loss(net(X),y) + mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)
ls += l.item()
count += y.shape[0]
test_ls.append(ls)
if(epoch+1)%5==0:
print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))
return train_ls,test_ls
9 开始训练模型
lr = 0.01
num_epochs = 20
Lamda = [0,0.1,0.2,0.3,0.4,0.5]
Train_ls, Test_ls = [], []
for lamda in Lamda:
print("current lambda is %f"%lamda)
W1,b1,W2,b2 = init_param()
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)
train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer,lamda)
Train_ls.append(train_ls)
Test_ls.append(test_ls)
10 绘制训练集和测试集的loss曲线
x = np.linspace(0,len(Train_ls[1]),len(Train_ls[1]))
plt.figure(figsize=(10,8))
for i in range(0,len(Lamda)):
plt.plot(x,Train_ls[i],label= f'L2_Regularization:{Lamda [i]}',linewidth=1.5)
plt.xlabel('different epoch')
plt.ylabel('loss')
plt.legend(loc=2, bbox_to_anchor=(1.1,1.0),borderaxespad = 0.)
plt.title('train loss with L2_penalty')
plt.show()
相关文章
- [MySQL] 导入数据时防止出现乱码
- [MySQL] group by 聚合函数的原理和聚合限制原因SELECT list is not in GROUP BY clause and contains nonaggregated column
- [PHP] websocket协议的生成
- [MySQL]mysql的ANY_VALUE()函数 解决 ONLY_FULL_GROUP_BY 模式
- [PHP] 框架中.env文件的加载过程
- [PHP] PHP7已经删除了preg_replace的e修饰符
- [日常]解决Connection to `ssl://pecl.php.net:443' failed
- [日常]k8s的前世今生
- [日常]windows下kill进程工具taskkill
- [CSS] 纯CSS的前端图标icon库并且修改大小和颜色
- [PHP] php中的索引数组和数组顺序问题
- [PHP] RBAC权限与审批流的简单数据库构想
- [日常] 修改编辑word中的页眉页脚
- [PHP] socket客户端时的超时问题
- [nginx]配置nginx支持websocket解决返回400错误问题
- [nginx]配置nginx支持websocket解决返回400错误问题
- [PHP] php5.3之前-php5.3-php7垃圾回收机制的进化
- [PHP] include语句的注意事项
- [TCP] tcp连接SYN超时重传次数和超时时间
- [PHP] 编译安装swoole