Embedding Training
Twinkle supports contrastive embedding model training with InfoNCE loss, in-batch negatives, and cross-rank gathering. This guide demonstrates how to train embedding models using Twinkle.
Overview
Embedding training in Twinkle uses the following core components:
| Component | Role |
|---|---|
InfonceLoss |
Contrastive loss with in-batch negatives |
EmbeddingMetric |
Tracks pos/neg similarity and loss |
TransformersModel |
Trainable embedding model (with LoRA or full) |
InputProcessor |
Processes anchor/positive pairs into features |
Data Format
Each training sample consists of (anchor, positive) pairs. In the embedding feature tensor:
embeddings: [anchor_0, positive_0, anchor_1, positive_1, ...]
labels: [ 1, 0, 1, 0, ...]
labels=1marks the start of a new group (anchor)labels=0marks positives/negatives within the group
Basic Embedding Training
A minimal embedding training script with DDP:
import twinkle
from twinkle import DeviceGroup, DeviceMesh, get_logger
from twinkle.dataloader import DataLoader
from twinkle.loss import InfonceLoss
from twinkle.metric import EmbeddingMetric
from twinkle.model import TransformersModel
from twinkle.processor import InputProcessor
from twinkle.template import Qwen3_5Template
logger = get_logger()
# --- Configuration ---
MODEL_ID = 'ms://Qwen/Qwen3.5-4B'
MODEL_GPUS = 4
BATCH_SIZE = 32
LEARNING_RATE = 1e-5
TEMPERATURE = 0.07
EMB_MAX_LENGTH = 8192
# --- Initialize ---
device_groups = [
DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
]
model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
twinkle.initialize(mode='ray', nproc_per_node=MODEL_GPUS, groups=device_groups)
# --- Model ---
model = TransformersModel(
model_id=MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
ddp_config={'find_unused_parameters': True},
)
model.set_processor(InputProcessor)
model.set_loss(InfonceLoss, temperature=TEMPERATURE, use_batch=True)
model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE)
model.set_lr_scheduler(
scheduler_cls='CosineWarmupScheduler',
num_warmup_steps=200,
num_training_steps=total_steps,
)
model.add_metric(EmbeddingMetric, is_training=True)
# --- Template ---
template = Qwen3_5Template(
model_id=MODEL_ID,
max_length=EMB_MAX_LENGTH,
enable_thinking=False,
)
# --- Training Loop ---
for step, batch in enumerate(dataloader):
# batch: list of features with anchor/positive pairs
model.forward_backward(inputs=batch, task='embedding')
model.clip_grad_and_step(gradient_accumulation_steps=1)
if step % 10 == 0:
metric = model.calculate_metric(is_training=True)
logger.info(f'Step {step}: {metric}')
Key Parameters
| Parameter | Recommended | Description |
|---|---|---|
temperature |
0.05–0.1 | Lower = sharper contrast. 0.07 keeps gradients flowing until cosine > 0.75 |
use_batch |
True | Enables cross-sample in-batch negatives for better efficiency |
hard_negatives |
None or 7 | Fix negative count per sample; None uses all in-batch |
find_unused_parameters |
True | Required for embedding models (only last hidden state contributes gradients) |
Monitoring
The EmbeddingMetric reports key training signals:
| Metric | What it means |
|---|---|
pos_sim |
Average anchor-positive cosine similarity (target: > 0.8) |
neg_sim |
Average anchor-negative similarity (target: < 0.3) |
loss |
InfoNCE loss value |
grad_norm |
Gradient magnitude |
Healthy training shows pos_sim rising and neg_sim stable or falling. If pos_sim saturates near 1.0, lower the temperature.