在pytorch中 保存 和加载神经网络
2023-09-14 09:05:37 时间
import torch
import matplotlib.pyplot as plt
torch.manual_seed(1) # reproducible
fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
保存神经网络和神经网络当前训练后的状态
def save():
# save net1
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
for t in range(100):
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(
相关文章
- 常用神经网络架构总结
- 基于萤火虫算法优化的BP神经网络预测模型(Matlab代码实现)
- 基于粒子群优化算法的BP神经网络预测模型(Matlab代码实现)
- 【评价模型】模糊神经网络的水质评价模型
- 压缩网络模型,或者是融合多个神经网络
- 【DL with Pytorch】第 6 章 : 用循环神经网络分析数据序列
- 【Pytorch深度学习实战】(8)双向循环神经网络(BiRNN)
- 基于萤火虫优化的BP神经网络(预测应用) - 附代码
- Python实现PSO粒子群优化循环神经网络LSTM分类模型项目实战
- 图神经网络的攻击防御
- android NDK 神经网络API——是给tensorflow lite调用的底层API,应用开发者使用tensorflow lite即可
- Self Organizing Maps (SOM): 一种基于神经网络的聚类算法,本质上感觉和kmeans迭代没啥区别
- Pytorch总结十二之 深度神经网络模型:NIN、GoogLeNet
- 【CV】第 3 章:使用 PyTorch 构建深度神经网络
- python工具方法 1 tensorflow简单全连接神经网络,识别minist手写数字
- Pytorch 1 简单卷积神经网络 minist分类
- pytorch学习笔记(十一):循环神经网络RNN(简介)
- pytorch学习笔记(十):卷积神经网络CNN(进阶篇)