构建新的 Loss

Twinkle 中的 loss 基类定义为:

class Loss:

    def __call__(self, inputs: InputFeature, outputs: ModelOutput, **kwargs):
        ...

损失的输入为模型的 InputFeature,输出为模型标准 ModelOutput,kwargs 可以在模型的 calculate_loss 中传入。由于它是一个带有 __call__ 方法的类,因此开发者也可以使用 Callable:

def my_loss(inputs: InputFeature, outputs: ModelOutput, extra_data1: int, extra_data2: dict):
    ...
    return loss

在模型中这样使用:

model.set_loss(my_loss)
model.calculate_loss(extra_data1=10, extra_data2={})

你也可以将 Loss 上传到 ModelScope/Hugging Face 的 Hub 中,在使用时动态拉取:

model.set_loss('ms://my_group/my_loss')

具体可以参考插件文档的介绍。