RemoteClass
All components in Twinkle that support use in Ray and HTTP are decorated with @remote_class and @remote_function. This decorator intercepts the construction of the class and, in Ray mode, converts the class construction to worker execution.
from twinkle import remote_class, remote_function
@remote_class(execute='first')
class MyComponent:
def __init__(self, **kwargs):
...
@remote_function(dispatch='slice_dp', collect='first')
def func(self, *args, **kwargs):
...
return ...
Developers only need to write the above code to transfer the MyComponent class to worker execution. Among them:
remote_class: Marks the class as needing remote execution. If Twinkle initialization is set to
localmode, or if the class construction does not pass in aremote_groupsetting, or ifremote_groupis the current worker, the class will be constructed within the process.remote_function: Marks a method of a class marked with
remote_classas executable in Ray. Its input and output will be compressed and passed by Ray.
Calling MyComponent:
import twinkle
from twinkle import DeviceGroup
device_groups = [
DeviceGroup(
name='default',
ranks=4,
device_type='cuda',
)
]
twinkle.initialize('ray', groups=device_groups)
_my_component = MyComponent(remote_group='default')
_my_component.func(...)
In this way, we wrote a MyComponent and constructed a group called default using 4 GPUs in the Ray cluster, and constructed MyComponent in that group.
Parameters when remote_class decorates a class:
execute: Supports first/all. first will only be created on the 0th device of the group, generally used for the construction of Dataset and DataLoader. all will be constructed on all devices.
Parameters when remote_function decorates a method:
dispatch: How to distribute input data. Supports four types: slice/all/slice_dp/function. slice will evenly distribute list input (non-list will be fully distributed), all performs full distribution, slice_dp will split and distribute the input data according to the dp group of device_mesh to ensure the correctness of model input data. The function method supports distributing input data with your own implementation:
def _dispatcher(length, i, args, kwargs, device_mesh):
# length is the number of workers, i is the current rank, args and kwargs are input data, execute the distribution logic here
# device_mesh is the device_mesh belongs to the target component
return _args_rank, _kwargs_rank
execute: Supports first/all, execute only on the first worker, or execute on all
collect: How to collect returned data, supports none/flatten/mean/sum/first/last_pp/function
none: Do not process anything
flatten: Flatten all worker data to mimic the return structure of single worker execution
mean/sum: Return average or cumulative value
first: Only return the result of the first worker. Generally used when all workers need input, but the output results are the same
last_pp: Return the result of the last pipeline, used for pp parallelism
function: Supports custom collection methods
def _collect(all_results: List, device_mesh):
# device_mesh is the device_mesh belongs to the target component
return ...
sync: Whether to execute synchronously using Ray’s method, default is
Falselazy_collect: Default is True. In this case, results will not be collected in the driver process, but will be delayed and expanded in the workers that need these results. For specific methods, some methods need to be collected in the driver, such as collecting loss, metrics and other situations with small network load, which can be set to False