模型输出
模型输出的详细类型定义。
OutputType
OutputType 定义了模型输出支持的数据类型:
OutputType = Union[np.ndarray, 'torch.Tensor', List[Any]]
支持 NumPy 数组、PyTorch 张量或任意类型的列表。
ModelOutput
ModelOutput 是 Twinkle 用于表示模型输出的标准类。该类适配于 transformers/megatron 等模型结构。
class ModelOutput(TypedDict, total=False):
logits: OutputType
loss: OutputType
ModelOutput 本质上是一个 Dict。其字段来自于模型的输出和 loss 计算。
logits: 一般是 [BatchSize * SequenceLength * VocabSize] 尺寸,和 labels 配合计算 loss
loss: 实际残差
ModelOutput 是 Twinkle 中所有模型输出的标准接口。
使用示例:
from twinkle.data_format import ModelOutput
# 在模型的 forward 方法中
def forward(self, inputs):
...
return ModelOutput(
logits=logits,
loss=loss
)
注意:ModelOutput 使用 TypedDict 定义,意味着它在运行时是一个普通的 dict,但在类型检查时会提供类型提示。