RemoteClass

所有 Twinkle 中支持 Ray 和 HTTP 中使用的组件均通过 @remote_class@remote_function 进行了装饰。该装饰器会拦截类的构造,在 Ray 模式下,将类的构造转为 worker 执行。

from twinkle import remote_class, remote_function

@remote_class(execute='first')
class MyComponent:

    def __init__(self, **kwargs):
        ...

    @remote_function(dispatch='slice_dp', collect='first')
    def func(self, *args, **kwargs):
        ...
        return ...

开发者只需要编写上述代码,就可以将 MyComponent 类转入 worker 执行。其中:

  • remote_class: 将类标记为需要远端执行。如果 Twinkle 初始化设置为 local 模式,或者该类构造时没有传入 remote_group 设置,或者 remote_group 为当前 worker,都会在进程内构造该类。

  • remote_function: 将某个标记了 remote_class 的方法标记为可以在 Ray 中执行。其输入和输出均会被 Ray 压缩传递。

调用 MyComponent

import twinkle
from twinkle import DeviceGroup

device_groups = [
    DeviceGroup(
        name='default',
        ranks=4,
        device_type='cuda',
    )
]

twinkle.initialize('ray', groups=device_groups)

_my_component = MyComponent(remote_group='default')
_my_component.func(...)

通过这种方式,我们编写了一个 MyComponent,并在 Ray 集群中使用 4 张卡构造了一个叫 default 的组,把 MyComponent 构造在了该组中。

remote_class 在装饰类的时候的参数:

  • execute: 支持 first/all。first 仅会在该组的第 0 个设备上创建,一般用于 Dataset、DataLoader 的构造,all 会在所有设备上构造。

remote_function 在装饰方法的时候有下面的参数:

  • dispatch: 如何分发输入数据。支持 slice/all/slice_dp/函数 四种。slice 会将 list 输入均匀分发(非 list 会全部分发),all 进行全部分发,slice_dp 会将输入数据按照 device_mesh 的 dp 组进行切分分发,来保障模型输入数据的正确性,函数方式支持以自己的实现来分发输入数据:

def _dispatcher(length, i, args, kwargs, device_mesh):
    # length 是 worker 数量,i 是当前 rank,args 和 kwargs 是输入数据,在这里具体执行分发逻辑
    # device_mesh是隶属于目标组件的device_mesh
    return _args_rank, _kwargs_rank
  • execute: 支持 first/all,仅在第一个 worker 上执行,还是全部执行

  • collect: 如何收集返回的数据,支持 none/flatten/mean/sum/first/last_pp/函数

    • none: 不做任何处理

    • flatten: 将所有 worker 数据进行拉平,模仿单一 worker 执行的返回结构

    • mean/sum: 返回均值或累加值

    • first: 仅返回第一个 worker 的结果。一般用于所有 worker 需要输入,但输出结果相同的情况

    • last_pp: 返回最后一个 pipeline 的结果,用于 pp 并行的情况

    • 函数: 支持自定义收集方法

def _collect(all_results: List, device_mesh):
    # device_mesh是隶属于目标组件的device_mesh
    return ...
  • sync: 是否以 Ray 的同步方式执行,默认为 False

  • lazy_collect: 默认为 True,在这种情况下,会不在 driver 进程中收集结果,而在需要这些结果的 worker 中延迟展开,对于具体方法来说,某些方法需要在 driver 中收集,例如收集 loss、metric 等网络负载不大的情况,可以设置为 False