InputProcessor

InputProcessor 承载了不同任务的数据准备过程。

class InputProcessor:

    def __init__(self, device_mesh: Optional[DeviceMesh] = None,
                 padding_free: bool = False,
                 framework: Literal['transformers', 'megatron'] = 'transformers',
                 **kwargs):
        ...

    def __call__(self, inputs: Union[InputFeature, List[InputFeature]], **kwargs) -> Union[InputFeature, List[InputFeature]]:
        # 整体处理的入口
        ...

    def prepare_inputs(self, inputs: Union[List[InputFeature], InputFeature], **kwargs) -> List[InputFeature]:
        # 移动到 cuda 设备上
        ...

    def pad_cp(self, inputs: List[InputFeature], **kwargs) ->List[InputFeature]:
        # 处理 cp
        ...

    def split_cp(self, inputs: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any]]:
        # 处理 cp
        ...

    def collate_fn(self, inputs: List[InputFeature], micro_batch_size: Optional[int] = None,
                   variable_seq_lengths=False, **kwargs) -> List[InputFeature]:
        # data_collator
        ...
  • device_mesh: 用于切分 cp。如果没有 cp,device_mesh 参数可以不传。

  • padding_free: 是否将多个样本拼接为一个,这个功能和 PackingDataset 比较相似,但 PackingDataset 会让每个 batch 长度基本一致,而 padding_free 仅考虑本 batch 内部的拼接。

    • 使用 PackingDataset 会自动在 InputProcessor 内触发 padding_free

  • framework: 支持 transformers 和 megatron。不同的模型架构返回的模型输入略有不同

Twinkle 将 collate_fn 放入 InputProcessor 中,因为不同的任务(sft/grpo 等)对输入需求是不同的。目前 InputProcessor 默认执行在模型端,因为这样可以将 DataLoader 和模型进行解耦。 因为 collate_fn 和运行任务、Megatron 的 micro_batch_size 等信息有关,如果在 DataLoader 中运行,会导致 DataLoader 无法独立成为组件,其逻辑也会变得复杂。

InputProcessor 实现了 call 方法,因此你可以使用自己的 function 来完成自己的任务数据准备流程:

def my_processor(inputs: Union[InputFeature, List[InputFeature]]) -> Union[InputFeature, List[InputFeature]]:
    return ...

model.set_processor(my_processor)
# 或者
dataloader.set_processor(my_processor)