FLOPs#
参考:FLOPs
FLOPs(Floating Point Operations,浮点运算数)和 MACs(Multiply-Accumulate Operations,乘加运算数)是常用于计算深度学习模型计算复杂度的指标。它们是快速、简单地了解执行给定计算所需的算术运算数量的方法。例如,在为边缘设备使用不同的模型架构(如 MobileNet 或 DenseNet)时,人们使用 MACs 或 FLOPs 来估计模型性能。同时,使用“估计”这个词的原因是,这两个指标都是近似值,而不是实际运行时性能模型的捕获。然而,它们仍然可以提供有关能量消耗或计算要求的非常有用的洞察,这在边缘计算中非常有用。
FLOPs 特指对浮点数进行的加法、减法、乘法和除法等浮点运算的数量。这些运算在机器学习中涉及的许多数学计算中非常常见,例如矩阵乘法、激活函数和梯度计算。FLOPs 通常用于衡量模型或模型内特定操作的计算成本或复杂度。当需要提供所需算术运算总数的估计时,这非常有用,通常用于衡量计算效率的上下文中。
另一方面,MACs 只计算乘加操作的数量,这涉及将两个数字相乘并相加结果。这种运算是许多线性代数操作的基础,例如矩阵乘法、卷积和点积。在严重依赖线性代数运算的模型中,如卷积神经网络(CNN),MACs 通常用作计算复杂度的更具体度量。
备注
全大写的 FLOPS 是“每秒浮点运算数”的缩写,指的是计算速度,通常用作硬件性能的度量。FLOPS 中的“S”表示“秒”,与“P”(作为“每”)一起,通常用于表示比率。
一般AI社区的共识是,一个 MAC 大约等于两个 FLOP。
import warnings
warnings.filterwarnings("ignore", category=UserWarning) # 忽略用户警告
# from typing import Any, Sequence
# import traceback
# import torch
# from torch import nn, Tensor
# from torch.fx.passes.shape_prop import ShapeProp
# from torch.fx.node import Argument, Node, Target, map_aggregate
# from torch.fx._compatibility import compatibility
# from torch_book.scan.common import FLOPsABC
from torch_book.scan.flop import ElementwiseFLOPs
import numpy as np
from typing import Any
from torch import Tensor, nn
from torch.types import Number
from torch.fx._compatibility import compatibility
from torch_book.scan.common import FLOPsABC
class ElementwiseFLOPs(FLOPsABC):
@compatibility(is_backward_compatible=True)
def fetch_method_flops(self, self_obj: Tensor, result: Tensor, *args_tail, **kwargs):
"""计算方法的FLOPs"""
return np.prod(result.shape)
@compatibility(is_backward_compatible=True)
def fetch_function_flops(self, result: Tensor|Number, *args, **kwargs) -> Any:
"""计算函数的FLOPs"""
assert len(args) == 2, len(args)
total_flops = None
if isinstance(result, Number):
total_flops = 1
elif isinstance(result, Tensor):
total_flops = np.prod(result.shape)
else:
raise TypeError(type(result))
return total_flops
@compatibility(is_backward_compatible=True)
def fetch_module_flops(self, module: nn.Module, result: Tensor, *args, **kwargs) -> Any:
"""计算模块的FLOPs"""
assert len(args) == 1
assert isinstance(args[0], Tensor)
assert isinstance(result, Tensor)
input_shape = args[0].shape # [..., d_in]
result_shape = result.shape
assert input_shape == result_shape
total_flops = np.prod(result_shape)
return total_flops
import torch
from torch import nn
class Demo(nn.Module):
def __init__(self):
super().__init__()
# self.layer = nn.Linear(5, 4, bias=True)
self.layer1 = nn.ReLU()
self.layer2 = nn.Sigmoid()
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x
# model = nn.ReLU()
model = Demo()
gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn(1, 5, 32, 32)
ElementwiseFLOPs(gm).propagate(sample_input);
from torch_book.scan.show_flop import show_flops_table
show_flops_table(gm, sample_input)
╒═════════════╤═════════════╤═════════════╤═══════════════════════╤═════════╕
│ node_name │ node_op │ op_target │ nn_module_stack[-1] │ FLOPs │
╞═════════════╪═════════════╪═════════════╪═══════════════════════╪═════════╡
│ x │ placeholder │ x │ │ 0 │
├─────────────┼─────────────┼─────────────┼───────────────────────┼─────────┤
│ layer1 │ call_module │ layer1 │ ReLU │ 5120 │
├─────────────┼─────────────┼─────────────┼───────────────────────┼─────────┤
│ layer2 │ call_module │ layer2 │ Sigmoid │ 5120 │
├─────────────┼─────────────┼─────────────┼───────────────────────┼─────────┤
│ output │ output │ output │ │ 0 │
╘═════════════╧═════════════╧═════════════╧═══════════════════════╧═════════╛
total_flops = 10,240
result_tablew
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[8], line 1
----> 1 result_tablew
NameError: name 'result_tablew' is not defined
from torch import nn, Tensor, Size
from torch.types import Number
def flops_zero() -> int:
return 0
def flops_elemwise(result_shape: Size) -> int:
return result_shape.numel()
def flops_matmul(tensor1_shape: Size, tensor2_shape: Size, result_shape: Size) -> int:
# 可根据输入维度改为分情况处理,参考https://github.com/zhijian-liu/torchprofile/blob/6d80fe57bb8c6bc9f789da7925fac6547fa9502b/torchprofile/handlers.py#L35
def get_reduce_dim_shape(_s: Size, is_first_mat: bool):
return _s[0] if len(_s) == 1 else _s[-1 if is_first_mat else -2]
reduce_dim_shape = get_reduce_dim_shape(tensor1_shape, True)
assert reduce_dim_shape == get_reduce_dim_shape(tensor2_shape, False)
return (2 * reduce_dim_shape - 1) * result_shape.numel()
class LinearFLOPs(FLOPsABC):
@compatibility(is_backward_compatible=True)
def fetch_method_flops(self, self_obj: Any, result: Tensor, *args_tail, **kwargs):
"""计算方法的FLOPs"""
...
@compatibility(is_backward_compatible=True)
def fetch_function_flops(self, result: Tensor, *args, **kwargs) -> Any:
"""计算函数的FLOPs"""
...
@compatibility(is_backward_compatible=True)
def fetch_module_flops(self, module: Any, result: Tensor, *args, **kwargs) -> Any:
"""计算模块的FLOPs"""
assert len(args) == 1
assert isinstance(args[0], Tensor)
assert isinstance(result, Tensor)
input_shape = args[0].shape # [..., d_in]
weight_shape = module.weight.T.shape # [d_out, d_in].T -> [d_in, d_out]
result_shape = result.shape
assert input_shape[-1] == weight_shape[0], f"{input_shape}, {weight_shape}"
matmul_shape = Size(list(input_shape[:-1]) + list(weight_shape[-1:]))
assert matmul_shape == result_shape
total_flops = flops_matmul(input_shape, weight_shape, result_shape)
if module.bias is not None:
total_flops += flops_elemwise(result_shape)
return total_flops
from torch import nn
model = SimpleModel()
gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn(1, 5)
result = LinearFLOPs(gm).propagate(sample_input)
import numpy as np
np.prod(a.shape)
a = torch.zeros(4, 4)
import torch
from torch import nn, Tensor
from torch.fx.node import Node
from tabulate import tabulate
from d2py.utils.log_config import config_logging
from torch_book.scan_temp.flop import get_FLOPs
import warnings
warnings.filterwarnings("ignore", category=UserWarning) # 忽略用户警告
config_logging("flops.log", filter_mod_names={"torch"}) # 配置日志信息
model = nn.Linear(5, 4, bias=True)
gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn(1, 5)