Sampling Output
Sampling output is a data format used to represent input parameters and return results of the sampling process.
SamplingParams
Sampling parameters are used to control the model’s sampling behavior.
@dataclass
class SamplingParams:
max_tokens: Optional[int] = None
seed: Optional[int] = None
stop: Union[str, Sequence[str], Sequence[int], None] = None
temperature: float = 1.0
top_k: int = -1
top_p: float = 1.0
repetition_penalty: float = 1.0
max_tokens: Maximum number of tokens to generate
seed: Random seed
stop: Stop sequences, can be a string, sequence of strings, or sequence of token ids
temperature: Temperature parameter controlling sampling randomness. 0 means greedy sampling
top_k: Top-K sampling parameter, -1 means not used
top_p: Top-P (nucleus) sampling parameter
repetition_penalty: Repetition penalty coefficient
Conversion Methods
SamplingParams provides conversion methods to adapt to different inference engines:
# Convert to vLLM's SamplingParams
vllm_params = params.to_vllm(num_samples=4, logprobs=True, prompt_logprobs=0)
# Convert to transformers' generate parameters
gen_kwargs = params.to_transformers(tokenizer=tokenizer)
SampleResponse
Sample response is the result data structure returned by the sampler.
@dataclass
class SampleResponse:
trajectories: List[Trajectory]
logprobs: Optional[List[List[float]]] = None
prompt_logprobs: Optional[List[List[float]]] = None
stop_reason: Optional[List[StopReason]] = None
trajectories: List of generated trajectories
logprobs: Log probabilities of generated tokens
prompt_logprobs: Log probabilities of prompt tokens
stop_reason: Stop reason, can be “length” (reached max length) or “stop” (encountered stop sequence)
Usage example:
from twinkle.data_format import SamplingParams, SampleResponse
from twinkle.sampler import vLLMSampler
sampler = vLLMSampler(model_id='ms://Qwen/Qwen3.5-4B')
params = SamplingParams(max_tokens=512, temperature=0.7, top_p=0.9)
response: SampleResponse = sampler.sample(trajectories, sampling_params=params, num_samples=4)
# Access generated trajectories
for traj in response.trajectories:
print(traj.messages)