Sequence Parallel (SP)
Sequence Parallel splits long sequences across multiple GPUs along the sequence dimension, enabling training with sequence lengths that exceed single-GPU memory. Twinkle implements Ulysses-style sequence parallel with optional derived ring attention.
Overview
| Concept | Description |
|---|---|
| SequenceParallelConfig | Configuration dataclass for SP |
| SequenceParallelStrategy | Strategy class that wraps SP lifecycle |
| SequenceParallel | Core implementation handling pad/split/gather |
Configuration
from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallelConfig
config = SequenceParallelConfig(
enabled=True, # Enable sequence parallel
ulysses_size=None, # Ulysses SP degree (auto-derived from DeviceMesh if None)
gather_logits=True, # Gather logits after forward for loss computation
)
Usage with DeviceMesh
SP is activated by setting ulysses_size in DeviceMesh.from_sizes():
from twinkle.utils import DeviceMesh
# 8 GPUs: 4-way Ulysses SP × 2-way data parallel
device_mesh = DeviceMesh.from_sizes(
world_size=8,
dp_size=2,
ulysses_size=4,
)
How It Works
Pad — input sequences are padded to a length divisible by SP world size
Split — padded inputs are evenly split across SP ranks along the sequence dimension
Distributed Attention — FlashAttention2 is patched to perform Ulysses all-to-all communication before/after attention computation
Gather — after forward, logits are gathered back to full sequence length for loss computation
Supported Attention Backends
| Backend | Status |
|---|---|
| FlashAttention2 | Fully supported (including packed/padding-free sequences) |
| SDPA | Supported (non-packed batches only) |
| Derived Ring Attention | Supported with FlashAttention2 only (rp_world_size > 1) |
Qwen3.5 Linear Attention
SP automatically detects Qwen3.5 GatedDeltaNet linear attention layers and applies the Qwen3_5GatedDeltaNetUlyssesPatch for correct sequence-parallel behavior on hybrid attention architectures.
MoE Auxiliary Loss
For MoE models, SP automatically installs a forward hook that gathers router logits across SP ranks before auxiliary loss computation, ensuring correct load-balancing signals.
Key Constraints
num_key_value_headsmust be divisible byulysses_size(for Ulysses) or use ring attention fallbackPacked/padding-free batches require FlashAttention2
Derived ring attention requires
batch_size == 1(packed format)torch.distributedmust be initialized