构建指标

指标用于衡量训练过程和训练结果。指标组件属于可定制组件的一部分。

class Metric:

    def __init__(self, device_mesh, process_group, **kwargs):
        self.process_group = process_group
        self.device_mesh = device_mesh

    # 由于 microbatch 的存在,输入到 Metric 的 inputs 可能是个 List
    def accumulate(self, inputs: 'Union[InputFeature, List[InputFeature]]', outputs: 'ModelOutput'):
        ...

    def calculate(self):
        ...

    def reset(self):
        ...

指标无法通过 Callable 传入。因为它包含了 accumulatecalculate 两个部分,并需要支持 reset 来归零。指标的构造中会自动传入 device_mesh 和隶属于当前 dp 组的 process_group,用以跨进程通信。 并且,在实际的实现中,基类提供了 gather_results 方法来辅助收集各个进程的输入结果。