模型输出

模型输出的详细类型定义。

OutputType

OutputType 定义了模型输出支持的数据类型:

OutputType = Union[np.ndarray, 'torch.Tensor', List[Any]]

支持 NumPy 数组、PyTorch 张量或任意类型的列表。

ModelOutput

ModelOutput 是 Twinkle 用于表示模型输出的标准类。该类适配于 transformers/megatron 等模型结构。

class ModelOutput(TypedDict, total=False):
    logits: OutputType
    loss: OutputType

ModelOutput 本质上是一个 Dict。其字段来自于模型的输出和 loss 计算。

  • logits: 一般是 [BatchSize * SequenceLength * VocabSize] 尺寸,和 labels 配合计算 loss

  • loss: 实际残差

ModelOutput 是 Twinkle 中所有模型输出的标准接口。

使用示例:

from twinkle.data_format import ModelOutput

# 在模型的 forward 方法中
def forward(self, inputs):
    ...
    return ModelOutput(
        logits=logits,
        loss=loss
    )

注意:ModelOutput 使用 TypedDict 定义,意味着它在运行时是一个普通的 dict,但在类型检查时会提供类型提示。