快速开始

✨ Twinkle 是什么?

大模型训练组件库。基于 PyTorch,更简洁、更灵活、生产就绪。

🧩 松耦合架构 · 标准化接口
🚀 多运行模式 · torchrun / Ray / HTTP
🔌 多框架兼容 · Transformers / Megatron
👥 多租户支持 · 单基座模型部署

Twinkle 适配性

Twinkle 和 ms-swift 都是模型训练框架,但二者的特性有很大不同,开发者可以根据自己的需求选择。

何时选择 Twinkle

  • 如果你是大模型的初学者,希望更好地了解模型机制和模型训练方法

  • 如果你是大模型研究者,希望定制模型或训练方法

  • 如果你善于编写 training loop,希望定制训练过程

  • 如果你希望提供企业级或商业化训练平台

何时选择ms-swift

  • 如果你不关心训练过程,希望仅提供数据集便可完成训练

  • 如果你需要更多的模型支持和数据集种类

  • 如果你需要Embedding、Reranker、Classification等多种类型的训练

  • 如果你需要推理、部署、量化等其他能力

  • 如果你对新模型的训练支持敏感,Swift 会保证 day-0 的更新能力

模型训练与Twinkle

当你发现通用大模型无法满足你的需求时,训练就成为必选项:

  • 让模型认识你:通过自我认知训练,模型可以回答"你是谁"、"你的开发者是谁"等问题,成为专属于你的 AI 助手。

  • 让模型懂你的业务:使用私有数据微调,模型可以学会你的行业术语、业务流程、内部知识库,成为领域专家。

  • 让模型按你的方式思考:通过强化学习(RL),你可以定义奖励规则,引导模型生成符合你期望的输出格式、推理风格或价值观。

  • 让模型更强:蒸馏大模型的能力到小模型,或通过持续预训练注入新知识,让模型能力持续进化。

训练完成后,你可以将模型部署到自己的服务器,或发布到 ModelScope/Hugging Face 与社区分享,或者通过vLLM等部署架构部署你的服务进行使用。

现有的训练框架可以大致分为三类:

  • 底层框架(如原生 PyTorch):灵活性极高,但需要开发者从零搭建分布式、数据加载、checkpoint 等基础设施,开发成本高、周期长。

  • 高层框架(如 ms-swift、transformers Trainer):开箱即用,只需提供数据集和配置即可完成训练,但训练过程是黑盒,难以定制算法细节。

  • 重型框架(如 Megatron-LM):为超大规模模型设计,支持复杂的并行策略,但学习曲线陡峭,代码侵入性强。

Twinkle 的设计目标是在这三类框架之间找到平衡点:

  1. 保留 training loop 的控制权:开发者可以清晰看到并控制 forward、backward、step 的每一步,便于调试和定制算法。

  2. 提供高内聚的组件抽象:Dataset、Model、Sampler、Loss 等组件各司其职,可独立使用也可组合使用,无需整体接入。

  3. 屏蔽分布式复杂性:无论是单卡、torchrun 还是 Ray 集群,训练代码几乎相同,只需修改初始化参数。

  4. 支持生产级部署:内置多租户、HTTP 服务、权重同步等能力,可直接用于构建企业级训练平台。

使用模式

仅使用部分组件

开发者可以仅使用Twinkle的一部分组件,结合自己的已有代码来完成训练工作。例如,仅使用Dataset&DataLoader:

from twinkle.dataset import PackingDataset, DatasetMeta
from twinkle.dataloader import DataLoader
from twinkle.preprocessor import SelfCognitionProcessor

def train():
    dataset_meta = DatasetMeta(
        dataset_id='ms://swift/self-cognition',
    )

    dataset = PackingDataset(dataset_meta)
    dataset.map(SelfCognitionProcessor(model_name='Twinkle模型', model_author='ModelScope社区'))
    dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B', max_length=512)
    dataset.encode()
    dataset.pack_dataset()

    dataloader = DataLoader(dataset, batch_size=8)
    for data in dataloader:
        print(data)
        """
        {
            "input_ids": [...],
            "position_ids": [...],
            ...
        }
        """
        break

if __name__ == '__main__':
    train()

上面的代码中,使用PackingDataset加载了一个叫做swift/self-cognition的数据集。PackingDataset可以用于将数据进行装箱,保证每个batch的长度都与设置的最大长度相似。 我们在循环中简单地使用了print打印了输出,在实际使用中,你可以在下面继续编写你的自定义训练代码。

