FLOPs

FLOPs#

参考:FLOPs

FLOPs(Floating Point Operations,浮点运算数)和 MACs(Multiply-Accumulate Operations,乘加运算数)是常用于计算深度学习模型计算复杂度的指标。它们是快速、简单地了解执行给定计算所需的算术运算数量的方法。例如,在为边缘设备使用不同的模型架构(如 MobileNet 或 DenseNet)时,人们使用 MACs 或 FLOPs 来估计模型性能。同时,使用“估计”这个词的原因是,这两个指标都是近似值,而不是实际运行时性能模型的捕获。然而,它们仍然可以提供有关能量消耗或计算要求的非常有用的洞察,这在边缘计算中非常有用。

FLOPs 特指对浮点数进行的加法、减法、乘法和除法等浮点运算的数量。这些运算在机器学习中涉及的许多数学计算中非常常见,例如矩阵乘法、激活函数和梯度计算。FLOPs 通常用于衡量模型或模型内特定操作的计算成本或复杂度。当需要提供所需算术运算总数的估计时,这非常有用,通常用于衡量计算效率的上下文中。

另一方面,MACs 只计算乘加操作的数量,这涉及将两个数字相乘并相加结果。这种运算是许多线性代数操作的基础,例如矩阵乘法、卷积和点积。在严重依赖线性代数运算的模型中,如卷积神经网络(CNN),MACs 通常用作计算复杂度的更具体度量。

备注

全大写的 FLOPS 是“每秒浮点运算数”的缩写,指的是计算速度,通常用作硬件性能的度量。FLOPS 中的“S”表示“秒”,与“P”(作为“每”)一起,通常用于表示比率。

一般AI社区的共识是,一个 MAC 大约等于两个 FLOP。

import warnings
warnings.filterwarnings("ignore", category=UserWarning) # 忽略用户警告
# from typing import Any, Sequence
# import traceback
# import torch
# from torch import nn, Tensor
# from torch.fx.passes.shape_prop import ShapeProp
# from torch.fx.node import Argument, Node, Target, map_aggregate
# from torch.fx._compatibility import compatibility
# from torch_book.scan.common import FLOPsABC
from torch_book.scan.flop import ElementwiseFLOPs
import numpy as np
from typing import Any
from torch import Tensor, nn
from torch.types import Number
from torch.fx._compatibility import compatibility
from torch_book.scan.common import FLOPsABC


class ElementwiseFLOPs(FLOPsABC):
    @compatibility(is_backward_compatible=True)
    def fetch_method_flops(self, self_obj: Tensor, result: Tensor, *args_tail, **kwargs):
        """计算方法的FLOPs"""
        return np.prod(result.shape)

    @compatibility(is_backward_compatible=True)
    def fetch_function_flops(self, result: Tensor|Number, *args, **kwargs) -> Any:
        """计算函数的FLOPs"""
        assert len(args) == 2, len(args)
        total_flops = None
        if isinstance(result, Number):
            total_flops = 1
        elif isinstance(result, Tensor):
            total_flops = np.prod(result.shape)
        else:
            raise TypeError(type(result))
        return total_flops

    @compatibility(is_backward_compatible=True)
    def fetch_module_flops(self, module: nn.Module, result: Tensor, *args, **kwargs) -> Any:
        """计算模块的FLOPs"""
        assert len(args) == 1
        assert isinstance(args[0], Tensor)
        assert isinstance(result, Tensor)
        input_shape = args[0].shape  # [..., d_in]
        result_shape = result.shape
        assert input_shape == result_shape
        total_flops = np.prod(result_shape)
        return total_flops
import torch
from torch import nn

class Demo(nn.Module):
    def __init__(self):
        super().__init__()
        # self.layer = nn.Linear(5, 4, bias=True)
        self.layer1 = nn.ReLU()
        self.layer2 = nn.Sigmoid()

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

