TorchMetrics 概述#

度量 API 为用户提供 update()compute()reset() 函数。

from torchmetrics.classification import Accuracy

train_accuracy = Accuracy()
valid_accuracy = Accuracy()

for epoch in range(epochs):
    for x, y in train_data:
        y_hat = model(x)

        # training step accuracy
        batch_acc = train_accuracy(y_hat, y)
        print(f"Accuracy of batch{i} is {batch_acc}")

    for x, y in valid_data:
        y_hat = model(x)
        valid_accuracy.update(y_hat, y)

    # total accuracy over all training batches
    total_train_accuracy = train_accuracy.compute()

    # total accuracy over all validation batches
    total_valid_accuracy = valid_accuracy.compute()

    print(f"Training acc for epoch {epoch}: {total_train_accuracy}")
    print(f"Validation acc for epoch {epoch}: {total_valid_accuracy}")

    # Reset metric states after each epoch
    train_accuracy.reset()
    valid_accuracy.reset()

备注

度量包含跟踪到目前为止看到的数据的内部状态。不要在训练、验证和测试中混合度量状态。强烈建议重新初始化每个模式的度量,如上面的示例所示。

默认情况下,度量状态不会添加到模型 state_dict 中。要改变这一点,在初始化度量之后,可以使用 .persistent(mode) 方法来启用 (mode=True) 或禁用 (mode=False) 此行为。

由于度量状态的特殊逻辑,通常不建议在其他度量中初始化度量(嵌套的度量),因为这可能导致奇怪的行为。相反,可以考虑子类化度量或使用 MetricCollection

度量和设备#

度量是 Module 的简单子类,它们的度量状态行为类似于模块的缓冲区和参数。这意味着度量状态应该被移动到与度量输入相同的设备上:

import torch
from torchmetrics import Accuracy

target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0))
preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0))

# 指标状态总是在 CPU 上初始化,需要移动到正确的设备
confmat = Accuracy(num_classes=2).to(torch.device("cuda", 0))
out = confmat(preds, target)
print(out.device)
cuda:0

然而,当在 ModuleLightningModule 中正确定义时,当使用 .to(device) 时,度量将自动移动到与模块相同的设备上。正确定义意味着该度量被正确标识为模型的子模块(检查模型的 .children() 属性)。因此,度量不能放在原生 Python listdict 中,因为它们不会被正确地标识为子模块。用 ModuleList 代替 list,用 ModuleDict 代替 dict。此外,在使用多个度量时,还可以使用原生的 MetricCollection 模块包装多个度量。

from torchmetrics import Accuracy, MetricCollection

class MyModule(torch.nn.Module):
    def __init__(self):
        ...
        # valid ways metrics will be identified as child modules
        self.metric1 = Accuracy()
        self.metric2 = nn.ModuleList(Accuracy())
        self.metric3 = nn.ModuleDict({'accuracy': Accuracy()})
        self.metric4 = MetricCollection([Accuracy()]) # torchmetrics build-in collection class

    def forward(self, batch):
        data, target = batch
        preds = self(data)
        ...
        val1 = self.metric1(preds, target)
        val2 = self.metric2[0](preds, target)
        val3 = self.metric3['accuracy'](preds, target)
        val4 = self.metric4(preds, target)

您总是可以使用 .device 属性检查度量位于哪个设备上。

数据并行模式下的度量#

在数据并行(DataParallel DP)模式下使用度量时,应该注意 DP 将在单个 forward 过程中创建和清理度量对象的副本。这样做的结果是,副本的度量状态会在同步它们之前被默认销毁。因此,在 DP 模式下使用度量时,建议使用 dist_sync_on_step=True 初始化它们,以便在销毁度量状态之前在主进程和副本之间同步度量状态。

另外,如果度量与 LightningModule 一起使用,那么度量更新/日志记录应该在 <mode>_step_end 方法中完成(其中 <mode> 要么是 training,要么是 validation,要么是 test),否则会导致错误的积累。在实践中要做以下几点:

def training_step(self, batch, batch_idx):
    data, target = batch
    preds = self(data)
    ...
    return {'loss': loss, 'preds': preds, 'target': target}

def training_step_end(self, outputs):
    #update and log
    self.metric(outputs['preds'], outputs['target'])
    self.log('metric', self.metric)

分布式数据并行模式中的度量#

在分布式数据并行(DistributedDataParallel DDP)模式下使用度量时,应该注意,如果数据集的大小不能被 batch_size * num_processors 整除,DDP 将向数据集添加额外的样本。添加的样本将始终是数据集中已经存在的数据点的复制。这样做是为了确保所有进程的负载相等。然而,这导致计算的度量值会对复制的样本有轻微的偏差,从而导致错误的结果。

