模型输入

twinkle用于表示模型输入的类是InputFeature,该类适配于transformers/megatron等模型结构。

InputType = Union[List[List[int]], List[int], np.ndarray, Any]

class InputFeature(TypedDict, total=False):
    # Text-related fields
    input_ids: InputType
    attention_mask: InputType
    position_ids: InputType
    labels: InputType

InputFeature本质上是一个Dict。其输入来自于Template组件的输出。

  • input_ids: List[Messages]以模板进行嵌套之后的token list

  • attention_mask: 注意力掩膜

  • position_ids: 用于样本区分的位置编码

  • labels: 训练的label,已经进行了一个token的左位移

在packing或padding_free的情况下,input_ids等字段由多个样本的列表拼接而来。 在多模态场景下,InputFeature包含多模态其他字段。

InputFeature是twinkle中所有模板输出、模型输入的标准接口。