# 分析提取中间表达式

In [1]:
import tvm
from tvm import relay

In [2]:
def get_conv_net():
    """This gets the net for:
          conv2d
          /  |
         /   |
    conv2d   |
        \    |
         \   |
        elemwise add
             |
             |
             |
           split
             |
             |
             |
        elemwise add
    """
    dshape = (1, 1, 5, 1)
    x = relay.var("x", shape=dshape)
    y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
    x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)

    z = relay.add(y, x1)

    tuple_out = relay.op.split(z, indices_or_sections=1, axis=0)

    tuple_0_add = relay.add(tuple_out[0], relay.const(1, dtype="float32"))

    return tvm.IRModule.from_expr(tuple_0_add)


def get_conv2d():
    x = relay.var("x", shape=(1, 56, 56, 64))
    weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
    y = relay.nn.conv2d(
        x,
        weight1,
        channels=32,
        kernel_size=(3, 3),
        padding=(1, 1),
        data_layout="NHWC",
        kernel_layout="HWIO",
    )
    return tvm.IRModule.from_expr(y)


In [3]:
dshape = (1, 1, 5, 1)

In [5]:
mod = get_conv_net()
mod.show()

In [7]:
relay.analysis.extract_intermdeiate_expr(mod, 0).show()

In [8]:
relay.analysis.extract_intermdeiate_expr(mod, 1).show()