【pytorch】onnx
2023-04-18 13:06:17 时间
t7 / pth -> onnx
pytorch任意形式的model(.t7、.pth等等)转.onnx全都可以采用固定格式。
完整实现:
def pth2onnx(self, simplify_onnx_sw=True):
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
model = torch.nn.DataParallel(self.model)
_state_dict = torch.load(pth_path, map_location=torch.device('cpu'))
model.load_state_dict(_state_dict, strict=True)
model.eval()
torch.onnx.export(model.module,
torch.randn(batch_size, *C.input_shape),
pure_onnx_path,
input_names=["input"],
output_names=["output"]
)
if simplify_onnx_sw:
os.system('python -m onnxsim {} {}'.format(pure_onnx_path, simplified_onnx_path))
print('
Simplified onnx has been save to {}
'.format(simplified_onnx_path))
os.remove(pure_onnx_path)
else:
print('
Pure onnx has been save to {}
'.format(pure_onnx_path))
实验举例:
model_dir = './'
pth_path = model_dir + 'A.pth'
onnx_path = model_dir + 'A.onnx'
batch_size = 1
input_shape = (3, 112, 112)
cfg = Config()
cfg.load_from_file(args.model_cfg_file)
model = PFLD_SE3_eval(cfg.model_conf.layer_cfg, cfg.model_conf.num_points)
model.load(pth_path)
model.eval()
torch.onnx.export(model,
torch.randn(batch_size, *input_shape),
onnx_path,
input_names=["input"],
output_names=["output_0", "output_1"],
)
print('
onnx has been save to {}
'.format(onnx_path))
如在mac下执行,还需要加上这行环境配置:
os.environ['KMP_DUPLICATE_LIB_OK']='True'
可能的报错:
ImportError: cannot import name 'get_all_providers' from 'onnxruntime.capi._pybind_state'
mac下的通用解决方法:
brew install libomp
如果还是报相同错误,则可能是版本问题。换版本即可。例如我是执行:
pip install onnxruntime==1.2.0
相关文章
- 【技术种草】cdn+轻量服务器+hugo=让博客“云原生”一下
- CLB运维&运营最佳实践 ---访问日志大洞察
- vnc方式登陆服务器
- 轻松学排序算法:眼睛直观感受几种常用排序算法
- 十二个经典的大数据项目
- 为什么使用 CDN 内容分发网络?
- 大数据——大数据默认端口号列表
- Weld 1.1.5.Final,JSR-299 的框架
- JavaFX 2012:彻底开源
- 提升as3程序性能的十大要点
- 通过凸面几何学进行独立于边际的在线多类学习
- 利用行动影响的规律性和部分已知的模型进行离线强化学习
- ModelLight:基于模型的交通信号控制的元强化学习
- 浅谈Visual Source Safe项目分支
- 基于先验知识的递归卡尔曼滤波的代理人联合状态和输入估计
- 结合网络结构和非线性恢复来提高声誉评估的性能
- 最佳实践丨云开发CloudBase多环境管理实践
- TimeVAE:用于生成多变量时间序列的变异自动编码器
- 具有线性阈值激活的神经网络:结构和算法
- 内网渗透之横向移动 -- 从域外向域内进行密码喷洒攻击