简单的前向折叠

简单的前向折叠#

%cd ..
import set_env
/media/pc/data/lxw/ai/tvm-book/tests/book/doc/tests
import numpy as np

import tvm
from tvm import relay
from tvm.relay import transform
# from tvm.relay.testing import create_workload
# from tvm.relay.build_module import bind_params_by_name


def initializer(_, param):
    param = np.zeros(param.shape)


def _get_positive_scale(size):
    return np.random.uniform(0.5, 1, size=size).astype("float32")


def run_opt_pass(expr, opt_pass):
    assert isinstance(opt_pass, tvm.transform.Pass)
    mod = tvm.IRModule.from_expr(expr)
    mod = opt_pass(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body
def before(x, conv_weight, in_bias, in_scale, channels, blocking):
    args = [x, conv_weight, in_bias]
    x = relay.multiply(x, in_scale)
    x = relay.nn.relu(x)
    x = relay.add(x, in_bias)
    y = relay.nn.conv2d(
        x,
        conv_weight,
        channels=channels,
        kernel_size=(3, 3),
        padding=(1, 1),
        data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
        kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW",
    )

    return relay.Function(args, y)

def expected(x, conv_weight, in_bias, in_scale, in_channels, channels, blocking):
    # use a fixed order of args so alpha equal check can pass
    args = [x, conv_weight, in_bias]
    if blocking:
        squeezed_scale = relay.squeeze(in_scale, axis=[0, 2, 3])
        x = relay.nn.relu(x)
        in_bias = relay.divide(
            in_bias,
            relay.reshape(squeezed_scale, (1, in_channels // blocking[0], 1, 1, blocking[0])),
        )  # NCHWc
        x = relay.add(x, in_bias)
        conv_weight = relay.multiply(
            conv_weight, relay.reshape(squeezed_scale, (1, in_channels // 2, 1, 1, 2, 1))
        )  # OIHWio
    else:
        squeezed_scale = relay.squeeze(in_scale, axis=[1, 2])
        x = relay.nn.relu(x)
        in_bias = relay.divide(
            in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)
        )
        x = relay.add(x, in_bias)
        conv_weight = relay.multiply(
            conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)
        )

    y = relay.nn.conv2d(
        x,
        conv_weight,
        channels=channels,
        kernel_size=(3, 3),
        padding=(1, 1),
        data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
        kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW",
    )
    return relay.Function(args, y)
test_cases = [
    ((2, 4, 10, 10), 2, None),
    ((2, 2, 10, 10, 2), 8, (2, 4))
]
for shape, channels, blocking in test_cases:
    x = relay.var("x", shape=shape)
    weight = relay.var("weight")
    if blocking:
        in_channels = shape[1] * shape[4]
        in_bias = relay.var("in_bias", shape=(1, in_channels // blocking[0], 1, 1, blocking[0]))
        in_scale = relay.const(
            _get_positive_scale((1, in_channels // blocking[0], 1, 1, blocking[0]))
        )
    else:
        in_channels = shape[1]
        in_bias = relay.var("in_bias", shape=(in_channels, 1, 1))
        in_scale = relay.const(_get_positive_scale((in_channels, 1, 1)))
    y1 = before(x, weight, in_bias, in_scale, channels, blocking)
    y1 = run_opt_pass(y1, transform.InferType())
    print("FoldScaleAxis 前:")
    tvm.IRModule.from_expr(y1).show()

    type_dict = {x.name_hint: x.checked_type for x in y1.params}
    weight = relay.var("weight", type_dict["weight"])
    y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
    y1_expected = expected(x, weight, in_bias, in_scale, in_channels, channels, blocking)

    y1_folded = run_opt_pass(y1_folded, transform.InferType())
    print("FoldScaleAxis 后:")
    tvm.IRModule.from_expr(y1_folded).show()

    y1_expected = run_opt_pass(y1_expected, transform.InferType())
    tvm.ir.assert_structural_equal(y1_folded, y1_expected)

    # out = run_opt_pass(y1_folded, transform.FoldConstant())
    # print("FoldConstant 后:")
    # tvm.IRModule.from_expr(y1_folded).show()
FoldScaleAxis 前:
FoldScaleAxis 后:
FoldScaleAxis 前:
FoldScaleAxis 后:
def @main(%x: Tensor[(2, 4, 10, 10), float32] /* ty=Tensor[(2, 4, 10, 10), float32] */, %weight: Tensor[(2, 4, 3, 3), float32] /* ty=Tensor[(2, 4, 3, 3), float32] */, %in_bias: Tensor[(4, 1, 1), float32] /* ty=Tensor[(4, 1, 1), float32] */) -> Tensor[(2, 2, 10, 10), float32] {
  %0 = multiply(%x, meta[relay.Constant][0] /* ty=Tensor[(4, 1, 1), float32] */) /* ty=Tensor[(2, 4, 10, 10), float32] */;
  %1 = nn.relu(%0) /* ty=Tensor[(2, 4, 10, 10), float32] */;
  %2 = add(%1, %in_bias) /* ty=Tensor[(2, 4, 10, 10), float32] */;
  nn.conv2d(%2, %weight, padding=[1, 1, 1, 1], channels=2, kernel_size=[3, 3]) /* ty=Tensor[(2, 2, 10, 10), float32] */
}
def @main(%x: Tensor[(2, 4, 10, 10), float32] /* ty=Tensor[(2, 4, 10, 10), float32] */, %weight: Tensor[(2, 4, 3, 3), float32] /* ty=Tensor[(2, 4, 3, 3), float32] */, %in_bias: Tensor[(4, 1, 1), float32] /* ty=Tensor[(4, 1, 1), float32] */) -> Tensor[(2, 2, 10, 10), float32] {
  %0 = squeeze(meta[relay.Constant][0] /* ty=Tensor[(4, 1, 1), float32] */, axis=[1, 2]) /* ty=Tensor[(4), float32] */;
  %1 = expand_dims(%0, axis=1, num_newaxis=2) /* ty=Tensor[(4, 1, 1), float32] */;
  %2 = nn.relu(%x) /* ty=Tensor[(2, 4, 10, 10), float32] */;
  %3 = divide(%in_bias, %1) /* ty=Tensor[(4, 1, 1), float32] */;
  %4 = expand_dims(%0, axis=1, num_newaxis=2) /* ty=Tensor[(4, 1, 1), float32] */;
  %5 = add(%2, %3) /* ty=Tensor[(2, 4, 10, 10), float32] */;
  %6 = multiply(%weight, %4) /* ty=Tensor[(2, 4, 3, 3), float32] */;
  nn.conv2d(%5, %6, padding=[1, 1, 1, 1], channels=2, kernel_size=[3, 3]) /* ty=Tensor[(2, 2, 10, 10), float32] */
}
def @main(%x: Tensor[(2, 2, 10, 10, 2), float32] /* ty=Tensor[(2, 2, 10, 10, 2), float32] */, %weight: Tensor[(2, 2, 3, 3, 2, 4), float32] /* ty=Tensor[(2, 2, 3, 3, 2, 4), float32] */, %in_bias: Tensor[(1, 2, 1, 1, 2), float32] /* ty=Tensor[(1, 2, 1, 1, 2), float32] */) -> Tensor[(2, 4, 10, 10, 2), float32] {
  %0 = multiply(%x, meta[relay.Constant][0] /* ty=Tensor[(1, 2, 1, 1, 2), float32] */) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
  %1 = nn.relu(%0) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
  %2 = add(%1, %in_bias) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
  nn.conv2d(%2, %weight, padding=[1, 1, 1, 1], channels=8, kernel_size=[3, 3], data_layout="NCHW2c", kernel_layout="OIHW2i4o") /* ty=Tensor[(2, 4, 10, 10, 2), float32] */
}
def @main(%x: Tensor[(2, 2, 10, 10, 2), float32] /* ty=Tensor[(2, 2, 10, 10, 2), float32] */, %weight: Tensor[(2, 2, 3, 3, 2, 4), float32] /* ty=Tensor[(2, 2, 3, 3, 2, 4), float32] */, %in_bias: Tensor[(1, 2, 1, 1, 2), float32] /* ty=Tensor[(1, 2, 1, 1, 2), float32] */) -> Tensor[(2, 4, 10, 10, 2), float32] {
  %0 = squeeze(meta[relay.Constant][0] /* ty=Tensor[(1, 2, 1, 1, 2), float32] */, axis=[0, 2, 3]) /* ty=Tensor[(2, 2), float32] */;
  %1 = reshape(%0, newshape=[1, 2, 1, 1, 2]) /* ty=Tensor[(1, 2, 1, 1, 2), float32] */;
  %2 = nn.relu(%x) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
  %3 = divide(%in_bias, %1) /* ty=Tensor[(1, 2, 1, 1, 2), float32] */;
  %4 = reshape(%0, newshape=[1, 2, 1, 1, 2, 1]) /* ty=Tensor[(1, 2, 1, 1, 2, 1), float32] */;
  %5 = add(%2, %3) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
  %6 = multiply(%weight, %4) /* ty=Tensor[(2, 2, 3, 3, 2, 4), float32] */;
  nn.conv2d(%5, %6, padding=[1, 1, 1, 1], channels=8, kernel_size=[3, 3], data_layout="NCHW2c", kernel_layout="OIHW2i4o") /* ty=Tensor[(2, 4, 10, 10, 2), float32] */
}