双路径后向折叠失败

双路径后向折叠失败#

%cd ..
import set_env
import numpy as np

import tvm
from tvm import relay
from tvm.relay import transform
# from tvm.relay.testing import create_workload
# from tvm.relay.build_module import bind_params_by_name


def initializer(_, param):
    param = np.zeros(param.shape)


def _get_positive_scale(size):
    return np.random.uniform(0.5, 1, size=size).astype("float32")


def run_opt_pass(expr, opt_pass):
    assert isinstance(opt_pass, tvm.transform.Pass)
    mod = tvm.IRModule.from_expr(expr)
    mod = opt_pass(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body
/media/pc/data/lxw/ai/tvm-book/tests/book/doc/tests
def fail1(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
    args = [x, conv_weight, out_bias]
    y1 = relay.nn.conv2d(
        x,
        conv_weight,
        channels=channels,
        kernel_size=(3, 3),
        padding=(1, 1),
        data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
        kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
    )
    y1 = relay.nn.relu(y1)
    y2 = relay.nn.conv2d(
        x,
        conv_weight,
        channels=channels,
        kernel_size=(3, 3),
        padding=(1, 1),
        data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
        kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
        out_layout="CNHW{}c".format(blocking[1]) if blocking else "CNHW",
    )
    # fold will fail because the axis from two path
    # differs from each other.
    y2 = relay.nn.relu(y2)
    y = relay.add(y1, y2)
    y = relay.multiply(y, out_scale)
    return relay.Function(args, y)

def fail2(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
    args = [x, conv_weight, out_bias]
    y1 = relay.nn.conv2d(
        x,
        conv_weight,
        channels=channels,
        kernel_size=(3, 3),
        padding=(1, 1),
        data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
        kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
    )
    y2 = relay.nn.relu(y1)
    # fold will fail because y1 is referred also by y2
    y1 = relay.multiply(y1, out_scale)
    y = relay.add(y1, y2)
    return relay.Function(args, y)
def check(shape, in_channels, channels, blocking, fbefore):
    x = relay.var("x", shape=shape)
    weight = relay.var("weight")
    if blocking:
        out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1]))
        out_scale = relay.const(
            _get_positive_scale((channels // blocking[1], 1, 1, blocking[1]))
        )
    else:
        out_bias = relay.var("out_bias", shape=(channels, 1, 1))
        out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
    y1 = fbefore(x, weight, out_bias, out_scale, in_channels, channels, blocking)
    y1 = run_opt_pass(y1, transform.InferType())
    tvm.IRModule.from_expr(y1).show()

    y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
    tvm.ir.assert_structural_equal(y1_folded, y1)

check((4, 4, 10, 10), 4, 4, None, fail1)
check((2, 2, 10, 10, 2), 4, 4, (2, 2), fail1)
check((4, 4, 10, 10), 4, 4, None, fail2)
check((4, 2, 10, 10, 2), 4, 4, (2, 2), fail2)
def @main(%x: Tensor[(4, 4, 10, 10), float32] /* ty=Tensor[(4, 4, 10, 10), float32] */, %weight: Tensor[(4, 4, 3, 3), float32] /* ty=Tensor[(4, 4, 3, 3), float32] */, %out_bias: Tensor[(4, 1, 1), float32] /* ty=Tensor[(4, 1, 1), float32] */) -> Tensor[(4, 4, 10, 10), float32] {
  %0 = nn.conv2d(%x, %weight, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3]) /* ty=Tensor[(4, 4, 10, 10), float32] */;
  %1 = nn.conv2d(%x, %weight, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3], out_layout="CNHW") /* ty=Tensor[(4, 4, 10, 10), float32] */;
  %2 = nn.relu(%0) /* ty=Tensor[(4, 4, 10, 10), float32] */;
  %3 = nn.relu(%1) /* ty=Tensor[(4, 4, 10, 10), float32] */;
  %4 = add(%2, %3) /* ty=Tensor[(4, 4, 10, 10), float32] */;
  multiply(%4, meta[relay.Constant][0] /* ty=Tensor[(4, 1, 1), float32] */) /* ty=Tensor[(4, 4, 10, 10), float32] */
}
def @main(%x: Tensor[(2, 2, 10, 10, 2), float32] /* ty=Tensor[(2, 2, 10, 10, 2), float32] */, %weight: Tensor[(2, 4, 3, 3, 1, 2), float32] /* ty=Tensor[(2, 4, 3, 3, 1, 2), float32] */, %out_bias: Tensor[(2, 1, 1, 2), float32] /* ty=Tensor[(2, 1, 1, 2), float32] */) -> Tensor[(2, 2, 10, 10, 2), float32] {
  %0 = nn.conv2d(%x, %weight, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3], data_layout="NCHW2c", kernel_layout="OIHW1i2o") /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
  %1 = nn.conv2d(%x, %weight, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3], data_layout="NCHW2c", kernel_layout="OIHW1i2o", out_layout="CNHW2c") /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
  %2 = nn.relu(%0) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
  %3 = nn.relu(%1) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
  %4 = add(%2, %3) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
  multiply(%4, meta[relay.Constant][0] /* ty=Tensor[(2, 1, 1, 2), float32] */) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */
}
def @main(%x: Tensor[(4, 4, 10, 10), float32] /* ty=Tensor[(4, 4, 10, 10), float32] */, %weight: Tensor[(4, 4, 3, 3), float32] /* ty=Tensor[(4, 4, 3, 3), float32] */, %out_bias: Tensor[(4, 1, 1), float32] /* ty=Tensor[(4, 1, 1), float32] */) -> Tensor[(4, 4, 10, 10), float32] {
  %0 = nn.conv2d(%x, %weight, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3]) /* ty=Tensor[(4, 4, 10, 10), float32] */;
  %1 = multiply(%0, meta[relay.Constant][0] /* ty=Tensor[(4, 1, 1), float32] */) /* ty=Tensor[(4, 4, 10, 10), float32] */;
  %2 = nn.relu(%0) /* ty=Tensor[(4, 4, 10, 10), float32] */;
  add(%1, %2) /* ty=Tensor[(4, 4, 10, 10), float32] */
}
def @main(%x: Tensor[(4, 2, 10, 10, 2), float32] /* ty=Tensor[(4, 2, 10, 10, 2), float32] */, %weight: Tensor[(2, 4, 3, 3, 1, 2), float32] /* ty=Tensor[(2, 4, 3, 3, 1, 2), float32] */, %out_bias: Tensor[(2, 1, 1, 2), float32] /* ty=Tensor[(2, 1, 1, 2), float32] */) -> Tensor[(4, 2, 10, 10, 2), float32] {
  %0 = nn.conv2d(%x, %weight, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3], data_layout="NCHW2c", kernel_layout="OIHW1i2o") /* ty=Tensor[(4, 2, 10, 10, 2), float32] */;
  %1 = multiply(%0, meta[relay.Constant][0] /* ty=Tensor[(2, 1, 1, 2), float32] */) /* ty=Tensor[(4, 2, 10, 10, 2), float32] */;
  %2 = nn.relu(%0) /* ty=Tensor[(4, 2, 10, 10, 2), float32] */;
  add(%1, %2) /* ty=Tensor[(4, 2, 10, 10, 2), float32] */
}