采样输出

采样输出是用于表示采样过程的输入参数和返回结果的数据格式。

SamplingParams

采样参数用于控制模型的采样行为。

@dataclass
class SamplingParams:
    max_tokens: Optional[int] = None
    seed: Optional[int] = None
    stop: Union[str, Sequence[str], Sequence[int], None] = None
    temperature: float = 1.0
    top_k: int = -1
    top_p: float = 1.0
    repetition_penalty: float = 1.0
  • max_tokens: 生成的最大 token 数量

  • seed: 随机种子

  • stop: 停止序列,可以是字符串、字符串序列或 token id 序列

  • temperature: 温度参数,控制采样的随机性。0 表示贪心采样

  • top_k: Top-K 采样参数,-1 表示不使用

  • top_p: Top-P (nucleus) 采样参数

  • repetition_penalty: 重复惩罚系数

转换方法

SamplingParams 提供了转换方法来适配不同的推理引擎:

# 转换为 vLLM 的 SamplingParams
vllm_params = params.to_vllm(num_samples=4, logprobs=True, prompt_logprobs=0)

# 转换为 transformers 的 generate 参数
gen_kwargs = params.to_transformers(tokenizer=tokenizer)

SampleResponse

采样响应是采样器返回的结果数据结构。

@dataclass
class SampleResponse:
    trajectories: List[Trajectory]
    logprobs: Optional[List[List[float]]] = None
    prompt_logprobs: Optional[List[List[float]]] = None
    stop_reason: Optional[List[StopReason]] = None
  • trajectories: 采样生成的轨迹列表

  • logprobs: 生成 token 的对数概率

  • prompt_logprobs: prompt token 的对数概率

  • stop_reason: 停止原因,可以是 "length" (达到最大长度) 或 "stop" (遇到停止序列)

使用示例:

from twinkle.data_format import SamplingParams, SampleResponse
from twinkle.sampler import vLLMSampler

sampler = vLLMSampler(model_id='ms://Qwen/Qwen3.5-4B')
params = SamplingParams(max_tokens=512, temperature=0.7, top_p=0.9)
response: SampleResponse = sampler.sample(trajectories, sampling_params=params, num_samples=4)

# 访问生成的轨迹
for traj in response.trajectories:
    print(traj.messages)