zl程序教程

您现在的位置是:首页 >  后端

当前栏目

深入浅出PyTorch中的nn.CrossEntropyLoss

PyTorch 深入浅出 NN
2023-09-27 14:19:49 时间

一、前言

nn.CrossEntropyLoss 常用作多分类问题的损失函数(对交叉熵还不了解的读者可以看我的这篇文章),本文将围绕PyTorch的官方文档对重要知识点进行逐一讲解(不会全部讲解)。

import torch
import torch.nn as nn

二、理论基础

对于 C   ( C > 2 ) C\,(C>2) C(C>2) 分类问题,先不考虑 batch 的情形,设神经网络的输出(还未经过 Softmax)为 { x c } c = 1 C \{x_c\}_{c=1}^C {xc}c=1C,经过 Softmax 后得到

q i = exp ⁡ ( x i ) ∑ c = 1 C exp ⁡ ( x c ) q_i=\frac{\exp(x_i)}{\sum_{c=1}^C\exp(x_c)} qi=c=1Cexp(xc)exp(xi)

从而该样本的交叉熵损失为

H ( p , q ) = − ∑ i = 1 C p i log ⁡ q i = − ∑ i = 1 C p i log ⁡ exp ⁡ ( x i ) ∑ c = 1 C exp ⁡ ( x c ) H(p,q)=-\sum_{i=1}^C p_i\log q_i=-\sum_{i=1}^C p_i\log\frac{\exp(x_i)}{\sum_{c=1}^C\exp(x_c)} H(p,q)=i=1Cpilogqi=i=1Cpilogc=1Cexp(xc)exp(xi)

其中 ( p 1 , p 2 , ⋯   , p C ) (p_1,p_2,\cdots,p_C) (p1,p2,,pC) 是 One-Hot 向量。

不妨令 p y = 1   ( y ∈ { 1 , 2 , ⋯   , C } ) p_y=1\,(y\in\{1,2,\cdots,C\}) py=1(y{1,2,,C}),其余为 0 0 0,因此上式变为

H ( p , q ) = − log ⁡ exp ⁡ ( x y ) ∑ c = 1 C exp ⁡ ( x c ) H(p,q)=-\log\frac{\exp(x_y)}{\sum_{c=1}^C\exp(x_c)} H(p,q)=logc=1Cexp(xc)exp(xy)

现在考虑有 batch 的情形,不妨设 batch size 为 N N N,神经网络的输出为 { x n c } n c ,    n = 1 , ⋯   , N ,    c = 1 , ⋯   , C \{x_{nc}\}_{nc},\;n=1,\cdots,N,\;c=1,\cdots,C {xnc}nc,n=1,,N,c=1,,C,第 n n n 个样本的真实类别记为 y n   ( y n ∈ { 1 , 2 , ⋯   , C } ) y_n\,(y_n\in\{1,2,\cdots,C\}) yn(yn{1,2,,C}),第 n n n 个样本的交叉熵损失记为 l n l_n ln,则仿照上式就有

l n = − log ⁡ exp ⁡ ( x n , y n ) ∑ c = 1 C exp ⁡ ( x n c ) l_n=-\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})} ln=logc=1Cexp(xnc)exp(xn,yn)

接下来我们讨论一些特殊情形。当数据不平衡时(某一类的样本数特别多,另一类的样本数特别少),我们需要为每一类的损失安排一个权重用来平衡。权重为 w = ( w 1 , w 2 , ⋯   , w C ) \boldsymbol{w}=(w_1,w_2,\cdots,w_C) w=(w1,w2,,wC)

📌 模型容易在样本数最多的一个(或几个)类上过拟合,因此对于那些样本数较少的类,我们需要设置更高的权重,这样模型在预测这些类的标签时一旦出错,就会受到更多的惩罚

安排了权重后,相应的损失为

l n = − w y n log ⁡ exp ⁡ ( x n , y n ) ∑ c = 1 C exp ⁡ ( x n c ) l_n=-w_{y_n}\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})} ln=wynlogc=1Cexp(xnc)exp(xn,yn)

计算完 l 1 , l 2 , ⋯   , l N l_1,l_2,\cdots,l_N l1,l2,,lN 后,我们既可以一次性将它们全部返回(对应 reduction=none),也可以返回它们的均值(对应 reduction=mean),还可以返回它们的(对应 reduction=sum):

