构建指标
指标用于衡量训练过程和训练结果。指标组件属于可定制组件的一部分。
class Metric:
def __init__(self, device_mesh, process_group, **kwargs):
self.process_group = process_group
self.device_mesh = device_mesh
# 由于 microbatch 的存在,输入到 Metric 的 inputs 可能是个 List
def accumulate(self, inputs: 'Union[InputFeature, List[InputFeature]]', outputs: 'ModelOutput'):
...
def calculate(self):
...
def reset(self):
...
指标无法通过 Callable 传入。因为它包含了 accumulate 和 calculate 两个部分,并需要支持 reset 来归零。指标的构造中会自动传入 device_mesh 和隶属于当前 dp 组的 process_group,用以跨进程通信。
并且,在实际的实现中,基类提供了 gather_results 方法来辅助收集各个进程的输入结果。