zl程序教程

您现在的位置是:首页 >  后端

当前栏目

pytorch加载bert模型报错

PyTorch 报错 模型 加载 bert
2023-09-14 09:08:40 时间

背景

使用pytorch加载huggingface下载的albert-base-chinede模型出错

Exception has occurred: OSError
Unable to load weights from pytorch checkpoint file. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True.

模型地址:https://huggingface.co/models?search=albert_chinese

方法一:

参考以下文章删除缓存目录,问题还是存在
https://blog.csdn.net/znsoft/article/details/107725285
https://github.com/huggingface/transformers/issues/6159

方法二:

使用另一台电脑加载相同模型,加载成功,查看两台电脑的torch、transformers版本,发现一个torch为1.1,另一个为torch1.7.x
参考pytorch官网,torch1.6之后修改了模型保存方式,高版本保存的模型,低版本无法加载

The 1.6 release of PyTorch switched torch.save to use a new zipfile-based file format. torch.load still retains the ability to load files in the old format. If for any reason you want torch.save to use the old format, pass the kwarg _use_new_zipfile_serialization=False.

解决方法:

  1. 升级torch为高版本
  2. 如果因为cuda兼容等问题无法升级,可以在高版本上加载模型,然后重新save并添加_use_new_zipfile_serialization=False
from transformers import *
import torch

pretrained = 'D:/07_data/albert_base_chinese'
tokenizer = BertTokenizer.from_pretrained(pretrained)
model = AlbertForMaskedLM.from_pretrained(pretrained)

# 它包装在PyTorch DistributedDataParallel或DataParallel中
model_to_save = model.module if hasattr(model, 'module') else model

torch.save(model_to_save.state_dict(), 'pytorch_model_unzip.bin', _use_new_zipfile_serialization=False)

其他保存方法请参考:
https://blog.csdn.net/fendouaini/article/details/105322537
https://huggingface.co/transformers/serialization.html#serialization-best-practices