Twinkle的所有组件都支持单独拆分使用,可以参考下面章节的组件列表。

单GPU

Twinkle 支持单GPU运行训练。下面是一个例子:

from peft import LoraConfig

from twinkle import get_device_placement, get_logger
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
from twinkle.preprocessor import SelfCognitionProcessor

logger = get_logger()


def train():
    # 1000 samples
    dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
    # Set template to prepare encoding
    dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B')
    # Preprocess the dataset to standard format
    dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
    # Encode dataset
    dataset.encode()
    # Global batch size = 8, for GPUs, so 1 sample per GPU
    dataloader = DataLoader(dataset=dataset, batch_size=8)
    # Use a TransformersModel
    model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B')

    lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')

    # Add a lora to model, with name `default`
    # Comment this to use full-parameter training
    model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
    # Add Optimizer for lora `default`
    model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
    # Add LRScheduler for lora `default`
    model.set_lr_scheduler(
        scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader))
    logger.info(get_device_placement())
    # Print the training config
    logger.info(model.get_train_configs())
    logger.info(f'Total steps: {len(dataloader)}')
    for step, batch in enumerate(dataloader):
        # Do forward and backward
        model.forward_backward(inputs=batch)
        # Step
        model.clip_grad_and_step()
        if step % 20 == 0:
            # Print metric
            metric = model.calculate_metric(is_training=True)
            logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
    model.save(f'last-checkpoint')


if __name__ == '__main__':
    train()

在这个训练代码中,我们构造了一个数据集并拉起了Qwen/Qwen3.5-4B模型,使用all-linear方式加载了lora,并完成了一次训练。在日志中,可以看到loss逐步收敛的过程。

torchrun

Twinkle 支持以 torchrun 模式运行训练。在这种场景下,不需要安装 Ray 相关的依赖。

from peft import LoraConfig

import twinkle
from twinkle import DeviceMesh, get_device_placement, get_logger
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
from twinkle.preprocessor import SelfCognitionProcessor

# Construct a device_mesh, fsdp=4, dp=2
device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2)
# use torchrun mode
twinkle.initialize(mode='local', global_device_mesh=device_mesh)

logger = get_logger()


def train():
    # 1000 samples
    dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
    # Set template to prepare encoding
    dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B')
    # Preprocess the dataset to standard format
    dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
    # Encode dataset
    dataset.encode()
    # Global batch size = 8, for GPUs, so 1 sample per GPU
    dataloader = DataLoader(dataset=dataset, batch_size=8)
    # Use a TransformersModel
    model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B')

    lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')

    # Add a lora to model, with name `default`
    # Comment this to use full-parameter training
    model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
    # Add Optimizer for lora `default`
    model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
    # Add LRScheduler for lora `default`
    model.set_lr_scheduler(
        scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader))
    logger.info(get_device_placement())
    # Print the training config
    logger.info(model.get_train_configs())
    logger.info(f'Total steps: {len(dataloader)}')
    for step, batch in enumerate(dataloader):
        # Do forward and backward
        model.forward_backward(inputs=batch)
        # Step
        model.clip_grad_and_step()
        if step % 20 == 0:
            # Print metric
            metric = model.calculate_metric(is_training=True)
            logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
    model.save(f'last-checkpoint')


if __name__ == '__main__':
    train()

上面的代码中,构造了fsdp2和dp的hybrid并行模式,并使用了八张卡进行训练。可以看到它和单卡训练的代码基本相同,只是使用了DeviceMesh来声明模型布局。

运行时,需要这样拉起训练:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train.py

断点续训

上面的训练循环可以扩展为支持断点续训。完整示例可直接参考 cookbook/transformers/fsdp2.py

保存检查点

model.save(
    checkpoint_name,
    output_dir='./output/fsdp2',
    adapter_name=ADAPTER_NAME,
    save_optimizer=True,                                    # 保存优化器状态
    consumed_train_samples=dataloader.get_state()['consumed_train_samples'],  # 落盘训练进度
)

DataLoader 内部自动追踪已消费样本数,通过 dataloader.get_state() 获取。

恢复训练

from pathlib import Path

