InfoNCE Loss
The InfonceLoss implements contrastive learning with in-batch negatives and optional cross-rank gathering. It is designed for embedding/retrieval model training.
Usage
from twinkle.loss import InfonceLoss
loss_fn = InfonceLoss(
temperature=0.1,
use_batch=True, # Enable in-batch negatives
hard_negatives=7, # Fix negative count per sample
mask_fake_negative=True, # Mask false negatives
fake_neg_margin=0.1, # Margin for false negative detection
)
model.set_loss(loss_fn)
Input Format
Each sample is laid out as anchor(1) + positive(1) + negatives(n) in a flat embedding tensor. The inputs['labels'] is a 1-D mask where 1 marks the start of each group.
embeddings: [a0, p0, n0_1, n0_2, a1, p1, n1_1, n1_2, ...]
labels: [ 1, 0, 0, 0, 1, 0, 0, 0, ...]
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
temperature |
float | 0.1 | Logit scaling factor |
use_batch |
bool | True | Use cross-sample in-batch negatives |
hard_negatives |
int | None | Fix per-sample negative count (truncate/upsample) |
mask_fake_negative |
bool | False | Mask logits > positive + margin |
fake_neg_margin |
float | 0.1 | Threshold for false negative masking |
include_qq |
bool | False | Add query-query similarity block |
include_dd |
bool | False | Add doc-doc similarity block |
Cross-Rank Gathering
When use_batch=True and distributed training is active, embeddings are gathered from all DP ranks to maximize in-batch negative diversity. Only the local shard retains gradients.
Similarity Blocks
The loss supports three similarity blocks for comprehensive contrastive learning:
Q→D (default): Query to all documents — primary contrastive signal
Q→Q (
include_qq=True): Query to all other queries — prevents query collapseD→D (
include_dd=True): Document to all other documents — Qwen3-Embedding style
Example: Embedding Training
from twinkle.loss import InfonceLoss
from twinkle.metric import EmbeddingMetric
# Configure model for embedding
model.set_loss(InfonceLoss(temperature=0.05, use_batch=True, include_qq=True))
model.set_metric(EmbeddingMetric(device_mesh=mesh, process_group=pg))
# Training loop
for batch in dataloader:
model.forward_backward(batch)
model.clip_grad_and_step()