pytorch 保存训练好的模型
2023-02-18 16:32:56 时间
1 保存和加载整个模型
torch.save(model_object, 'model.pth')
model = torch.load('model.pth')
2 仅保存和加载模型参数
torch.save(model_obj.state_dict(), 'params.pth')
model_obj.load_state_dict(torch.load('params.pth'))
3 选择保存网络中的一部分参数或者额外保存其余的参数
torch.save({'state_dict': net.state_dict(), 'linear1':net.linear1.state_dict(),
'optimizer': optimizer.state_dict(),'num_epoch':num_epochs },
'detail.pth')
model = torch.load('detail.pth')
net = DNN(num_input,num_hidden1,num_hidden2,num_output)
net.load_state_dict(model['state_dict'])
参考:
[日常] PyTorch 预训练模型,保存,读取和更新模型参数以及多 GPU 训练模型
相关文章
- neo4j的安装部署
- Nifi组件脚本开发—ExecuteScript 使用指南(三)
- Nifi组件脚本开发—ExecuteScript 使用指南(二)
- Nifi组件脚本开发—ExecuteScript 使用指南(一)
- 数据结构—哈希表
- 设计模式——从HttpServletRequestWrapper了解装饰者模式
- 设计模式——单例模式
- 设计模式汇总
- 设计模式——责任链(结合Tomcat中Filter机制)
- web服务器专题:tomcat(三)tomcat-users.xml 配置文件
- web服务器专题:tomcat(二)模块组件与server.xml 配置文件
- web服务器专题:tomcat(一)基础及模块
- JDBC的基础接口及其用法
- 基于JQuery的前端form表单操作
- jQuery中的ajax
- NETs相关基因构建预后模型干湿结合发12分+SCI
- 用 Minio 快速启动 Velero 实现 Kubernetes资源备份
- MySQL(1) - 用户管理及root密码修改
- Jmeter系列(40)- 详解 Jmeter 图形化 HTML 压测报告
- Jmeter系列(39)- 详解 Jmeter CLI 模式