序列并行 (SP)

序列并行沿序列维度将长序列分割到多个 GPU 上,使训练能处理超出单卡显存的序列长度。Twinkle 实现了 Ulysses 风格的序列并行,并可选地支持派生环形注意力。

概览

概念 说明
SequenceParallelConfig SP 配置数据类
SequenceParallelStrategy 封装 SP 生命周期的策略类
SequenceParallel 核心实现,处理填充/分割/聚合

配置

from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallelConfig

config = SequenceParallelConfig(
    enabled=True,           # 启用序列并行
    ulysses_size=None,      # Ulysses SP 并行度(若为 None 则从 DeviceMesh 自动推导)
    gather_logits=True,     # 前向后聚合 logits 用于损失计算
)

配合 DeviceMesh 使用

DeviceMesh.from_sizes() 中设置 ulysses_size 即可激活 SP:

from twinkle.utils import DeviceMesh

# 8 卡:4 路 Ulysses SP × 2 路数据并行
device_mesh = DeviceMesh.from_sizes(
    world_size=8,
    dp_size=2,
    ulysses_size=4,
)

工作原理

  1. 填充 — 输入序列被填充到可被 SP 并行度整除的长度

  2. 分割 — 填充后的输入沿序列维度均匀分配到各 SP rank

  3. 分布式注意力 — FlashAttention2 被 patch 为在注意力计算前后执行 Ulysses all-to-all 通信

  4. 聚合 — 前向传播后,logits 被聚合回完整序列长度用于损失计算

支持的注意力后端

后端 状态
FlashAttention2 完全支持(包括打包/padding-free 序列)
SDPA 支持(仅非打包批次)
派生环形注意力 仅支持 FlashAttention2(rp_world_size > 1

Qwen3.5 线性注意力

SP 自动检测 Qwen3.5 GatedDeltaNet 线性注意力层,并应用 Qwen3_5GatedDeltaNetUlyssesPatch,确保混合注意力架构下序列并行的正确性。

MoE 辅助损失

对于 MoE 模型,SP 自动安装前向 hook,在计算辅助损失前跨 SP rank 聚合路由 logits,确保负载均衡信号的正确性。

关键约束

  • num_key_value_heads 必须能被 ulysses_size 整除(Ulysses 模式),否则回退到环形注意力

  • 打包/padding-free 批次需要 FlashAttention2

  • 派生环形注意力要求 batch_size == 1(打包格式)

  • torch.distributed 必须已初始化