DataLoader
DataLoader 是 PyTorch 中用于加载处理后的数据集,并提供数据给模型的组件。该组件的工作流程为:
传入数据集 -> 构建 sampler 和 batch_sampler -> 索引数据 -> 调用 sampler 拿到索引 -> 从 dataset 中取出一个 batch -> 进行 collate_fn 操作 -> 吐出数据
DataLoader 的整体工作方式类似于:
for data in dataloader:
...
可以看出 dataloader 包含 __iter__ 方法,返回一个迭代器出来。在 DDP、TP、Ulysses 等不同训练条件下,由于每个 rank 取出的数据不同,因此一般 sampler 有多种实现,较为复杂。
在 Twinkle 中,我们采取了一个非常简单直接的方案,将 DeviceMesh 传递给 DataLoader,由于 DeviceMesh 中包含了集群结构,因此 DeviceMesh 可以给出所有 rank 需要的数据分片。
因此我们额外开发了 DeviceMeshSampler 和 DeviceMeshFetcher,分别用于普通数据集和流式数据集两类的取样工作。
另外,由于 LazyDataset 的存在,导致数据集实际取出数据时可能包含了无效数据或者抛出异常,因此提供了 RetrySampler 来进行跳过和重试。
DataLoader 的使用非常简单:
dataloader = DataLoader(dataset)
for data in dataloader:
...
在 torchrun 条件下,由于整体同构,因此全局只需要一个 device_mesh,这个参数无需通过 DataLoader 的构造传入,infra 模块会自动分析并传入。
DataLoader 也支持在 Ray 模式下工作:
def create_dataset():
dataset = Dataset(...)
dataset.map(...)
dataset.encode(...)
return dataset
dataloader = DataLoader(create_dataset, device_mesh=actor_device_mesh, remote_group='actor')
for data in dataloader:
...
DataLoader 的 dataset 参数可以传入一个 Callable 来返回一个 Dataset,这样可以做到数据集的构建代码放在 driver 中,但实际的构建在 Dataloader 的 worker 中,防止了跨进程的 pickle,提高速度。
dataloader 的 @remote_class 装饰器的执行范围也是 first,这意味着它只会有一个 worker 用来取出数据。
开发者无需担心 dataloader 返回的 data 占用 driver 内存,data 通常是一个引用句柄,到了需要使用的 worker 才会实际传递并解包。 Dataloader 默认不设置任何的 collate_fn,而是将这个过程交由模型处理。