Accuracy

Accuracy#

计算公式:

(8)#\[ \operatorname{Accuracy} = \frac{1}{N} \sum_i^N \mathbf{1}(\mathbf{y}_i, \hat{\mathbf{y}_i}) \]

\(\mathbf{y}\) 是目标值的张量,\(\hat{\mathbf{y}}\) 是预测值的张量。

对于具有概率或 logits 预测的多类(multi-class)和多维多类(multi-dimensional multi-class)数据,参数 top_k 将此度量推广为 Top-K accuracy 度量:对于每个样本,考虑 Top-K 最高概率或 logits 评分项以找到正确的标签。

对于多标签(multi-label)和多维多类(multi-dimensional multi-class)输入,该度量默认计算“全局”accuracy,它分别计算所有标签或子样本。这可以通过设置 subset_accuracy=True 更改为子集 accuracy(要求正确预测样本中的所有标签或子样本)。

import torch
from torchmetrics import Accuracy

target = torch.tensor([0, 1, 2, 3])
preds = torch.tensor([0, 2, 1, 3])
accuracy = Accuracy()
accuracy(preds, target)
tensor(0.5000)
target = torch.tensor([0, 1, 2])
preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]])
accuracy = Accuracy(top_k=2)
accuracy(preds, target)
tensor(0.6667)