在训练和/或验证过程中,这可能不重要,但在评估测试数据集时,强烈建议只运行在单个 gpu 上,或使用连接上下文与 DDP 结合,以防止这种行为。

度量和 16 位精度#

集合中的大多数度量可以使用 16 位精度(torch.half)张量。然而,存在以下局限性:

您总是可以通过检查 .dtype 属性来检查度量的 precision/dtype。

度量算法#

Metrics 支持大多数用于算术、逻辑和位操作的 Python 内置算子。

例如,对于应该返回两个不同度量的 sum 的度量,实现新的度量是不必要的开销。现在可以这样做:

first_metric = MyFirstMetric()
second_metric = MySecondMetric()

new_metric = first_metric + second_metric

new_metric.update(*args, **kwargs) 现在调用更新的 first_metricsecond_metric。它 forward 所有位置参数,但只 forward 在各自度量的更新声明中可用的关键字参数。类似地,new_metric.compute() 现在调用 first_metricsecond_metriccompute 并将结果相加。重要的是要注意,所有实现的运算总是返回新的度量对象。这意味着 first_metric == second_metric 一行将不会返回 bool 值来指示 first_metricsecond_metric 是否相同,而是返回新的度量值来检查 first_metric.compute() == second_metric.compute()

该模式实现于以下算子(a 是度量,b 是度量,张量,整数或浮点数):

  • Addition (a + b)

  • Bitwise AND (a & b)

  • Equality (a == b)

  • Floordivision (a // b)

  • Greater Equal (a >= b)

  • Greater (a > b)

  • Less Equal (a <= b)

  • Less (a < b)

  • Matrix Multiplication (a @ b)

  • Modulo (a % b)

  • Multiplication (a * b)

  • Inequality (a != b)

  • Bitwise OR (a | b)

  • Power (a ** b)

  • Subtraction (a - b)

  • True Division (a / b)

  • Bitwise XOR (a ^ b)

  • Absolute Value (abs(a))

  • Inversion (~a)

  • Negative Value (neg(a))

  • Positive Value (pos(a))

  • Indexing (a[0])

度量和可微性#

如果度量计算中涉及的所有计算都是可微的,则度量支持反向传播。所有的模块度量类都具有属性 is_differentiable,该属性决定了度量是否可微。

但是,请注意,缓存的状态与计算图分离,不能反向传播。不这样做将意味着为每次更新调用存储计算图,这可能导致内存不足错误。在实践中,这意味着:

MyMetric.is_differentiable  # returns True if metric is differentiable
metric = MyMetric()
val = metric(pred, target)  # this value can be back-propagated
val = metric.compute()  # this value cannot be back-propagated

函数模式度量是可微的,如果它对应的模块模式度量是可微的。

度量和超参数优化#

如果您想直接优化度量,它需要支持反向传播。然而,如果你只是对使用度量进行超参数调优感兴趣,并且不确定度量应该最大化还是最小化,所有模块化度量类都有 higher_is_better 属性,可以用来确定这一点:

# returns True because accuracy is optimal when it is maximized
torchmetrics.Accuracy.higher_is_better

# returns False because the mean squared error is optimal when it is minimized
torchmetrics.MeanSquaredError.higher_is_better

MetricCollection#

在许多情况下,通过多个度量来评估模型输出是有益的。在这种情况下,MetricCollection 类可能会派上用场。它接受一系列度量,并将这些度量包装成可调用的度量类,具有与任何其他度量相同的接口。

例如:

from torchmetrics import MetricCollection, Accuracy, Precision, Recall
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
metric_collection = MetricCollection([
    Accuracy(),
    Precision(num_classes=3, average='macro'),
    Recall(num_classes=3, average='macro')
])
print(metric_collection(preds, target))
{'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)}

类似地,它还可以减少记录 LightningModule 内多个度量所需的代码量。在大多数情况下,只需要用 self.log_dict 替换 self.log

from torchmetrics import Accuracy, MetricCollection, Precision, Recall

class MyModule(LightningModule):
    def __init__(self):
        metrics = MetricCollection([Accuracy(), Precision(), Recall()])
        self.train_metrics = metrics.clone(prefix='train_')
        self.valid_metrics = metrics.clone(prefix='val_')

    def training_step(self, batch, batch_idx):
        logits = self(x)
        # ...
        output = self.train_metrics(logits, y)
        # use log_dict instead of log
        # metrics are logged with keys: train_Accuracy, train_Precision and train_Recall
        self.log_dict(output)

    def validation_step(self, batch, batch_idx):
        logits = self(x)
        # ...
        self.valid_metrics.update(logits, y)

    def validation_epoch_end(self, outputs):
        # use log_dict instead of log
        # metrics are logged with keys: val_Accuracy, val_Precision and val_Recall
        output = self.valid_metric.compute()
        self.log_dict(output)

备注

默认情况下,MetricCollection 假定集合中的所有度量都具有相同的调用签名。如果情况并非如此,则应该提供给不同度量的输入可以作为关键字参数提供给集合。

使用 MetricCollection 对象的另一个优点是,它将自动尝试通过查找共享相同底层度量状态的度量组来减少所需的计算。如果发现了这样一组度量,那么实际上只有其中一个度量被更新,并且更新的状态将被广播到组中的其他度量。在上面的例子中,与只调用更新的验证度量(此特性不能与 forward 结合使用)相比,禁用此特性将导致 2 -3 倍的计算成本。然而,这种加速有固定的前期成本,其中的状态组必须在第一次更新后确定。如果预先知道组,也可以手动设置这些组,以避免动态搜索的额外成本。

备注

计算组(compute groups)特性可以在适当的条件下显著加快度量的计算。首先,该特性只在调用 update 方法时可用,而在调用 forward 方法时不可用,因为 forward 的内部逻辑阻止了这一点。其次,由于计算组通过引用共享度量状态,在度量集合上调用 .items().values() 等将破坏该引用,在这种情况下反而返回状态的副本(引用将在下一次调用 update 时重新建立)。

度量集合可以在初始化时嵌套(参见上一个例子),但是集合的输出仍然是一个单一的扁平字典,它结合了来自嵌套集合的前缀和后缀参数。

作为列表输入:

import torch
from torchmetrics import MetricCollection, Accuracy, Precision, Recall, MeanSquaredError

target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
metrics = MetricCollection([Accuracy(),
                            Precision(num_classes=3, average='macro'),
                            Recall(num_classes=3, average='macro')])
metrics(preds, target)
{'Accuracy': tensor(0.1250),
 'Precision': tensor(0.0667),
 'Recall': tensor(0.1111)}

作为参数输入:

metrics = MetricCollection(Accuracy(), Precision(num_classes=3, average='macro'),
                           Recall(num_classes=3, average='macro'))
metrics(preds, target)
{'Accuracy': tensor(0.1250),
 'Precision': tensor(0.0667),
 'Recall': tensor(0.1111)}

作为字典输入:

metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'),
                            'macro_recall': Recall(num_classes=3, average='macro')})