# model = nn.ReLU()
model = Demo()
gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn(1, 5, 32, 32)
ElementwiseFLOPs(gm).propagate(sample_input);
from torch_book.scan.show_flop import show_flops_table
show_flops_table(gm, sample_input)
╒═════════════╤═════════════╤═════════════╤═══════════════════════╤═════════╕
│ node_name   │ node_op     │ op_target   │ nn_module_stack[-1]   │   FLOPs │
╞═════════════╪═════════════╪═════════════╪═══════════════════════╪═════════╡
│ x           │ placeholder │ x           │                       │       0 │
├─────────────┼─────────────┼─────────────┼───────────────────────┼─────────┤
│ layer1      │ call_module │ layer1      │ ReLU                  │    5120 │
├─────────────┼─────────────┼─────────────┼───────────────────────┼─────────┤
│ layer2      │ call_module │ layer2      │ Sigmoid               │    5120 │
├─────────────┼─────────────┼─────────────┼───────────────────────┼─────────┤
│ output      │ output      │ output      │                       │       0 │
╘═════════════╧═════════════╧═════════════╧═══════════════════════╧═════════╛
total_flops = 10,240
result_tablew
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[8], line 1
----> 1 result_tablew

NameError: name 'result_tablew' is not defined
from torch import nn, Tensor, Size
from torch.types import Number

def flops_zero() -> int:
    return 0

def flops_elemwise(result_shape: Size) -> int:
    return result_shape.numel()

def flops_matmul(tensor1_shape: Size, tensor2_shape: Size, result_shape: Size) -> int:
    # 可根据输入维度改为分情况处理,参考https://github.com/zhijian-liu/torchprofile/blob/6d80fe57bb8c6bc9f789da7925fac6547fa9502b/torchprofile/handlers.py#L35
    def get_reduce_dim_shape(_s: Size, is_first_mat: bool):
        return _s[0] if len(_s) == 1 else _s[-1 if is_first_mat else -2]
    reduce_dim_shape = get_reduce_dim_shape(tensor1_shape, True)
    assert reduce_dim_shape == get_reduce_dim_shape(tensor2_shape, False)
    return (2 * reduce_dim_shape - 1) * result_shape.numel()

class LinearFLOPs(FLOPsABC):
    @compatibility(is_backward_compatible=True)
    def fetch_method_flops(self, self_obj: Any, result: Tensor, *args_tail, **kwargs):
        """计算方法的FLOPs"""
        ...

    @compatibility(is_backward_compatible=True)
    def fetch_function_flops(self, result: Tensor, *args, **kwargs) -> Any:
        """计算函数的FLOPs"""
        ...

    @compatibility(is_backward_compatible=True)
    def fetch_module_flops(self, module: Any, result: Tensor, *args, **kwargs) -> Any:
        """计算模块的FLOPs"""
        assert len(args) == 1
        assert isinstance(args[0], Tensor)
        assert isinstance(result, Tensor)
        input_shape = args[0].shape  # [..., d_in]
        weight_shape = module.weight.T.shape  # [d_out, d_in].T -> [d_in, d_out]
        result_shape = result.shape

        assert input_shape[-1] == weight_shape[0], f"{input_shape}, {weight_shape}"
        matmul_shape = Size(list(input_shape[:-1]) + list(weight_shape[-1:]))
        assert matmul_shape == result_shape

        total_flops = flops_matmul(input_shape, weight_shape, result_shape)
        if module.bias is not None:
            total_flops += flops_elemwise(result_shape)
        return total_flops
from torch import nn
model = SimpleModel()
gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn(1, 5)
result = LinearFLOPs(gm).propagate(sample_input)
import numpy as np
np.prod(a.shape)
a = torch.zeros(4, 4)
import torch
from torch import nn, Tensor
from torch.fx.node import Node
from tabulate import tabulate
from d2py.utils.log_config import config_logging
from torch_book.scan_temp.flop import get_FLOPs
import warnings
warnings.filterwarnings("ignore", category=UserWarning) # 忽略用户警告
config_logging("flops.log", filter_mod_names={"torch"}) # 配置日志信息
model = nn.Linear(5, 4, bias=True)
gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn(1, 5)