Accuracy#
计算公式:
(8)#\[
\operatorname{Accuracy} = \frac{1}{N} \sum_i^N \mathbf{1}(\mathbf{y}_i, \hat{\mathbf{y}_i})
\]
\(\mathbf{y}\) 是目标值的张量,\(\hat{\mathbf{y}}\) 是预测值的张量。
对于具有概率或 logits 预测的多类(multi-class)和多维多类(multi-dimensional multi-class)数据,参数 top_k
将此度量推广为 Top-K accuracy 度量:对于每个样本,考虑 Top-K 最高概率或 logits 评分项以找到正确的标签。
对于多标签(multi-label)和多维多类(multi-dimensional multi-class)输入,该度量默认计算“全局”accuracy,它分别计算所有标签或子样本。这可以通过设置 subset_accuracy=True
更改为子集 accuracy(要求正确预测样本中的所有标签或子样本)。
import torch
from torchmetrics import Accuracy
target = torch.tensor([0, 1, 2, 3])
preds = torch.tensor([0, 2, 1, 3])
accuracy = Accuracy()
accuracy(preds, target)
tensor(0.5000)
target = torch.tensor([0, 1, 2])
preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]])
accuracy = Accuracy(top_k=2)
accuracy(preds, target)
tensor(0.6667)