ℓ = { ( l 1 , ⋯   , l N ) , reduction=none ∑ n = 1 N l n / ∑ n = 1 N w y n , reduction=mean ∑ n = 1 N l n , reduction=sum \ell=\begin{cases} (l_1,\cdots,l_N),&\text{reduction=none} \\ \sum_{n=1}^N l_n/\sum_{n=1}^N w_{y_n},&\text{reduction=mean} \\ \sum_{n=1}^N l_n,&\text{reduction=sum} \\ \end{cases} =(l1,,lN),n=1Nln/n=1Nwyn,n=1Nln,reduction=nonereduction=meanreduction=sum

在 NLP 任务中,我们往往将填充词元添加到每个序列的末尾,这样一来不同长度的序列可以进行批量加载。训练过程中,我们不希望网络预测出的填充词元被算入损失函数中。不妨设填充词元在词表中的索引为 i i i,则此时应对 l n l_n ln 作如下修正:

l n = − w y n ⋅ I ( y n ≠ i ) ⋅ log ⁡ exp ⁡ ( x n , y n ) ∑ c = 1 C exp ⁡ ( x n c ) , where    I ( x ) = { 1 , x    is True 0 , x    is False l_n=-w_{y_n}\cdot \mathbb{I}(y_n\neq i)\cdot\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})},\qquad \text{where}\; \mathbb{I}(x)= \begin{cases} 1,&x\; \text{is True} \\ 0,&x\; \text{is False} \end{cases} ln=wynI(yn=i)logc=1Cexp(xnc)exp(xn,yn),whereI(x)={1,0,xis Truexis False

另外,该场景下的 reduction=mean 对应的损失变为

ℓ = ∑ n = 1 N l n ∑ n = 1 N w y n ⋅ I ( y n ≠ i ) \ell=\sum_{n=1}^N\frac{l_n}{\sum_{n=1}^Nw_{y_n}\cdot \mathbb{I}(y_n\neq i)} =n=1Nn=1NwynI(yn=i)ln

📌 需要注意的是,在PyTorch中 y n ∈ { 0 , 1 , ⋯   , C − 1 } y_n\in\{0,1,\cdots,C-1\} yn{0,1,,C1},这里我们之所以用 { 1 , 2 , ⋯   , C } \{1,2,\cdots,C\} {1,2,,C} 是为了更自然地衔接上下文

三、主要参数

nn.CrossEntropyLoss 的主要参数如下:

nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0)

⚠️ size_averagereduce 参数已经弃用,取而代之的是 reduction 参数,所以这里不再讲解


有了前面的铺垫,我们就可以很容易理解这些参数了:

  • weight:长度为 C C C 的张量,一般在数据不平衡时才会使用;
  • ignore_index:需要忽略的类别的索引,默认为 − 100 -100 100,即不忽略;
  • reduction:决定以何种形式返回损失。为 none 时返回 N N N 个样本的损失,为 mean 时返回 N N N 个样本的损失均值,为 sum 时返回 N N N 个样本的损失的和。默认为 mean
  • label_smoothing:决定是否开启标签平滑(不了解标签平滑的读者可参考这篇文章),数值在 [ 0 , 1 ] [0,1] [0,1] 内。默认为 0 0 0,即不开启。

3.1 输入与输出

输入分为 inputtargetinput 通常为 ( N , C ) (N,C) (N,C) 的形状(即 batch_size × num_classes),target 通常为 ( N , ) (N,) (N,) 的形状,其中的每个分量均位于 [ 0 , C − 1 ] ∩ Z [0,C-1] \cap \mathbb{Z} [0,C1]Z 中,代表样本属于的类别。

📌 inputtarget 还可以是其他类型的输入,但本文只讨论这种使用最为广泛的输入
📌 input 是神经网络的原始输出(未经过 Softmax),nn.CrossEntropyLoss 会自动对其应用 Softmax

torch.manual_seed(0)
batch_size = 3
num_classes = 5
criterion_1 = nn.CrossEntropyLoss(reduction='none')
criterion_2 = nn.CrossEntropyLoss()
criterion_3 = nn.CrossEntropyLoss(reduction='sum')

inputs = torch.randn(batch_size, num_classes)  # 避免与input关键字冲突(当然这无所谓)
target = torch.randint(num_classes, size=(batch_size, ))

print(criterion_1(inputs, target))  # 输出3个样本的loss
# tensor([1.4639, 3.0493, 2.3056])
print(criterion_2(inputs, target))  # 输出3个样本的loss的均值
# tensor(2.2729)
print(criterion_3(inputs, target))  # 输出3个样本的loss的和
# tensor(6.8188)

