自定义度量#
要实现自定义度量,子类化基本 Metric
类并实现以下方法:
__init__()
:每个状态变量都应该使用self.add_state(…)
来调用。update()
:给定度量的任何输入,更新状态所需的任何代码。compute()
:从度量的状态计算最终值。
reset()
可以确保正确重置使用 add_state
添加的所有度量状态。因此,不应该自己实现 reset()
。此外,使用 add_state
添加度量状态将确保在分布式设置(DDP)中正确地同步状态。
简单的示例如下:
import torch
from torchmetrics import Metric
class MyAccuracy(Metric):
def __init__(self):
super().__init__()
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
preds, target = self._input_format(preds, target)
assert preds.shape == target.shape
self.correct += torch.sum(preds == target)
self.total += target.numel()
def compute(self):
return self.correct.float() / self.total
此外,可能希望设置类属性:is_differentiable
、higher_is_better
和 full_state_update
。注意,这些度量都不是度量工作所必需的。
from torchmetrics import Metric
class MyMetric(Metric):
# Set to True if the metric is differentiable else set to False
is_differentiable: bool|None = None
# Set to True if the metric reaches it optimal value when the metric is maximized.
# Set to False if it when the metric is minimized.
higher_is_better: bool|None = True
# Set to True if the metric during 'update' requires access to the global metric
# state for its calculations. If not, setting this to False indicates that all
# batch states are independent and we will optimize the runtime of 'forward'
full_state_update: bool = True
内部实现细节#
本节简要描述度量的内部工作方式。鼓励查看源代码以获得更多信息。在内部,TorchMetrics 封装了用户定义的 update()
和 compute()
方法。这样做是为了自动同步和减少跨多个设备的度量状态。
更准确地说,调用 update()
在内部执行以下运算:
清除缓存计算。
调用用户定义的
update()
。
类似地,调用 compute()
在内部执行以下运算:
同步进程之间的度量状态。
规约收集的度量状态。
在收集的度量状态上调用用户定义的
compute()
方法。缓存计算结果。
从用户的角度来看,这有一个重要的副作用:计算结果被缓存。这意味着无论在一个和另一个之后调用多少次 compute()
,它都将继续返回相同的结果。在下一次调用 update()
时首先清空缓存。
forward()
有双重目的,既可以返回当前数据上的度量值,也可以更新内部度量状态,以便在多个批之间累积。forward()
方法通过组合调用 update()
、compute()
和 reset()
来实现这一点。根据类属性 full_state_update
的不同,forward()
可以有两种表现方式:
如果
full_state_update
为True
,则表示update()
期间的度量需要访问完整的度量状态,因此需要执行两次update()
调用,以确保正确计算度量调用
update()
来更新全局度量状态(用于多个批的累积)缓存全局状态
调用
reset()
来清除全局度量状态调用
update()
来更新局部度量状态调用
compute()
来计算当前批处理的度量。恢复全局状态。
如果
full_state_update
为False
(默认值),则一个批的度量状态完全独立于其他批的状态,这意味着只需要调用update()
一次。缓存全局状态
调用
reset()
将度量重置为默认状态调用
update()
使用本地批处理统计信息更新状态调用
compute()
为当前批处理计算度量将全局状态和批处理状态缩减为单个状态,该状态将成为新的全局状态
如果实现您自己的度量,建议尝试使用 full_state_update
类属性同时设置为 True
和 False
的度量。如果结果相等,则将其设置为 False
通常会获得最佳性能。