pytorch中计算loss的函数总结
前言
- 文章来源:
CSDN@LawsonAbs
- 代码见我GitHub 下的CrossEntropyLoss 的使用 + BCELoss的学习 等小节
深度学习中,为模型计算损失是一个重要的课题,然而又有很多种不同的损失计算方法,常用如 CrossEntropyLoss
,相应的pytorch
中也实现了这些基本方法,那么在pytorch
中,到底有哪些方法与loss相关呢?它们又是怎么使用的呢?看下面分解。
1. BCELoss
1.1 定义
BCELoss
简称Binary Cross Entropy Loss
,使用的计算公式就是简单的二分类交叉熵损失
L
=
−
[
y
l
o
g
y
^
+
(
1
−
y
)
l
o
g
(
1
−
y
^
)
]
L = - [ylog \hat y + (1-y)log(1-\hat y)]
L=−[ylogy^+(1−y)log(1−y^)]
具体通过案例来说一下怎么使用这个函数。
1.2 代码示例
import torch as t
import torch.nn as nn
softmax = nn.Softmax(dim=-1)
target = t.tensor([0,1])
# 使用BCELoss() 计算损失
bce = nn.BCELoss()
b = t.tensor([0.2939,0.939])
b1 = softmax(b) # 要归一化一下~
print(b1)
loss_2 = bce(b1,target.float())
print(loss_2)
执行结果:
这里的b1 = tensor([0.3441,0.6559])
与 target = t.tensor([0,1])
合着一起的含义就是:第一个样本标签为0,其预测的概率是0.3441;第二个样本的标签是1,其预测的概率是0.6559。根据上面的公式就可以得到计算公式:
L
=
l
1
+
l
2
=
[
(
1
−
0
)
l
o
g
(
1
−
0.3441
)
]
+
[
1
∗
l
o
g
(
0.6559
)
]
\begin{aligned} L =& l_1+l_2 \\ =&[(1-0)log(1-0.3441)] + [1*log(0.6559)] \end{aligned}
L==l1+l2[(1−0)log(1−0.3441)]+[1∗log(0.6559)]
也就是如下这个过程:
from math import log
loss = -(log(1-0.3441) + log(0.6559))/2
print(loss)
得到的结果是:0.4217469406824472
。
1.2 再看一个反直觉的代码
# 下面这个示例说明BCELoss 不仅需要看标签为1的值,还需要计算标签为0 的值
# 而我们在使用BCELoss的时候,因为给出的标签都是0或者1,所以就导致0乘的那部分的损失没有了。
# 如果使用LSR[label smoothing regularization 策略,则会计算相应的损失了]
import torch as t
import torch.nn as nn
pred= t.tensor([9.1276e-01])
b = t.tensor([1.0,])
criterion = nn.BCELoss()
loss_ab = criterion(pred,b)
print(loss_ab)
print(t.log(pred))
c = t.tensor([0.92,])
loss_ac = criterion(pred,c)
print(loss_ac)
print(c*t.log(pred) + (1-c) * t.log(1-pred))
计算结果是:
根据标签和预测值我们可以知道 pred 到 c的“距离”是要比到 b 近的(也就是pred更靠近c),但使用BCELoss计算出来的结果却是 loss_ab < loss_ac 。 所以在这里使用交叉熵是不是就不大ok了?
2. BCEWithLogitsLoss
分类损失,需要注意的问题有两个:
- 二分类损失
- 传入的参数是未经
sigmoid
处理的logits,所以叫WithLogits
但它是多个二分类问题【正负样本都得计算】,所以计算方式不同于二分类交叉熵。
- 手写代码实现
下面就自己写一下这个代码的实现:
先给出直接使用BCEWithLogitsLoss
的结果
"""
BCEWithLogitsLoss函数的学习
"""
import torch
import numpy as np
pred = np.array([[-0.4089, -1.2471, 0.5907],
[-0.4897, -0.8267, -0.7349],
[0.5241, -0.1246, -0.4751]])
label = np.array([[0, 1, 1],
[0, 0, 1],
[1, 0, 1]])
pred = torch.from_numpy(pred).float()
label = torch.from_numpy(label).float()
crition1 = torch.nn.BCEWithLogitsLoss(reduction='none')
loss1 = crition1(pred, label)
print(loss1)
再使用如下过程计算:
# 下面就来模拟一下 BCE 损失的计算过程
# 先计算 positive sample
a = torch.sigmoid(pred)
print(a)
eps = 1e-20
b = -torch.log(a+eps)
loss_pos = label*b
# 接着计算negetive sample
loss_neg = -(1-label)*torch.log(1-a+eps)
print(loss_pos)
print(loss_neg)
loss = loss_pos + loss_neg
print(loss)
更加详细的过程可以参考我的github
3. CrossEntropyLoss
3.1 问题一: 怎么用的?
3.2 问题二:CrossEntropyLoss
与 BCELoss
有什么区别?
4. NLLLoss
4.1 做什么?
负对数似然损失。简单的说就是: − ∑ i N l o g y i -\sum_{i}^{N}log y_i −i∑Nlogyi
4.2 定义
定义得看类。类的构造函数如下:
weight
代表的是size_average
:已废弃的参数ignore_index
:指定某个下标被忽略,从而不被贡献到输入的梯度中reduce
:已废弃reduction
:指定输出的归一操作
4.3 实现
损失的实现我们就只需要注意观察一下forward()
函数即可,
相关文章
- pytorch交叉熵损失函数计算_pytorch loss不下降
- 矩阵特征值和特征向量详细计算过程(转载)_矩阵特征值的详细求法
- 启科量子自主研发量子计算模拟后端 QuSprout 并正式开源
- 【计算理论】图灵机 ( 图灵机示例 )
- HJ2 计算某字符出现次数
- 在java代码中执行js脚本,实现计算出字符串“(1+2)*(1+3)”的结果详解编程语言
- 2013年那些被风吹走的云 云计算安全问题盘点
- IBM错过几次科技革命后还能跟上AI和云计算的浪潮吗?
- Oracle函数计算年龄:一种可靠而高效的方式(oracle函数计算年龄)
- MySQL的中位数算法,快速准确计算数据集中间的值(mysql 中位数算法)
- MySQL中如何计算中位数(mysql中位数如何计算)
- 添加C时间戳搞定时间间隔计算的最佳方案(C时间戳 为oracle)
- MySQL利用函数实现上下取整差值计算(mysql上下取差值)
- 利用Oracle 揭示乘法计算的精妙之处(oracle乘法计算公式)