print(sum(criterion_1(inputs, target)) == criterion_3(inputs, target))
# tensor(True)
print(sum(criterion_1(inputs, target)) / batch_size == criterion_2(inputs, target))
# tensor(True)

四、从零开始实现 nn.CrossEntropyLoss

为了加深理解,接下来我们从零开始实现 nn.CrossEntropyLoss(当然会和官方不同,为了追求可读性会采用傻瓜式实现)。

首先确定框架(为简便起见这里不考虑 label_smoothing):

class CrossEntropyLoss(nn.Module):

    def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
        super().__init__()
        self.weight = weight
        self.ignore_index = ignore_index
        self.reduction = reduction
        
    def forward(self, inputs, target):
        pass

为方便计算,我们对第二章节的损失计算公式进行改写

l n = w y n ⋅ I ( y n ≠ i ) ⋅ [ − x n , y n + log ⁡ ∑ c = 1 C exp ⁡ ( x n c ) ] l_n=w_{y_n}\cdot \mathbb{I}(y_n\neq i)\cdot[-x_{n,y_n}+\log\sum_{c=1}^C\exp(x_{nc})] ln=wynI(yn=i)[xn,yn+logc=1Cexp(xnc)]

采用更符合 Python 的表述方式来改写上式

l n = w [ y n ] ⋅ I ( y n ≠ i ) ⋅ [ − x n [ y n ] + log ⁡ ∑ c = 1 C exp ⁡ ( x n [ c ] ) ] l_n=\boldsymbol{w}[y_n]\cdot \mathbb{I}(y_n\neq i)\cdot[-\boldsymbol{x_n}[y_n]+\log\sum_{c=1}^C\exp(\boldsymbol{x_n}[c])] ln=w[yn]I(yn=i)[xn[yn]+logc=1Cexp(xn[c])]

其中 w = ( w 1 , ⋯   , w C ) ,    x n = ( x n 1 , ⋯   , x n C ) \boldsymbol{w}=(w_1,\cdots,w_C),\;\boldsymbol{x_n}=(x_{n1},\cdots,x_{nC}) w=(w1,,wC),xn=(xn1,,xnC)。再令 X = ( x 1 ; ⋯   ; x N ) ,    y = ( y 1 , ⋯   , y C ) {\bf X}=(\boldsymbol{x_1};\cdots;\boldsymbol{x_N}),\;\boldsymbol{y}=(y_1,\cdots,y_C) X=(x1;;xN),y=(y1,,yC),则显然 X {\bf X} X 就是我们的 input y \boldsymbol{y} y 就是 target,于是我们可以进行批量计算

( l 1 , ⋯   , l N ) = w [ y ] ∗ I ( y ≠ i ) ∗ ( − X [ range ( len ( y ) ) ,   y ] + log ⁡ ( sum ( exp ⁡ ( X ) ,   dim = 1 ) ) ) (l_1,\cdots,l_N)=\boldsymbol{w}[\boldsymbol{y}] *\mathbb{I}(\boldsymbol{y}\neq i)* (-{\bf X}[\text{range}(\text{len}(\boldsymbol{y})),\,\boldsymbol{y}]+\log(\text{sum}(\exp({\bf X}),\,\text{dim}=1))) (l1,,lN)=w[y]I(y=i)(X[range(len(y)),y]+log(sum(exp(X),dim=1)))

其中 ∗ * 代表按元素相乘。上式采用了广播机制。

class CrossEntropyLoss(nn.Module):

    def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
        super().__init__()
        self.weight = weight
        self.ignore_index = ignore_index
        self.reduction = reduction

    def forward(self, inputs, target):
        if self.weight is not None:
            n_samples_weight = self.weight[target]  # 每个样本的权重
        else:
            n_samples_weight = torch.ones_like(target).float()  # 不提供权重则默认全为1
        indicator = (target != self.ignore_index).long().float()  # long()方法可以将布尔型张量转化成0-1张量
        raw_loss = -inputs[torch.arange(len(target)), target] + torch.log(torch.sum(torch.exp(inputs), dim=1))
        result = n_samples_weight * indicator * raw_loss
        if self.reduction == 'mean':
            return torch.sum(result) / n_samples_weight.dot(indicator)
        elif self.reduction == 'sum':
            return torch.sum(result)
        else:
            return result

输出结果与 PyTorch 官方的 nn.CrossEntropyLoss 的完全相同,这里不再展示,读者可自行验证。