RESUME_FROM_CHECKPOINT = './output/fsdp2/last-checkpoint'
RESUME_ONLY_MODEL = False   # True: 仅恢复权重,不恢复优化器/调度器等训练状态
IGNORE_DATA_SKIP = False    # True: 不从 trainer_state.json 跳过已消费数据

if RESUME_FROM_CHECKPOINT:
    checkpoint_path = str(Path(RESUME_FROM_CHECKPOINT).expanduser().resolve())
    progress = model.resume_from_checkpoint(checkpoint_path, resume_only_model=RESUME_ONLY_MODEL)
    if not IGNORE_DATA_SKIP:
        dataloader.resume_from_checkpoint(progress['consumed_train_samples'])

两个开关的组合效果:

RESUME_ONLY_MODEL IGNORE_DATA_SKIP 效果
False(默认) False(默认) 完整续训:恢复权重 + 优化器 + 调度器 + RNG,并跳过已消费数据
True False 仅恢复权重,但仍跳过已消费数据(适合沿用权重、重新开始优化)
True True 仅恢复权重,从数据集开头重新训练

LoRA / adapter vs 全参训练

上述流程默认以 LoRA 为例。全参训练的恢复仅有一处不同——TransformersModel 初始化时,model_id 需要用 checkpoint 路径替代 base model ID:

# LoRA / adapter:base model 从 hub 加载,checkpoint 仅含 adapter 权重和训练状态
model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B')
progress = model.resume_from_checkpoint(resume_path)

# 全参:模型权重已整体保存到 checkpoint,直接将其作为 model_id
model = TransformersModel(model_id=resume_path)
progress = model.resume_from_checkpoint(resume_path)

二者后续的 resume_from_checkpointdataloader.resume_from_checkpoint 调用完全一致。

Ray训练

Ray是多机模型训练和推理场景中常用的调度中间件框架。它针对多模型、多设备的执行和资源管理进行了额外优化, 并支持对接kubernetes系统进行生产化。这样的特性使得它尤其适用于RL、GKD等复杂训练场景中。

Twinkle 支持使用 Ray 进行训练和采样,并且它的代码和上面的训练 API 几乎一致:

import os
from typing import List, Tuple, Dict, Any
from peft import LoraConfig
import twinkle
from twinkle import DeviceMesh, DeviceGroup, get_device_placement
from twinkle.advantage import GRPOAdvantage
from twinkle.checkpoint_engine import CheckpointEngineManager
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model.megatron import MegatronModel
from twinkle.metric import CompletionRewardMetric
from twinkle.preprocessor.llm import GSM8KProcessor
from twinkle.processor import InputProcessor
from twinkle.reward import GSM8KAccuracyReward, GSM8KFormatReward
from twinkle.sampler import vLLMSampler
from twinkle.template import Template

MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS',4))
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
LEARNING_RATE = float(os.environ.get('LR', 1e-5))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 200))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 16)) # global completion-level mini-batch-size
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) # per-device-micro-batch-size (completion-level), batch_size in forward_backward
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
ADAPTER_NAME = 'default'

def create_gsm8k_dataset():
    dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
    dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=2048)
    dataset.map(GSM8KProcessor())
    dataset.encode(add_generation_prompt=True)
    return dataset

def compute_rewards(
    trajectories: List[Dict[str, Any]],
) -> Tuple[List[float], List[float], List[float]]:
    accuracy_reward_fn = GSM8KAccuracyReward()
    format_reward_fn = GSM8KFormatReward()
    accuracy_rewards = accuracy_reward_fn(trajectories)
    format_rewards = format_reward_fn(trajectories)
    total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)]
    return total_rewards, format_rewards, accuracy_rewards

