Chunked Cross Entropy
A memory-efficient variant of cross-entropy loss that processes the vocabulary dimension in chunks to reduce peak GPU memory usage.
from twinkle.loss import ChunkedCrossEntropyLoss
loss_fn = ChunkedCrossEntropyLoss(
chunk_size=1024, # vocabulary chunk size
reduction='mean',
)
model.set_loss(loss_fn)
Parameters:
chunk_size: Number of vocabulary tokens to process per chunk (default: 1024)reduction: Reduction mode —sum,mean, ornone
The implementation uses a custom autograd function that splits the logit-to-loss computation into chunks along the vocabulary dimension. This avoids materializing the full [batch*seq_len, vocab_size] probability tensor, significantly reducing memory for large vocabularies.
Useful when training with large vocabulary models where standard cross-entropy causes OOM errors.