GKD Loss
Generalized Knowledge Distillation (GKD) loss uses Jensen-Shannon Divergence for distilling knowledge from a teacher model to a student model.
from twinkle.loss import GKDLoss
loss_fn = GKDLoss(
teacher_mode='full', # 'full', 'topk_local', 'topk_remote'
beta=0.5, # interpolation weight for JSD
temperature=1.0,
)
model.set_loss(loss_fn)
Parameters:
teacher_mode: How teacher logits are obtainedfull: Full vocabulary logits from a local teacher modeltopk_local: Top-k logits from a local teacher with chunked computation for memory efficiencytopk_remote: Top-k logits from a remote API teacher
beta: Interpolation weight between student and teacher distributions in JSD (0 = pure student, 1 = pure teacher)temperature: Softmax temperature for both student and teacher distributions
The GKD loss implements chunked computation internally to reduce peak memory usage when working with large vocabularies.
GKD is useful for training smaller student models that mimic the behavior of larger teacher models, and supports both local and remote teacher setups.