def main():
    # set sampler and model separate to use different gpus
    device_groups = [
        DeviceGroup(name='model',ranks=list(range(MODEL_GPUS)),device_type='GPU'),
        DeviceGroup(name='sampler',ranks=list(range(MODEL_GPUS, NUM_GPUS)),device_type='GPU'),
    ]
    model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
    sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)
    twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)

    lora_config = LoraConfig(target_modules='all-linear', r=32, lora_alpha=64, lora_dropout=0.05)
    model = MegatronModel(model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model', mixed_precision='bf16')
    model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1)
    model.set_optimizer('default', lr=LEARNING_RATE)
    model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE)
    model.set_loss('GRPOLoss', epsilon=0.2)
    model.set_processor(InputProcessor)
    model.set_template('Qwen3_5Template', model_id=MODEL_ID)

    sampler = vLLMSampler(
        model_id=MODEL_ID,
        engine_args={
            'gpu_memory_utilization': 0.8,
            'max_model_len': 4096,
            'max_lora_rank': 32, # save as lora_config
            'enable_lora': True,
        },
        device_mesh=sampler_mesh,
        remote_group='sampler',
    )
    sampler.set_template('Qwen3_5Template', model_id=MODEL_ID)
    ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)
    dataloader = DataLoader(
        dataset=create_gsm8k_dataset,
        batch_size=BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS,
        min_batch_size=BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS,
        device_mesh=model_mesh,
        remote_group='model',
    )
    advantage_fn = GRPOAdvantage()
    metrics = CompletionRewardMetric()
    sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1)
    optim_step = 0
    print(get_device_placement())

    for batch in dataloader:
        if optim_step >= MAX_STEPS:
            break
        metrics.reset()
        global_prompts = batch if isinstance(batch, list) else [batch]
        ckpt_manager.sync_weights(merge_and_sync=False)
        sampler.reset_prefix_cache()
        sample_responses = sampler.sample(
            global_prompts*NUM_GENERATIONS,
            sampling_params,
        )
        all_input_data: List[Dict[str, Any]] = []
        all_old_logps: List[List[float]] = []
        all_completion_lengths: List[int] = []

        for sample_response in sample_responses:
            for sequence in sample_response.sequences:
                all_input_data.append(sequence.new_input_feature)
                all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs])
                all_completion_lengths.append(len(sequence.tokens))
        total_rewards, format_rewards, accuracy_rewards = compute_rewards(
            all_input_data
        )
        metrics.accumulate(
            completion_lengths=all_completion_lengths,
            rewards={
                'total': total_rewards,
                'format': format_rewards,
                'accuracy': accuracy_rewards,
            },
        )
        advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist()
        # Split completions into mini-batches and run one optim step per mini-batch.
        total_completions = len(all_input_data)
        for mb_start in range(0, total_completions, MINI_BATCH_SIZE):
            mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions)
            mb_inputs = all_input_data[mb_start:mb_end]
            mb_old_logps = all_old_logps[mb_start:mb_end]
            mb_advantages = advantages[mb_start:mb_end]

            model.forward_backward(
                inputs=mb_inputs,
                old_logps=mb_old_logps,
                advantages=mb_advantages,
                micro_batch_size=MICRO_BATCH_SIZE,
            )
            model.clip_grad_and_step()
            optim_step += 1

            if optim_step >= MAX_STEPS:
                break
            log_dict = metrics.calculate()
            log_dict.update(model.calculate_metric(is_training=True))
            metrics.reset()
            print(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}')

    print(f'Training completed. optim_steps={optim_step}')
    model.save('grpo-gsm8k-checkpoint')

if __name__ == '__main__':
    main()

在上面的代码中,我们给出了一个RL的训练代码。我们可以在代码中清晰看到数据如何构造、sampler/model如何声明和传参,以及advantage和loss的构造过程。 这个过程没有任何显式引用 ray 的地方。我们仅在初始化时声明了 ray 模式:

twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)

开发者可以定制模型等组件的构造和调用方式,所有 Transformers、Megatron 的模型参数都可以在构造模型时传入。

后面所有的 ray 调用和数据分发,都是隐式进行的。运行这个脚本需要提前安装好 Ray。之后这样运行:

python train.py

远程训练

client-server 训练场景同样支持断点续训。推荐流程是调用 model.resume_from_checkpoint(resume_path) 恢复权重和优化器状态,再调用 dataloader.resume_from_checkpoint(progress['consumed_train_samples']) 跳过已消费数据。详细示例可参考 Twinkle客户端self_cognition.py

Twinkle 的一大特色是支持多租户用户混合训练。具体来说,多个用户可以使用一个基模进行 LoRA 训练,这样可以极大减小服务端部署成本。

假设我们使用八卡开启一个服务。首先我们需要启动ray集群:

CUDA_VISIBLE_DEVICES=0,1 ray start --head --port=6379 --num-gpus=2
CUDA_VISIBLE_DEVICES=2,3 ray start --address=127.0.0.1:6379 --num-gpus=2
CUDA_VISIBLE_DEVICES="" ray start --address=127.0.0.1:6379 --num-gpus=0

