利用torch.nn实现前馈神经网络解决 多分类 任务使用至少三种不同的激活函数
2023-02-18 16:32:56 时间
1 导入包
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import TensorDataset,DataLoader
import torchvision
from IPython import display
from torchvision import transforms
2 加载数据
mnist_train = torchvision.datasets.MNIST(root='./datasets/',download = True,train = True,transform = transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='./datasets/',download = True,train = False,transform = transforms.ToTensor())
batch_size = 256
train_iter = DataLoader(dataset = mnist_train,shuffle = True,batch_size = batch_size)
test_iter = DataLoader(dataset = mnist_test,shuffle = True,batch_size = batch_size)
3 定义平铺层
class FlattenLayer(torch.nn.Module):
def __init__(self):
super(FlattenLayer, self).__init__()
def forward(self, x):
return x.view(x.shape[0],784)
4 模型选择
num_input,num_hidden1,num_hidden2,num_output = 28*28,512,256,10
def choose_model(model_type):
if model_type =='ReLU':
activation = nn.ReLU()
elif model_type =='ELU':
activation = nn.ELU()
else:
activation = nn.Sigmoid()
model = nn.Sequential()
model.add_module("flatten",FlattenLayer())
model.add_module("linear1",nn.Linear(num_input,num_hidden1))
model.add_module("activation",activation)
model.add_module("linear2",nn.Linear(num_hidden1,num_hidden2))
model.add_module("activation",activation)
model.add_module("linear3",nn.Linear(num_hidden2,num_output))
return model
model = choose_model('ReLU')
print(model)
5 参数初始化
# for param in model.parameters():
# nn.init.normal_(param,mean=0,std=0.001)
for m in model.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0.1)
6 定义训练函数
def train(net,train_iter,test_iter,loss,num_epochs):
train_ls,test_ls,train_acc,test_acc = [],[],[],[]
for epoch in range(num_epochs):
train_ls_sum,train_acc_sum,n = 0,0,0
for x,y in train_iter:
y_pred = net(x)
l = loss(y_pred,y)
optimizer.zero_grad()
l.backward()
optimizer.step()
train_ls_sum +=l.item()
train_acc_sum += (y_pred.argmax(dim = 1)==y).sum().item()
n += y_pred.shape[0]
train_ls.append(train_ls_sum)
train_acc.append(train_acc_sum/n)
test_ls_sum,test_acc_sum,n = 0,0,0
for x,y in test_iter:
y_pred = net(x)
l = loss(y_pred,y)
test_ls_sum +=l.item()
test_acc_sum += (y_pred.argmax(dim = 1)==y).sum().item()
n += y_pred.shape[0]
test_ls.append(test_ls_sum)
test_acc.append(test_acc_sum/n)
print('epoch %d, train_loss %.6f,test_loss %f, train_acc %.6f,test_acc %f'
%(epoch+1, train_ls[epoch],test_ls[epoch], train_acc[epoch],test_acc[epoch]))
return train_ls,test_ls,train_acc,test_acc_sum
7 定义损失函数和优化器
#训练次数和学习率
num_epochs = 20
lr = 0.01
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=lr)
8 模型训练
train_loss,test_loss,train_acc,test_acc = train(model,train_iter,test_iter,loss,num_epochs)
相关文章
- 使用 Redis 实现延时队列
- redis 使用setnx实现锁
- Redis数据结构存储系统:第四章:底层实现原理
- Kubernetes 网络流量流转路径
- SpringDataRedis:第一章:简介
- WebAssembly 可以取代 Kubernetes 吗?
- 软件测试|Monkey基本参数介绍
- springboot 之集成Redis
- Redis事件循环
- Grafana 系列文章(一):基于 Grafana 的全栈可观察性 Demo
- Docker 基础 - 3
- Docker 基础 - 2
- Docker 基础 - 1
- Crossplane - 比 Terraform 更先进的云基础架构管理平台?
- Cert Manager 申请SSL证书流程及相关概念-三
- Cert Manager 申请SSL证书流程及相关概念-二
- Cert Manager 申请 SSL 证书流程及相关概念 - 一
- APM vs NPM
- API 网关的功能用途及实现方式
- Ansible 学习笔记 - 批量巡检站点 URL 状态