分析提取算子

分析提取算子#

import pytest
import tvm
from tvm import relay
from tvm.relay.testing.resnet import get_workload
from tvm.relay.testing import run_opt_pass


def get_conv_net():
    """This gets the net for:
          conv2d
          /  |
         /   |
    conv2d   |
        \    |
         \   |
        elemwise add
             |
    """
    dshape = (1, 1, 5, 1)
    x = relay.var("x", shape=dshape)
    y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
    x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)

    z = relay.add(y, x1)

    return tvm.IRModule.from_expr(z)


def get_conv2d():
    x = relay.var("x", shape=(1, 56, 56, 64))
    weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
    y = relay.nn.conv2d(
        x,
        weight1,
        channels=32,
        kernel_size=(3, 3),
        padding=(1, 1),
        data_layout="NHWC",
        kernel_layout="HWIO",
    )
    return tvm.IRModule.from_expr(y)


def test_extract_identity():
    mod = get_conv2d()
    op_freqs = relay.analysis.list_op_freqs(mod)
    assert len(op_freqs) == 1
    assert op_freqs["nn.conv2d"] == 1


def test_extract_conv_net():
    mod = get_conv_net()
    op_freqs = relay.analysis.list_op_freqs(mod)
    assert len(op_freqs) == 2
    assert op_freqs["add"] == 1
    assert op_freqs["nn.conv2d"] == 2


def test_extract_fused():
    mod = get_conv_net()
    mod = relay.transform.InferType()(mod)
    mod = relay.transform.FuseOps(3)(mod)

    op_freqs = relay.analysis.list_op_freqs(mod)
    assert len(op_freqs) == 2
    assert op_freqs["add"] == 1
    assert op_freqs["nn.conv2d"] == 2


def test_extract_resnet():
    mod, _params = get_workload()
    expected_op_freqs = {
        "nn.batch_norm": 19,
        "nn.conv2d": 21,
        "nn.relu": 18,
        "nn.max_pool2d": 1,
        "add": 8,
        "nn.global_avg_pool2d": 1,
        "nn.batch_flatten": 1,
        "nn.dense": 1,
        "nn.bias_add": 1,
        "nn.softmax": 1,
    }
    op_freqs = relay.analysis.list_op_freqs(mod)
    assert len(op_freqs) == len(expected_op_freqs)
    assert all([op_freqs[op] == expected_op_freqs[op] for op in expected_op_freqs])


if __name__ == "__main__":
    tvm.testing.main()
ERROR: usage: ipykernel_launcher.py [options] [file_or_dir] [file_or_dir] [...]
ipykernel_launcher.py: error: unrecognized arguments: --f=/home/ai/.local/share/jupyter/runtime/kernel-v399c99730476352b7b4b0c36dfc978d956b870aea.json
  inifile: /media/pc/data/lxw/ai/tvm-book/pyproject.toml
  rootdir: /media/pc/data/lxw/ai/tvm-book

An exception has occurred, use %tb to see the full traceback.

SystemExit: 4
/media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3585: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.
  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)