交叉熵

交叉熵是模型SFT和PT训练中最常用的一类损失。用于对labels的精确概率拟合。

class CrossEntropyLoss(Loss):

    def __init__(self, **kwargs):
        self.reduction = kwargs.get('reduction', 'mean')

    def __call__(self, inputs, outputs, **kwargs):
        import torch
        logits = outputs['logits'].view(-1, outputs['logits'].shape[-1])
        labels = inputs['labels'].view(-1)
        return torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels)

构造中可以传入reduction参数,支持sum, mean, none等(和torch.nn.CrossEntropyLoss输入相同)。

在Transformers模型中目前使用sum。目的是在optimizer.step之前统计有效token数量并在grad层面取单token平均。