模型输入
twinkle用于表示模型输入的类是InputFeature,该类适配于transformers/megatron等模型结构。
class ModelOutput(TypedDict, total=False):
logits: OutputType
loss: OutputType
ModelOutput本质上是一个Dict。其字段来自于模型的输出和loss计算。
logits: 一般是[BatchSize * SequenceLength * VocabSize]尺寸,和labels配合计算loss
loss: 实际残差
ModelOutput是twinkle中所有模型输出的标准接口。