Sampler
Sampler (采样器) 是 Twinkle 中用于生成模型输出的组件,主要用于 RLHF 训练中的样本生成。采样器支持多种推理引擎,包括 vLLM 和原生 PyTorch。
基本接口
class Sampler(ABC):
@abstractmethod
def sample(
self,
inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
sampling_params: Optional[SamplingParams] = None,
adapter_name: str = '',
*,
num_samples: int = 1,
) -> List[SampleResponse]:
"""对给定输入进行采样"""
...
def add_adapter_to_model(self, adapter_name: str, config_or_dir, **kwargs):
"""添加 LoRA 适配器"""
...
def set_template(self, template_cls: Union[Template, Type[Template], str], **kwargs):
"""设置模板"""
...
采样器的核心方法是 sample,它接受输入数据并返回生成的样本。
可用的采样器
Twinkle 提供了两种采样器实现:
vLLMSampler
vLLMSampler 使用 vLLM 引擎进行高效推理,支持高吞吐量的批量采样。
高性能: 使用 PagedAttention 和连续批处理
LoRA 支持: 支持动态加载和切换 LoRA 适配器
多样本生成: 可以为每个 prompt 生成多个样本
Tensor Parallel: 支持张量并行加速大模型推理
详见: vLLMSampler
TorchSampler
TorchSampler 使用原生 PyTorch 和 transformers 进行推理,适合小规模采样或调试。
简单易用: 基于 transformers 的标准接口
灵活性高: 容易定制和扩展
内存占用小: 适合小规模采样
详见: TorchSampler
如何选择
vLLMSampler: 适合生产环境和大规模训练,需要高吞吐量
TorchSampler: 适合调试、小规模实验或自定义需求
在 RLHF 训练中,采样器通常与 Actor 模型分离,使用不同的硬件资源,避免推理和训练相互干扰。