自定义度量

自定义度量#

要实现自定义度量,子类化基本 Metric 类并实现以下方法:

  • __init__():每个状态变量都应该使用 self.add_state(…) 来调用。

  • update():给定度量的任何输入,更新状态所需的任何代码。

  • compute():从度量的状态计算最终值。

reset() 可以确保正确重置使用 add_state 添加的所有度量状态。因此,不应该自己实现 reset()。此外,使用 add_state 添加度量状态将确保在分布式设置(DDP)中正确地同步状态。

简单的示例如下:

import torch
from torchmetrics import Metric


class MyAccuracy(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        preds, target = self._input_format(preds, target)
        assert preds.shape == target.shape

        self.correct += torch.sum(preds == target)
        self.total += target.numel()

    def compute(self):
        return self.correct.float() / self.total

此外,可能希望设置类属性:is_differentiablehigher_is_betterfull_state_update。注意,这些度量都不是度量工作所必需的。

from torchmetrics import Metric

class MyMetric(Metric):
   # Set to True if the metric is differentiable else set to False
   is_differentiable: bool|None = None

   # Set to True if the metric reaches it optimal value when the metric is maximized.
   # Set to False if it when the metric is minimized.
   higher_is_better: bool|None = True

   # Set to True if the metric during 'update' requires access to the global metric
   # state for its calculations. If not, setting this to False indicates that all
   # batch states are independent and we will optimize the runtime of 'forward'
   full_state_update: bool = True

内部实现细节#

本节简要描述度量的内部工作方式。鼓励查看源代码以获得更多信息。在内部,TorchMetrics 封装了用户定义的 update()compute() 方法。这样做是为了自动同步和减少跨多个设备的度量状态。

更准确地说,调用 update() 在内部执行以下运算:

  • 清除缓存计算。

  • 调用用户定义的 update()

类似地,调用 compute() 在内部执行以下运算:

  • 同步进程之间的度量状态。

  • 规约收集的度量状态。

  • 在收集的度量状态上调用用户定义的 compute() 方法。

  • 缓存计算结果。

从用户的角度来看,这有一个重要的副作用:计算结果被缓存。这意味着无论在一个和另一个之后调用多少次 compute(),它都将继续返回相同的结果。在下一次调用 update() 时首先清空缓存。

forward() 有双重目的,既可以返回当前数据上的度量值,也可以更新内部度量状态,以便在多个批之间累积。forward() 方法通过组合调用 update()compute()reset() 来实现这一点。根据类属性 full_state_update 的不同,forward() 可以有两种表现方式:

  1. 如果 full_state_updateTrue,则表示 update() 期间的度量需要访问完整的度量状态,因此需要执行两次 update() 调用,以确保正确计算度量

    1. 调用 update() 来更新全局度量状态(用于多个批的累积)

    2. 缓存全局状态

    3. 调用 reset() 来清除全局度量状态

    4. 调用 update() 来更新局部度量状态

    5. 调用 compute() 来计算当前批处理的度量。

    6. 恢复全局状态。

  2. 如果 full_state_updateFalse (默认值),则一个批的度量状态完全独立于其他批的状态,这意味着只需要调用 update() 一次。

    1. 缓存全局状态

    2. 调用 reset() 将度量重置为默认状态

    3. 调用 update() 使用本地批处理统计信息更新状态

    4. 调用 compute() 为当前批处理计算度量

    5. 将全局状态和批处理状态缩减为单个状态,该状态将成为新的全局状态

如果实现您自己的度量,建议尝试使用 full_state_update 类属性同时设置为 TrueFalse 的度量。如果结果相等,则将其设置为 False 通常会获得最佳性能。