利用pytorch构建一个完整的自定义神经网络
2023-09-11 14:21:06 时间
一、自定义神经网络
import torch import tqdm class TwoLayerNet(torch.nn.Module): def __init__(self, D_in, H, D_out): """ 在构造函数中,我们实例化了两个nn.Linear模块,并将它们作为成员变量。 """ super(TwoLayerNet, self).__init__() self.linear1 = torch.nn.Linear(D_in, H) self.linear2 = torch.nn.Linear(H, D_out) def forward(self, x): #.clamp(input, min, max, out=None) → Tensor表示将结果限定在【min,max】之间 h_relu = self.linear1(x).clamp(min=0) y_pred = self.linear2(h_relu) return y_pred #定义输入输出,构建模型、损失函数、优化器 N, D_in, H, D_out = 64, 1000, 100, 10 # N是批大小; D_in 是输入维度;H 是隐藏层维度; D_out 是输出维度 x = torch.randn(N, D_in) #x.shape=torch.Size([64, 1000]) y = torch.randn(N, D_out)#y.shape=torch.Size([64, 10]) model = TwoLayerNet(D_in, H, D_out) loss_fn = torch.nn.MSELoss(reduction='sum')# 构造损失函数和优化器。 optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)# SGD构造函数中对model.parameters()的调用 #K开始模型训练 for i in tqdm.tqdm(range(500)): y_pred = model(x)# 前向传播:通过向模型传递x计算预测值y loss = loss_fn(y_pred, y)#计算并输出loss if (i+1)%100==0: print(i, loss.item()) optimizer.zero_grad()# 清零梯度,反向传播,更新权重 loss.backward() optimizer.step()
验证一下结果:
model(x[10,:])
y[10,:]
Out[32]:
相关文章
- 机器学习笔记 - win10安装Pytorch-GPU版本并训练第一个神经网络
- 【PyTorch】教程:对抗学习实例生成
- Pytorch之CNN:从代码认知CNN经典架构—基于Pytorch框架的代码实现经典轻量化卷积神经网络的算法集合(SqueezeNet/MobileNet3/ShuffleNet)
- 使用pytorch创建神经网络并解决线性拟合和分类问题
- 用pytorch的两种方法创建神经网络
- VIT pytorch源码
- pytorch设计卷积神经网络的广义分析
- Jetson NX系统烧录以及CUDA、cudnn、pytorch等环境的安装
- 【Pytorch】expand()用法==》扩展某个维度
- 【DL with Pytorch】第 4 章 : 卷积神经网络
- 【DL with Pytorch】第 2 章 : 神经网络的构建块
- 【Pytorch深度学习实战】(6)递归神经网络(RNN)
- Pytorch——torch.nn.init 中实现的初始化函数
- Pytorch网络模型转Onnx格式,多种方法(opencv、onnxruntime、c++)调用Onnx
- pytorch的Dataloader的shuffle
- 【youcans的深度学习 D02】PyTorch例程:创建 LeNet 模型进行图像分类
- 【PyTorch】安装支持cuda的pytorch-1.10.2
- 【深度学习】Pytorch面试题:什么是 PyTorch?PyTorch 的基本要素是什么?Conv1d、Conv2d 和 Conv3d 有什么区别?
- Pytorch总结十二之 深度神经网络模型:NIN、GoogLeNet
- 【Pytorch with fastai】第 17 章 :基础神经网络
- 【Pytorch with fastai】第 10 章 :NLP 深入探讨 RNN
- 【Pytorch Lighting】第 3 章:使用预训练模型进行迁移学习
- Pytorch搭建CIFAR10神经网络
- pytorch 35 yolov5_obb项目解读+使用技巧+调优经验(提升map)
- 超简单的pytorch(GPU版)安装教程(亲测有效)
- pytorch及相关工具的安装