transformer 4 RuntimeError: Expected tensor for argument #1 ‘indices‘ to have scalar type Long
在使用transformer 4.0时,报错误提示RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.IntTensor instead (while checking arguments for embedding)。该问题主要时由于tensor的类型导致的,解决方法是在相应报错行的前一行对数据类型进行转换。假设输入数据为x,那么增加行为“x = torch.tensor(x).to(torch.int64)”。
如果修改之后仍然出现该错误,并且发生错误的位置发生变化,如下面述错误示例所示,那么逐一进行对应修改即可。
1 完整错误样例一
完整错误提示为:
File "D:\ProgramData\Anaconda3\lib\site-packages\transformers\pipelines.py", line 1874, in __call__
start, end = self.model(**fw_args)[:2]
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\ProgramData\Anaconda3\lib\site-packages\transformers\models\bert\modeling_bert.py", line 1621, in forward
return_dict=return_dict,
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\ProgramData\Anaconda3\lib\site-packages\transformers\models\bert\modeling_bert.py", line 843, in forward
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\ProgramData\Anaconda3\lib\site-packages\transformers\models\bert\modeling_bert.py", line 198, in forward
inputs_embeds = self.word_embeddings(input_ids)
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\sparse.py", line 126, in forward
self.norm_type, self.scale_grad_by_freq, self.sparse)
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\functional.py", line 1852, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.IntTensor instead (while checking arguments for embedding)
在错误提示中,有如下一行:
“"D:\ProgramData\Anaconda3\lib\site-packages\transformers\models\bert\modeling_bert.py", line 198, in forward inputs_embeds = self.word_embeddings(input_ids)”,
那么只需要在这一行前面增加:
input_ids = torch.tensor(input_ids).to(torch.int64)
2 完整错误样例二
重新运行程序,这一行错误跳过,但仍有类似错误:
File "D:\ProgramData\Anaconda3\lib\site-packages\transformers\pipelines.py", line 1874, in __call__
start, end = self.model(**fw_args)[:2]
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\ProgramData\Anaconda3\lib\site-packages\transformers\models\bert\modeling_bert.py", line 1622, in forward
return_dict=return_dict,
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\ProgramData\Anaconda3\lib\site-packages\transformers\models\bert\modeling_bert.py", line 844, in forward
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\ProgramData\Anaconda3\lib\site-packages\transformers\models\bert\modeling_bert.py", line 201, in forward
token_type_embeddings = self.token_type_embeddings(token_type_ids)
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\sparse.py", line 126, in forward
self.norm_type, self.scale_grad_by_freq, self.sparse)
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\functional.py", line 1852, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.IntTensor instead (while checking arguments for embedding)
同样地,错误提示所在行为:
"D:\ProgramData\Anaconda3\lib\site-packages\transformers\models\bert\modeling_bert.py", line 201, in forward token_type_embeddings = self.token_type_embeddings(token_type_ids)。
那么在这一行之前增加:
token_type_ids = torch.tensor(token_type_ids).to(torch.int64)
3 解决方法
综上所述,该问题主要时由于tensor的类型导致的,解决方法是在相应报错行的前一行对数据类型进行转换。假设输入数据为x,那么增加行为“x = torch.tensor(x).to(torch.int64)”。如果修改之后仍然出现该错误,并且发生错误的位置发生变化,如上述错误示例所示,那么逐一进行对应修改即可。
相关文章
- Text to image论文精读Adma-GAN:用于文本到图像生成的属性驱动内存增强型GAN Attribute-Driven Memory Augmented GANs for T2I
- from/to maven-default-http-blocker (http://0.0.0.0/): Blocked mirror for repositories报错解决方案
- maven install时报错The packaging for this project did not assign a file to the build artifact
- 【论文解读】(如何微调BERT?) How to Fine-Tune BERT for Text Classification?
- 【论文精度】《Few-Shot Domain Adaptation For End-to-End Communication》
- Unable to resolve service for type 'Microsoft.AspNetCore.ResponseCompression.IResponseCompressionProvider' while attempting to activate 'Microsoft.AspNetCore.ResponseCompression.ResponseCompressionMid
- SetParameterValue for main report or subreport in crystal report
- Failed to convert property value of type ‘null‘ to required type ‘double‘ for property ‘balance‘解决方案
- error: 'for' loop initial declarations are only allowed in C99 or C11 mode
- UniEAP Platform V5.0 Unable to compile class for JSP
- due to a StackOverflowError. Possible root causes include a too low setting for -Xss and illegal cyclic inheritance dependencies. The class hierarchy being processed was [org.jaxen.util.AncestorAxisIt
- [Django] 01 - How to design tests for REST API
- Scientific Toolworks Understand for linux安装方法
- transaction manager has disabled its support for remote/network transactions. 该伙伴事务管理器已经禁止了它对远程/网络事务
- [转载]Memory Limits for Windows and Windows Server Releases
- valgrind: failed to start tool 'memcheck' for platform 'amd64-linux': No such file or directory
- mysql Invalid default value for 'time'
- uplink Tx switching for ENDC/CA
- ComPDFKit PDF SDK for Windows crack
- Rebex Syslog for.NET R6.6,Adds support for multicast addresses
- 运维 —— IMP-00030: failed to create file import_sys for write
- 【10】 kotlin for 与while 循环。
- 微信小程序警告:Now you can provide attr "wx:key" for a "wx:for" to improve performance.
- 注册中心nacos [Error watching Nacos Service change]跟[No service to register for nacos client...]
- 【错误记录】Mac 中 Python 报错 ( ERROR: Could not build wheels for numpy which use PEP 517 | 问题未解决 | 问题记录 )