我们启动了一组包含三个 node 的 Ray 集群:

  • 0、1 两张卡作为一个 node

  • 2、3 两张卡作为一个 node

  • CPU 资源作为一个 node

如果在生产环境使用,可以启动更多 node,并部署更多 replica 以兼容更大的用户量。在这里我们仅以四卡作为例子。

下面,启动server:

cd cookbook/client/twinkle/transformer
python server.py

服务端会启动一个包含 Sampler 集群、模型集群、工具集群的三个服务。

下面可以进行client端训练:

import dotenv
dotenv.load_dotenv('.env')
import re
from twinkle.data_format import Trajectory
from twinkle.reward.base import Reward
import gc
from peft import LoraConfig
from typing import List, Tuple

from twinkle import get_logger
from twinkle.advantage import GRPOAdvantage
from twinkle.dataset import DatasetMeta
from twinkle.metric import CompletionRewardMetric
from twinkle_client import init_twinkle_client
from twinkle_client.dataloader import DataLoader
from twinkle_client.dataset import Dataset
from twinkle_client.model import MultiLoraTransformersModel
from twinkle_client.sampler import vLLMSampler

logger = get_logger()

# ========== Configuration ==========
MODEL_ID = 'ms://Qwen/Qwen3.5-4B'
NUM_GENERATIONS = 4
MAX_NEW_TOKENS = 1024
LEARNING_RATE = 1e-5
MAX_STEPS = 10
BATCH_SIZE = 2
TEMPERATURE = 1.0
SYNC_INTERVAL = 1  # Save weights for sampler every N steps
GRADIENT_ACCUMULATION_STEPS = 4


def create_countdown_dataset():
    """Create Countdown Game dataset for GRPO training."""

    dataset = Dataset(dataset_meta=DatasetMeta('ms://zouxuhong/Countdown-Tasks-3to4', data_slice=range(500)))
    dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=8192)
    dataset.map('CountdownProcessor')
    dataset.encode(add_generation_prompt=True, batched=True)
    return dataset


class CountDownAccuracy(Reward):

    @staticmethod
    def countdown_accuracy_reward(completion: str, target: int, nums: List[int]) -> float:
        """Accuracy reward: checks if equation is correct."""
        try:
            match = re.search(r'<answer>(.*?)<\/answer>', completion)
            if match is None:
                return 0.0
            equation = match.group(1).strip()
            if '=' in equation:
                equation = equation.split('=')[0]
            used_numbers = [int(n) for n in re.findall(r'\d+', equation)]
            if sorted(used_numbers) != sorted(nums):
                return 0.0
            if not re.match(r'^[\d+\-*/().\s]+$', equation):
                return 0.0
            result = eval(equation, {'__builtins__': None}, {})
            return 1.0 if abs(float(result) - float(target)) < 1e-5 else 0.0
        except Exception:  # noqa
            return 0.0

    def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]):
        rewards = []
        for trajectory in trajectories:
            messages = trajectory.get('messages', [])
            completion = ''
            for msg in reversed(messages):
                if msg.get('role') == 'assistant':
                    completion = msg.get('content', '')
                    break
            user_data = trajectory.get('user_data', [{}])
            data = user_data[0] if isinstance(user_data, list) and user_data else {}
            target = data.get('target', 0)
            nums = data.get('nums', [])
            acc_reward = self.countdown_accuracy_reward(completion, target, nums)
            rewards.append(acc_reward)
        return rewards


def compute_rewards(trajectories: List[dict], ) -> Tuple[List[float], List[float], List[float]]:
    """Compute format and accuracy rewards for Countdown game."""
    from twinkle.reward import FormatReward
    format_rewards = FormatReward()(trajectories, [])
    accuracy_rewards = CountDownAccuracy()(trajectories, [])
    total_rewards = [a + b for a, b in zip(accuracy_rewards, format_rewards)]
    return total_rewards, format_rewards, accuracy_rewards


