问题

有的时候我们需要计算多个batch的CrossEntropyLoss, 如下面的代码片段

....
criterion = nn.CrossEntropyLoss()

....

for input, target in self.dataloader:
            optimizer.zero_grad()

            .....
            # output shape (5,4,14)
            # target shape (5,4)
            loss = criterion(output, target)

从官网上的例子来看, 一般input为(Number of Batch, Features), 而target一般为 (N,)

上面是一个batch的执行,但是在一些实际的训练过程中,可能是多个batch

如果直接执行开头的代码,会抛出如下错误
ValueError: Expected target size (5, 14), got torch.Size([5, 4])

这是因为开头的例子是一个nlp任务,input的shape是(5,4,14), 即(Number of Batch, Sequence length, Embedding size),这里多处一维,。


分析

把output和target的数据通过debug获取出来单独计算尝试一下,下面的代码中,同时我使用numpy自己实现了一遍CrossEntropyLoss的计算,可以直接跳过查看最后调用nn.CrossEntropyLoss的部分。

import torch
import numpy as np


def my_softmax(x):
    output = np.zeros(x.shape)
    for n in range(x.shape[0]):
        exp_x = np.exp(x[n, :])
        output[n] = (exp_x / np.sum(exp_x))

    return output


def my_log_softmax(x):
    return np.log(my_softmax(x))


def my_nll_loss(P, Y, reduction='mean', ignore_index=-100):
    loss = []
    for n in range(len(Y)):
        if Y[n] == ignore_index:
            loss.append(-0.)
            continue
        p_n = P[n][Y[n]]
        loss.append(p_n.item())

    if reduction == 'mean':
        return -np.mean(loss)

    if reduction == 'sum':
        return -np.sum(loss)

    return -np.array(loss)


