Building Metrics

Metrics are used to measure the training process and training results. The metric component is part of the customizable components.

class Metric:

    def __init__(self, device_mesh, process_group, **kwargs):
        self.process_group = process_group
        self.device_mesh = device_mesh

    # Due to the existence of microbatch, the inputs to Metric may be a List
    def accumulate(self, inputs: 'Union[InputFeature, List[InputFeature]]', outputs: 'ModelOutput'):
        ...

    def calculate(self):
        ...

    def reset(self):
        ...

Metrics cannot be passed in through Callable. Because it contains two parts: accumulate and calculate, and needs to support reset to zero out. The device_mesh and process_group belonging to the current dp group are automatically passed in during the construction of the metric for cross-process communication. Moreover, in the actual implementation, the base class provides a gather_results method to assist in collecting input results from various processes.