def train():
    # Step 1: Initialize the Twinkle client
    client = init_twinkle_client(
        base_url='http://localhost:8000',
        api_key='',
    )

    # Step 2: Prepare dataset and dataloader
    dataset = create_countdown_dataset()
    dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)

    # Step 3: Configure the training model
    model = MultiLoraTransformersModel(model_id=MODEL_ID)

    lora_config = LoraConfig(
        target_modules='all-linear',
        r=8,
        lora_alpha=32,
        lora_dropout=0.05,
    )
    model.add_adapter_to_model(
        'default',
        lora_config,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    )

    # Set GRPO loss (the key difference from SFT training)
    model.set_loss('GRPOLoss', epsilon=0.2, beta=0.0)

    # Set optimizer and LR scheduler
    model.set_optimizer('AdamW', lr=LEARNING_RATE)
    model.set_lr_scheduler(
        'CosineWarmupScheduler',
        num_warmup_steps=500,
        num_training_steps=MAX_STEPS,
    )

    # Set processor and template for encoding inputs
    model.set_processor('InputProcessor')
    model.set_template('Qwen3_5Template', model_id=MODEL_ID)

    # Step 4: Configure the sampler
    sampler = vLLMSampler(model_id=MODEL_ID)
    sampler.set_template('Qwen3_5Template', model_id=MODEL_ID)

    # Step 5: Setup metrics and advantage function
    advantage_fn = GRPOAdvantage()
    metrics = CompletionRewardMetric()

    sampling_params = {
        'max_tokens': MAX_NEW_TOKENS,
        'temperature': TEMPERATURE,
        'top_p': 0.95,
    }

    # Track the current adapter path for sampling
    current_adapter_uri = None

    step = 0
    for batch in dataloader:
        if step >= MAX_STEPS:
            break

        metrics.reset()
        prompts = batch if isinstance(batch, list) else [batch]

        # ========== 1. Save weights and update adapter_uri ==========
        # Instead of sync_weights, save the model checkpoint and pass
        # the resulting path to the sampler as adapter_uri
        if step % SYNC_INTERVAL == 0:
            logger.info(f'Step {step}: Saving weights for sampler...')
            twinkle_path = model.save(
                name=f'grpo-sampler-step-{step}',
                save_optimizer=False,
            )
            current_adapter_uri = twinkle_path
            logger.info(f'Step {step}: Saved weights to {current_adapter_uri}')

        # ========== 2. Sample completions ==========
        sample_response = sampler.sample(
            inputs=prompts,
            sampling_params=sampling_params,
            adapter_uri=current_adapter_uri,
            num_samples=NUM_GENERATIONS,
        )

        input_features = []
        old_logps_list = []
        completion_lengths = []

        sequences = sample_response.get('sequences', [])
        for seq in sequences:
            input_features.append(seq.get('new_input_feature', seq))
            old_logps_list.append(seq.get('logprobs', []))
            completion_lengths.append(len(seq.get('tokens', [])))

        if not input_features:
            logger.warning(f'Step {step}: No valid samples, skipping')
            step += 1
            continue

        # ========== 3. Compute rewards ==========
        total_rewards, format_rewards, accuracy_rewards = compute_rewards(input_features)
        metrics.accumulate(
            None,
            None,
            completion_lengths=completion_lengths,
            rewards={
                'total': total_rewards,
                'format': format_rewards,
                'accuracy': accuracy_rewards,
            })

        # ========== 4. Compute advantages ==========
        advantages = advantage_fn(
            total_rewards,
            num_generations=NUM_GENERATIONS,
            scale='group',
        ).tolist()

        frac_zero_std = (1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0)
        if frac_zero_std == 1.0:
            logger.info(f'Step {step}: All advantages are zero, skipping training')
            step += 1
            continue

        # ========== 5. Training step (GRPO) ==========
        # forward_backward with GRPO loss: passes advantages and old_logps
        # to the server-side GRPOLoss for proper policy optimization
        model.forward_backward(
            inputs=input_features,
            advantages=advantages,
            old_logps=old_logps_list,
        )

        # Gradient clipping and optimizer step
        model.clip_grad_norm(1.0)
        model.step()
        model.zero_grad()
        model.lr_step()

        gc.collect()

        # ========== 6. Log ==========
        log_dict = metrics.calculate()
        log_dict.update(model.calculate_metric())
        log_dict['train/frac_reward_zero_std'] = frac_zero_std
        logger.info(f'Step {step}: {log_dict}')
        step += 1

    # Save final checkpoint
    twinkle_path = model.save(name='grpo-countdown-final', save_optimizer=True)
    logger.info(f'Saved final checkpoint: {twinkle_path}')


