TorchMetrics 简介

TorchMetrics 简介#

TorchMetrics 包含 80 多个 PyTorch 度量实现以及易于使用的 API 来创建自定义度量的集合。它提供了:

  • 标准化接口以增加再现性

  • 减少样板

  • 兼容分布训练

  • 严格测试

  • 批量自动累积

  • 多个设备之间的自动同步

你可以在任何 PyTorch 模型中使用 TorchMetrics,或者在 PyTorch Lightning 中使用 TorchMetrics 来享受以下额外的好处:

  • 数据将始终放在与 metrics 相同的设备上

  • 您可以在 Lightning 中直接记录 Metric 对象,以减少更多的样板

函数模式#

类似于 torch.nn 大多数度量既有基于类的版本,也有基于函数的版本。函数版本实现了计算每个度量所需的基本运算。它们是简单的 Python 函数,torch.Tensor 作为输入并返回对应的 torch.Tensor 作为度量。

下面的代码片段显示了使用函数接口计算精度的简单示例:

import torch
import torchmetrics

# 模拟分类问题
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

acc = torchmetrics.functional.accuracy(preds, target)
acc
tensor(0.1000)

模块模式#

几乎所有的函数度量都有对应的基于类的度量,该度量在下面称为函数对应。基于类的度量的特征是有一个或多个内部度量状态(类似于 PyTorch 模块的参数),允许它们提供额外的功能:

  • 多批次积累

  • 多台设备间自动同步

  • 度量算法

下面的代码展示了如何使用基于类的接口:

import torch
import torchmetrics

# 初始化 metric
metric = torchmetrics.Accuracy()

n_batches = 10
for i in range(n_batches):
    # 模拟分类问题
    preds = torch.randn(10, 5).softmax(dim=-1)
    target = torch.randint(5, (10,))
    # 度量当前 batch
    acc = metric(preds, target)
    print(f"Accuracy on batch {i}: {acc}")

# 使用自定义累积对所有批次进行度量
acc = metric.compute()
print(f"Accuracy on all data: {acc}")

# 重置内部状态,以便度量为新数据做好准备
metric.reset()
Accuracy on batch 0: 0.4000000059604645
Accuracy on batch 1: 0.4000000059604645
Accuracy on batch 2: 0.10000000149011612
Accuracy on batch 3: 0.5
Accuracy on batch 4: 0.20000000298023224
Accuracy on batch 5: 0.30000001192092896
Accuracy on batch 6: 0.20000000298023224
Accuracy on batch 7: 0.4000000059604645
Accuracy on batch 8: 0.30000001192092896
Accuracy on batch 9: 0.4000000059604645
Accuracy on all data: 0.3199999928474426

自定义度量#

实现你自定义度量就像子类化 torch.nn.Module 一样简单。简单地说,子类化 Metric 并执行以下运算:

  • 在实现 __init__ 的地方调用 self.add_state 以用于度量计算所需的每个内部状态

  • 实现 update 方法,其中更新度量状态所需的所有逻辑都放在这里

  • 实现 compute 方法,在这里进行最终的度量计算