双路径前向折叠

双路径前向折叠#

%cd ..
import set_env
import numpy as np

import tvm
from tvm import relay
from tvm.relay import transform
# from tvm.relay.testing import create_workload
# from tvm.relay.build_module import bind_params_by_name


def initializer(_, param):
    param = np.zeros(param.shape)


def _get_positive_scale(size):
    return np.random.uniform(0.5, 1, size=size).astype("float32")


def run_opt_pass(expr, opt_pass):
    assert isinstance(opt_pass, tvm.transform.Pass)
    mod = tvm.IRModule.from_expr(expr)
    mod = opt_pass(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body
/media/pc/data/lxw/ai/tvm-book/tests/book/doc/tests

轴缩放被两个消费者消耗:

def before(x, conv_weight, in_bias, in_scale, channels, blocking):
    args = [x, conv_weight, in_bias]
    x = relay.multiply(in_scale, x)
    x = relay.nn.relu(x)
    x = relay.subtract(x, in_bias)
    y1 = relay.nn.conv2d(
        x,
        conv_weight,
        channels=channels,
        kernel_size=(3, 3),
        data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
        kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
        groups=channels,
        padding=(1, 1),
    )
    y2 = relay.nn.conv2d(
        x,
        conv_weight,
        channels=channels,
        kernel_size=(3, 3),
        data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
        kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
        groups=channels,
        padding=(1, 1),
    )
    z = relay.add(y1, y2)
    return relay.Function(args, z)

def expected(x, conv_weight, in_bias, in_scale, channels, blocking):
    args = [x, conv_weight, in_bias]
    x = relay.nn.relu(x)
    if blocking:
        _in_scale = relay.reshape(
            in_scale, (1, 1, 1, channels // blocking[0], blocking[0])
        )  # NHWCc
    else:
        _in_scale = in_scale
    in_bias = relay.divide(in_bias, _in_scale)
    x = relay.subtract(x, in_bias)
    if blocking:
        _in_scale = relay.reshape(
            in_scale, (1, 1, 1, channels // blocking[0], 1, blocking[0])
        )  # HWIOio
    y1 = relay.nn.conv2d(
        x,
        relay.multiply(conv_weight, _in_scale),
        channels=channels,
        kernel_size=(3, 3),
        data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
        kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
        groups=channels,
        padding=(1, 1),
    )
    if blocking:
        _in_scale = relay.reshape(
            in_scale, (1, 1, 1, channels // blocking[0], 1, blocking[0])
        )  # HWIOio
    y2 = relay.nn.conv2d(
        x,
        relay.multiply(conv_weight, _in_scale),
        channels=channels,
        kernel_size=(3, 3),
        data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
        kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
        groups=channels,
        padding=(1, 1),
    )
    z = relay.add(y1, y2)
    return relay.Function(args, z)
test_cases = [
    ((2, 4, 10, 3), 3, None),
    ((2, 4, 10, 2, 2), 4, (2, 2))
]
for dshape, channels, blocking in test_cases:
    x = relay.var("x", shape=dshape)
    if blocking:
        in_channels = dshape[3] * dshape[4]
        wshape = (3, 3, 1, channels // blocking[1], 1, blocking[1])  # HWIOio
        weight = relay.var("weight", shape=wshape)
        in_bias = relay.var("in_bias", shape=(in_channels // blocking[0], blocking[0]))
        in_scale = relay.const(_get_positive_scale((in_channels // blocking[0], blocking[0])))
    else:
        in_channels = dshape[-1]
        wshape = (3, 3, 1, channels)  # HWIO
        weight = relay.var("weight", shape=wshape)
        in_bias = relay.var("in_bias", shape=(in_channels,))
        in_scale = relay.const(
            _get_positive_scale(
                in_channels,
            )
        )

    # test depthwise
    assert in_channels == channels

    y1 = before(x, weight, in_bias, in_scale, channels, blocking)
    y1 = run_opt_pass(y1, transform.InferType())
    print("FoldScaleAxis 前:")
    tvm.IRModule.from_expr(y1).show()

    y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
    print("FoldScaleAxis 后:")
    tvm.IRModule.from_expr(y1_folded).show()
    
    type_dict = {x.name_hint: x.checked_type for x in y1.params}
    weight = relay.var("weight", type_dict["weight"])
    y1_expected = expected(x, weight, in_bias, in_scale, channels, blocking)
    y1_expected = run_opt_pass(y1_expected, transform.InferType())
    tvm.ir.assert_structural_equal(y1_folded, y1_expected)
FoldScaleAxis 前:
FoldScaleAxis 后:
FoldScaleAxis 前:
FoldScaleAxis 后:
def @main(%x: Tensor[(2, 4, 10, 3), float32] /* ty=Tensor[(2, 4, 10, 3), float32] */, %weight: Tensor[(3, 3, 1, 3), float32] /* ty=Tensor[(3, 3, 1, 3), float32] */, %in_bias: Tensor[(3), float32] /* ty=Tensor[(3), float32] */) -> Tensor[(2, 4, 10, 3), float32] {
  %0 = multiply(meta[relay.Constant][0] /* ty=Tensor[(3), float32] */, %x) /* ty=Tensor[(2, 4, 10, 3), float32] */;
  %1 = nn.relu(%0) /* ty=Tensor[(2, 4, 10, 3), float32] */;
  %2 = subtract(%1, %in_bias) /* ty=Tensor[(2, 4, 10, 3), float32] */;
  %3 = nn.conv2d(%2, %weight, padding=[1, 1, 1, 1], groups=3, channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO") /* ty=Tensor[(2, 4, 10, 3), float32] */;
  %4 = nn.conv2d(%2, %weight, padding=[1, 1, 1, 1], groups=3, channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO") /* ty=Tensor[(2, 4, 10, 3), float32] */;
  add(%3, %4) /* ty=Tensor[(2, 4, 10, 3), float32] */
}
def @main(%x: Tensor[(2, 4, 10, 3), float32] /* ty=Tensor[(2, 4, 10, 3), float32] */, %weight: Tensor[(3, 3, 1, 3), float32] /* ty=Tensor[(3, 3, 1, 3), float32] */, %in_bias: Tensor[(3), float32] /* ty=Tensor[(3), float32] */) -> Tensor[(2, 4, 10, 3), float32] {
  %0 = nn.relu(%x) /* ty=Tensor[(2, 4, 10, 3), float32] */;
  %1 = divide(%in_bias, meta[relay.Constant][0] /* ty=Tensor[(3), float32] */) /* ty=Tensor[(3), float32] */;
  %2 = subtract(%0, %1) /* ty=Tensor[(2, 4, 10, 3), float32] */;
  %3 = multiply(%weight, meta[relay.Constant][0] /* ty=Tensor[(3), float32] */) /* ty=Tensor[(3, 3, 1, 3), float32] */;
  %4 = multiply(%weight, meta[relay.Constant][0] /* ty=Tensor[(3), float32] */) /* ty=Tensor[(3, 3, 1, 3), float32] */;
  %5 = nn.conv2d(%2, %3, padding=[1, 1, 1, 1], groups=3, channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO") /* ty=Tensor[(2, 4, 10, 3), float32] */;
  %6 = nn.conv2d(%2, %4, padding=[1, 1, 1, 1], groups=3, channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO") /* ty=Tensor[(2, 4, 10, 3), float32] */;
  add(%5, %6) /* ty=Tensor[(2, 4, 10, 3), float32] */
}
def @main(%x: Tensor[(2, 4, 10, 2, 2), float32] /* ty=Tensor[(2, 4, 10, 2, 2), float32] */, %weight: Tensor[(3, 3, 1, 2, 1, 2), float32] /* ty=Tensor[(3, 3, 1, 2, 1, 2), float32] */, %in_bias: Tensor[(2, 2), float32] /* ty=Tensor[(2, 2), float32] */) -> Tensor[(2, 4, 10, 2, 2), float32] {
  %0 = multiply(meta[relay.Constant][0] /* ty=Tensor[(2, 2), float32] */, %x) /* ty=Tensor[(2, 4, 10, 2, 2), float32] */;
  %1 = nn.relu(%0) /* ty=Tensor[(2, 4, 10, 2, 2), float32] */;
  %2 = subtract(%1, %in_bias) /* ty=Tensor[(2, 4, 10, 2, 2), float32] */;
  %3 = nn.conv2d(%2, %weight, padding=[1, 1, 1, 1], groups=4, channels=4, kernel_size=[3, 3], data_layout="NHWC2c", kernel_layout="HWIO1i2o") /* ty=Tensor[(2, 4, 10, 2, 2), float32] */;
  %4 = nn.conv2d(%2, %weight, padding=[1, 1, 1, 1], groups=4, channels=4, kernel_size=[3, 3], data_layout="NHWC2c", kernel_layout="HWIO1i2o") /* ty=Tensor[(2, 4, 10, 2, 2), float32] */;
  add(%3, %4) /* ty=Tensor[(2, 4, 10, 2, 2), float32] */
}
def @main(%x: Tensor[(2, 4, 10, 2, 2), float32] /* ty=Tensor[(2, 4, 10, 2, 2), float32] */, %weight: Tensor[(3, 3, 1, 2, 1, 2), float32] /* ty=Tensor[(3, 3, 1, 2, 1, 2), float32] */, %in_bias: Tensor[(2, 2), float32] /* ty=Tensor[(2, 2), float32] */) -> Tensor[(2, 4, 10, 2, 2), float32] {
  %0 = reshape(meta[relay.Constant][0] /* ty=Tensor[(2, 2), float32] */, newshape=[1, 1, 1, 2, 2]) /* ty=Tensor[(1, 1, 1, 2, 2), float32] */;
  %1 = nn.relu(%x) /* ty=Tensor[(2, 4, 10, 2, 2), float32] */;
  %2 = divide(%in_bias, %0) /* ty=Tensor[(1, 1, 1, 2, 2), float32] */;
  %3 = reshape(meta[relay.Constant][0] /* ty=Tensor[(2, 2), float32] */, newshape=[1, 1, 1, 2, 1, 2]) /* ty=Tensor[(1, 1, 1, 2, 1, 2), float32] */;
  %4 = subtract(%1, %2) /* ty=Tensor[(2, 4, 10, 2, 2), float32] */;
  %5 = multiply(%weight, %3) /* ty=Tensor[(3, 3, 1, 2, 1, 2), float32] */;
  %6 = reshape(meta[relay.Constant][0] /* ty=Tensor[(2, 2), float32] */, newshape=[1, 1, 1, 2, 1, 2]) /* ty=Tensor[(1, 1, 1, 2, 1, 2), float32] */;
  %7 = multiply(%weight, %6) /* ty=Tensor[(3, 3, 1, 2, 1, 2), float32] */;
  %8 = nn.conv2d(%4, %5, padding=[1, 1, 1, 1], groups=4, channels=4, kernel_size=[3, 3], data_layout="NHWC2c", kernel_layout="HWIO1i2o") /* ty=Tensor[(2, 4, 10, 2, 2), float32] */;
  %9 = nn.conv2d(%4, %7, padding=[1, 1, 1, 1], groups=4, channels=4, kernel_size=[3, 3], data_layout="NHWC2c", kernel_layout="HWIO1i2o") /* ty=Tensor[(2, 4, 10, 2, 2), float32] */;
  add(%8, %9) /* ty=Tensor[(2, 4, 10, 2, 2), float32] */
}