if __name__ == '__main__':
    train()

多个开发者可以并行使用这个服务的单个基模进行并行训练和采样。并且,他们进行的训练方式允许不同。例如,A 用户可以进行 SFT,B 用户可以进行 RL,C 用户可以进行采样。同样,Twinkle 也支持 Tinker-like API 进行远端训练:

from tinker import types
from tqdm import tqdm
from tinker import ServiceClient
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.preprocessor import SelfCognitionProcessor
from twinkle.server.common import input_feature_to_datum

# The base model to fine-tune / evaluate
base_model = 'Qwen/Qwen3.5-4B'


def train():
    # Step 1: Prepare the dataset

    # Load the self-cognition dataset from ModelScope (first 500 examples)
    dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500)))

    # Apply the chat template matching the base model (max 256 tokens per sample)
    dataset.set_template('Qwen3_5Template', model_id=f'ms://{base_model}', max_length=256)

    # Replace placeholder names with custom model/author identity
    dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队'), load_from_cache_file=False)

    # Tokenize and encode the dataset into model-ready input features
    dataset.encode(batched=True, load_from_cache_file=False)

    # Wrap the dataset into a DataLoader that yields batches of size 8
    dataloader = DataLoader(dataset=dataset, batch_size=8)

    # Step 2: Initialize the training client
    # Connect to the Twinkle server running locally
    service_client = ServiceClient(base_url='http://localhost:8000', api_key='your-api-key')
    # Create a LoRA training client for the base model (rank=16 for the LoRA adapter)
    training_client = service_client.create_lora_training_client(base_model=base_model, rank=16)

    # Step 3: Run the training loop
    for epoch in range(3):
        print(f'Epoch {epoch}')
        for step, batch in tqdm(enumerate(dataloader)):
            # Convert each InputFeature into a Datum for the Tinker API
            input_datum = [input_feature_to_datum(input_feature) for input_feature in batch]

            # Send data to server: forward + backward pass (computes gradients)
            fwdbwd_future = training_client.forward_backward(input_datum, 'cross_entropy')

            # Optimizer step: update model weights with Adam
            optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))

            # Wait for both operations to complete
            fwdbwd_future.result()
            optim_result = optim_future.result()
            print(f'Training Metrics: {optim_result}')

        # Save a checkpoint after each epoch
        save_future = training_client.save_state(f'twinkle-lora-{epoch}')
        save_result = save_future.result()
        print(f'Saved checkpoint to {save_result.path}')


if __name__ == '__main__':
    train()

使用魔搭社区提供的TaaS化训练服务

在 Twinkle 框架开源的同时,我们依托ModelScope的后台服务,也提供了托管的模型训练服务(Training as a Service),开发者可以通过这一服务, 免费体验Twinkle的训练API。 该服务和上面叙述的Tinker API部分代码是相同的,唯一不同的是Endpoint和Token需要使用魔搭官方的对应信息。关于如何使用官方服务,请查看训练服务的详细描述。

Twinkle提供了采样 API,该 API 可以用于更灵活地控制采样方式以验证结果,或者参与到 RL 算法的采样流程中。

完整的训练模式示例请参考 cookbook 目录。

使用 Hugging Face 的模型

要从 Hugging Face 加载模型而不是 ModelScope,只需切换前缀即可:

ms://Qwen/Qwen3.5-4B -> hf://Qwen/Qwen3.5-4B

所有接受 model_id 参数的组件都支持这种基于前缀的路由。

🛠️ Twinkle✨ 模块化生态系统

Dataset
数据加载和预处理

Template
编码和解码

DataLoader
数据分发和批处理

Preprocessor
数据 ETL

InputProcessor
任务特定的输入处理

Model
大模型,支持多种框架

Sampler
采样逻辑

Loss
损失函数

Metric
训练指标收集

Reward
奖励函数

Advantage
优势函数

CheckpointEngine
权重同步

Patch
模型修复补丁

Module
组件,如优化器

Kernel
算子

Server
启动后端集群

Client
客户端代码

Infra
隔离 ray 和 torchrun 的差异

Plugin
使用 hub 组件

Hub
与 HF/MS 库对接

Twinkle 的可定制组件

