损失缩放
LossScale 组件控制训练过程中的损失缩放,确保数值稳定性,在混合精度训练中尤为重要。
from twinkle.loss_scale import LossScale
loss_scale = LossScale()
# 在反向传播前对损失进行缩放
scaled_loss = loss_scale(loss, num_tokens)
LossScale 通过有效 token 数量对损失值进行归一化,确保不同批次大小和序列长度下梯度幅度的一致性。
LossScale 在模型训练流水线中内部使用。使用
model.forward_backward()时会自动应用。