zl程序教程

您现在的位置是:首页 >  后端

当前栏目

Pytorch RuntimeError 解决办法

PyTorch 解决办法 RuntimeError
2023-06-13 09:15:24 时间

问题描述

在Pytorch训练自定义数据集中发生如下错误:

RuntimeError: result type Float can't be cast to the desired output type Long

RuntimeError:结果类型 Float 无法转换为所需的输出类型 Long

loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_weights]))

问题解决

BCEWithLogitsLoss 要求它的目标是一个float 张量,而不是long。所以应该通过dtype=torch.float32指定张量的类型。

将上述代码修改如下:

loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_weights], dtype=torch.float32))

参考文章:Pytorch 抛出错误 RuntimeError: result type Float can’t be cast to the desired output type Long答案 - 爱码网 (likecs.com)