在 Twinkle 的设计中,torchrun、Ray、HTTP 的训练使用同样的 API,并共享相同的组件和输入输出结构。因此,其很多组件可以由开发者自定义来实现新的算法开发。

下面我们列出推荐定制的组件列表:

组件名称 基类 说明
损失 twinkle.loss.Loss 用于定义模型训练的损失函数
指标 twinkle.metric.Metric 用于定义模型训练的评价体系
Optimizer/LRScheduler 基于PyTorch 用于定义模型训练的优化器和LR衰减器
补丁 twinkle.patch.Patch 用于修复模型训练过程中的问题
预处理器 twinkle.preprocessor.Preprocessor 用于对数据进行预处理(ETL),并返回 Template 可用的标准格式
过滤器 twinkle.preprocessor.Filter 用于对原始数据进行合理性过滤
任务数据处理器 twinkle.processor.InputProcessor 用于将模型输入转换为各任务需要的数据,并添加额外字段
模型 twinkle.model.TwinkleModel 大模型本身
采样器 twinkle.sampler.Sampler 采样器,例如 vLLM
奖励 twinkle.reward.Reward 用于实现不同 RL 训练的奖励
优势 twinkle.advantage.Advantage 用于实现不同 RL 训练的优势估计
模板 twinkle.template.Template 用于处理标准输入,并转换成模型需要的 token
权重同步 twinkle.checkpoint_engine.CheckpointEngine 用于 RL 训练中的权重同步

未在上表中列出的组件,如Dataset、DataLoader等也可以实现定制,只需要跟随基类API设计即可。

DeviceGroup 和 DeviceMesh

DeviceGroup 和 DeviceMesh 是 Twinkle 架构的核心。所有的代码构建均基于这两个设计。

import twinkle
from twinkle import DeviceMesh, DeviceGroup
device_group = [
        DeviceGroup(
            name='default',
            ranks=8,
            device_type='cuda',
        )
    ]

device_mesh = DeviceMesh.from_sizes(pp_size=2, tp_size=2, dp_size=2)
twinkle.initialize(mode='ray', nproc_per_node=8, groups=device_group)

当 device_group 定义完成后,需要使用 twinkle.initialize 来初始化资源。

DeviceGroup:定义本次训练需要多少个资源组。定义后,组件可以通过选择资源组的方式将自己运行在远端:

from twinkle.model import TransformersModel
model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B', remote_group='default', device_mesh=device_mesh)
# 或者
from twinkle.model import MegatronModel
model = MegatronModel(model_id='ms://Qwen/Qwen3.5-4B', remote_group='default', device_mesh=device_mesh)

DeviceMesh 指定了模型等组件在资源组中的拓扑结构。可以理解为如何进行并行。这会影响一系列的框架决策,例如数据获取、数据消费、数据返回等。

使用样例

from peft import LoraConfig
import twinkle
from twinkle import DeviceMesh, DeviceGroup
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
from twinkle.preprocessor import SelfCognitionProcessor

device_group = [DeviceGroup(name='default',ranks=8,device_type='cuda')]
device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2)
# local for torchrun
twinkle.initialize(mode='ray', groups=device_group, global_device_mesh=device_mesh)


def train():
    # 1000 samples
    dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
    # Set template to prepare encoding
    dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B')
    # Preprocess the dataset to standard format
    dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
    # Encode dataset
    dataset.encode()
    # Global batch size = 8, for GPUs, so 1 sample per GPU
    dataloader = DataLoader(dataset=dataset, batch_size=8, min_batch_size=8)
    # Use a TransformersModel
    model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B', remote_group='default')

    lora_config = LoraConfig(
        r=8,
        lora_alpha=32,
        target_modules='all-linear'
    )

    # Add a lora to model, with name `default`
    # Comment this to use full-parameter training
    model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
    # Add Optimizer for lora `default`
    model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
    # Add LRScheduler for lora `default`
    model.set_lr_scheduler(scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5,
                           num_training_steps=len(dataloader))
    for step, batch in enumerate(dataloader):
        # Do forward and backward
        model.forward_backward(inputs=batch)
        # Step
        model.clip_grad_and_step()
        if step % 20 == 0:
            # Print metric
            metric = model.calculate_metric(is_training=True)
            print(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
    model.save(f'last-checkpoint')


if __name__ == '__main__':
    train()

这样启动训练:

python3 train.py