def batch_cross_entropy():
    # [B,S,E]
    output = np.array([[[-6.9800e-01, 6.8742e-01, 2.5055e-01, -6.6209e-01, -4.6491e-01,
                         -1.3935e-01, -1.7100e-01, 4.0013e-02, -3.6995e-01, -8.5358e-01,
                         -4.9449e-01, -4.5180e-01, -2.7848e-01, -1.1511e+00],
                        [-7.7217e-01, 5.0190e-01, 3.3348e-01, -4.0213e-01, -4.6606e-01,
                         -6.0082e-02, 4.7225e-01, 1.4079e-01, -1.7741e-01, -7.9565e-01,
                         -5.7972e-01, -4.8082e-01, -1.8605e-02, -9.5264e-01],
                        [-6.8221e-01, 3.7776e-01, 3.4762e-02, -6.9478e-01, -2.2510e-01,
                         3.0994e-01, -1.3499e-01, -1.6287e-01, -1.6151e-01, -2.4974e-01,
                         -4.6694e-01, -6.1922e-01, 2.4364e-01, -9.0690e-01],
                        [-8.0960e-01, 5.0074e-01, -1.8677e-01, -7.8651e-01, -4.1738e-01,
                         4.1874e-01, -2.3718e-01, -2.1826e-01, -3.3325e-01, -9.2656e-02,
                         -4.6586e-01, -8.4838e-01, 1.6432e-01, -6.5928e-01]],

                       [[-2.4560e-01, 6.9763e-01, 1.8138e-01, -3.2625e-02, -2.4262e-01,
                         -2.5643e-01, 1.1205e-01, 2.4543e-02, -4.4613e-01, -1.0645e+00,
                         -3.6831e-01, -4.1188e-02, -2.0788e-02, -1.0442e+00],
                        [-2.8846e-01, 8.2847e-01, -5.4134e-02, -7.8471e-01, 1.3351e-02,
                         -7.4033e-01, -6.3344e-01, -3.5146e-01, -8.5599e-01, -1.0859e+00,
                         -1.6991e-01, 4.7074e-02, 1.0111e-01, -5.1003e-01],
                        [-6.1263e-01, 7.3131e-01, 5.7170e-01, -3.8304e-02, 2.6139e-02,
                         -1.1358e-01, 5.1920e-01, 3.4961e-01, -2.8680e-01, -8.5890e-01,
                         -5.1087e-01, -3.2754e-01, 2.2287e-01, -6.6090e-01],
                        [-5.7762e-01, -1.6064e-01, -5.4849e-01, -5.2790e-02, -3.1316e-01,
                         5.7697e-01, 1.8820e-01, 1.9771e-03, 2.3494e-01, 4.6401e-02,
                         -6.0379e-01, -5.6362e-01, 1.0715e-01, -6.7643e-01]],

                       [[-7.5844e-01, 8.9643e-01, 4.2627e-02, -3.2765e-01, -3.2391e-01,
                         -3.7126e-01, 1.3792e-02, 1.6282e-03, -5.8745e-01, -4.6443e-01,
                         -2.7597e-01, -3.4279e-01, 1.0330e-03, -6.5268e-01],
                        [-6.7271e-01, 8.8120e-01, 4.4617e-01, -9.2040e-01, -3.0459e-01,
                         -3.1417e-01, -3.9815e-01, 1.0694e-01, -7.2992e-01, -5.3737e-01,
                         -1.6901e-01, -3.7259e-01, 9.2190e-02, -9.0215e-01],
                        [-6.4774e-01, 7.2040e-01, 7.7526e-01, -1.0923e+00, -8.9171e-02,
                         -6.2309e-05, 3.4601e-01, -6.7397e-02, -5.2992e-01, -4.7396e-01,
                         -2.0592e-01, -2.9428e-01, 2.7567e-01, -1.0032e+00],
                        [-9.6423e-01, 6.1445e-01, -6.5032e-01, -5.5757e-01, -6.0174e-01,
                         -1.6667e-01, 1.9756e-01, -5.3273e-01, -2.6795e-01, -1.6678e-01,
                         -4.7283e-01, -7.7119e-01, 8.7784e-02, -4.2825e-01]],

                       [[-2.8459e-01, 6.0364e-01, 5.0745e-01, -1.1500e-01, -2.8906e-01,
                         -2.1891e-01, 3.1818e-01, 2.6412e-01, -3.1559e-01, -9.2631e-01,
                         -2.5491e-01, -1.3816e-02, -2.7776e-01, -1.3621e+00],
                        [-6.3529e-01, 8.0968e-01, 5.9280e-01, -6.2296e-01, -3.4726e-01,
                         -1.6531e-01, 6.7529e-02, 3.7592e-01, -7.3573e-01, -1.0816e+00,
                         -3.1254e-01, -4.2386e-01, -2.4192e-01, -1.1896e+00],
                        [-7.9503e-01, 5.1963e-01, 5.1673e-01, -6.4723e-01, -8.6342e-02,
                         -2.1490e-01, 2.7284e-02, 2.6488e-01, -7.0478e-01, -1.1432e+00,
                         -2.9212e-01, -5.3028e-01, -4.8153e-01, -8.5909e-01],
                        [-7.9562e-01, 5.3502e-01, 1.2687e-01, -6.4034e-01, -1.4381e-01,
                         1.0957e-01, 2.4598e-02, 2.3910e-02, -6.8106e-01, -5.3939e-01,
                         -2.7420e-01, -4.9182e-01, 5.0746e-02, -8.6493e-01]],

                       [[-2.5208e-01, 9.5292e-02, 1.4688e-01, 4.0238e-01, -3.0913e-01,
                         -2.0094e-02, 3.9704e-01, 5.1999e-01, 1.2463e-01, -6.6643e-01,
                         -4.4233e-01, 4.3938e-03, -3.6015e-01, -1.0695e+00],
                        [-4.2988e-01, 3.2485e-01, 1.2833e-01, -7.1189e-01, -1.7690e-01,
                         -3.1612e-01, -4.5157e-01, -1.4707e-01, -2.3045e-01, -9.6345e-01,
                         -3.4908e-01, -4.5350e-01, -1.7349e-01, -7.9216e-01],
                        [-6.3809e-02, -5.2756e-02, -2.1734e-01, -1.5490e-01, -5.1187e-02,
                         -2.3425e-01, -3.4012e-01, -1.7033e-01, 5.0935e-02, -2.8938e-01,
                         -6.8729e-02, -2.7069e-01, -3.3257e-01, -4.0449e-01],
                        [-4.5155e-01, 1.0152e-01, -4.5864e-01, -2.4100e-01, -3.2433e-01,
                         3.0919e-01, 1.1523e-01, -3.3954e-01, 2.0666e-01, -1.9090e-01,
                         -4.4507e-01, -8.4536e-01, 2.5585e-01, -6.3963e-01]]])

    # [B, S]
    target = np.array([[0, 4, 1, -100],
                       [0, 8, 4, 1],
                       [0, 8, 7, 1],
                       [0, 11, 5, 1],
                       [0, 8, 6, 1]])

    my_crossentropy = np.zeros((output.shape[0], output.shape[1]))
    for i in range(output.shape[0]):
        my_crossentropy[i] = my_nll_loss(my_log_softmax(output[i]), target[i], reduction='none')
    
    # 注意这里一定要将reduction改为none,如果采用默认mean,那么所有的值都会混合到一起做平均
    # 这有时是合理的,有的时候却不是;所以最好的方式是自己做reduction
    criterion = torch.nn.CrossEntropyLoss(reduction='none')
    loss = criterion(torch.from_numpy(output).permute(0, 2, 1), torch.from_numpy(target).long())

    print("my_crossentropy:", my_crossentropy)
    print("crossentropy:", loss)


batch_cross_entropy()

【pytorch】在多个batch中如何使用nn.CrossEntropyLoss-LMLPHP

这里需要把index标记为-100的去处计算,所以在做reduction的时候需要单独处理一下。


参考

12-14 07:22