same_metric = metrics.clone()
print(metrics(preds, target))
print(same_metric(preds, target))
{'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}
{'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}

计算组规范:

metrics = MetricCollection(
    Recall(num_classes=3, average='macro'),
    Precision(num_classes=3, average='macro'),
    MeanSquaredError(),
    compute_groups=[['Recall', 'Precision'], ['MeanSquaredError']]
)
metrics.update(preds, target)
print(metrics.compute())

print(metrics.compute_groups)
{'Recall': tensor(0.1111), 'Precision': tensor(0.0667), 'MeanSquaredError': tensor(2.3750)}
{0: ['Recall', 'Precision'], 1: ['MeanSquaredError']}

嵌套的度量集合:

metrics = MetricCollection([
    MetricCollection([
        Accuracy(num_classes=3, average='macro'),
        Precision(num_classes=3, average='macro')
    ], postfix='_macro'),
    MetricCollection([
        Accuracy(num_classes=3, average='micro'),
        Precision(num_classes=3, average='micro')
    ], postfix='_micro'),
], prefix='valmetrics/')
print(metrics(preds, target))  
{'valmetrics/Accuracy_macro': tensor(0.1111), 'valmetrics/Precision_macro': tensor(0.0667), 'valmetrics/Accuracy_micro': tensor(0.1250), 'valmetrics/Precision_micro': tensor(0.1250)}

度量的高级设置#

下面是附加参数列表,可以用于任何度量类(在 **kwargs 参数中),这些参数将改变度量状态的存储和同步方式。

如果你在 GPU 上度量指标,并且遇到 GPU VRAM 即将耗尽的情况,那么下面的参数可能会有所帮助:

  • compute_on_cpu 将在调用 update 后自动将度量状态移动到 CPU,确保 GPU 内存没有被填满。结果是 compute 方法将在 CPU 而不是 GPU 上被调用。只适用于列表中的度量状态。

如果您在分布式环境中运行,TorchMetrics 将自动为您处理分布式同步。但是,以下三个关键字参数可以给任何度量类,以进一步控制分布式聚合:

  • dist_sync_on_step:这个参数是 bool,指示每次调用 forward 时,度量是否应该在不同设备之间同步。通常不建议将此设置为 True,因为在每个批处理之后进行同步是一项昂贵的运算。

  • process_group:默认情况下,在全局范围内同步,即所有正在计算的进程。您可以提供 ProcessGroup 指定应该在哪些设备上进行同步。

  • dist_sync_fn:默认情况下,使用 all_gather() 来执行设备之间的同步。为此参数提供另一个可调用函数,以执行自定义分布式同步。