后向折叠 conv3d 的测试用例

后向折叠 conv3d 的测试用例#

%cd ..
import set_env
/media/pc/data/lxw/ai/tvm-book/tests/book/doc/tests
import numpy as np

import tvm
from tvm import relay
from tvm.relay import transform
# from tvm.relay.testing import create_workload
# from tvm.relay.build_module import bind_params_by_name


def initializer(_, param):
    param = np.zeros(param.shape)


def _get_positive_scale(size):
    return np.random.uniform(0.5, 1, size=size).astype("float32")


def run_opt_pass(expr, opt_pass):
    assert isinstance(opt_pass, tvm.transform.Pass)
    mod = tvm.IRModule.from_expr(expr)
    mod = opt_pass(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body
def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
    args = [x, conv_weight, out_bias]
    if blocking:
        out_bias = relay.reshape(out_bias, (1, channels // blocking[1], 1, 1, 1, blocking[1]))
    else:
        out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=3)
    y = relay.nn.conv3d(
        x,
        conv_weight,
        channels=channels,
        kernel_size=(3, 3, 3),
        padding=(1, 1, 1),
        data_layout="NCDHW{}c".format(blocking[0]) if blocking else "NCDHW",
        kernel_layout="OIDHW1i{}o".format(blocking[1]) if blocking else "OIDHW",
    )
    y = relay.add(y, out_bias)
    y = relay.nn.relu(y)
    if blocking:
        out_scale = relay.reshape(out_scale, (1, channels // blocking[1], 1, 1, 1, blocking[1]))
    y = relay.multiply(y, out_scale)
    return relay.Function(args, y)

def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
    # use a fixed order of args so alpha equal check can pass
    args = [x, conv_weight, out_bias]
    if blocking:
        out_bias = relay.reshape(out_bias, (1, channels // blocking[1], 1, 1, 1, blocking[1]))
        out_scale = relay.reshape(out_scale, (1, channels // blocking[1], 1, 1, 1, blocking[1]))
        squeezed_scale = relay.squeeze(out_scale, axis=[0, 2, 3, 4])
        conv_weight = relay.multiply(
            conv_weight,
            relay.reshape(
                squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, 1, blocking[1])
            ),
        )
    else:
        out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=3)
        squeezed_scale = relay.squeeze(out_scale, axis=[1, 2, 3])
        conv_weight = relay.multiply(
            conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=4)
        )

    y = relay.nn.conv3d(
        x,
        conv_weight,
        channels=channels,
        kernel_size=(3, 3, 3),
        padding=(1, 1, 1),
        data_layout="NCDHW{}c".format(blocking[0]) if blocking else "NCDHW",
        kernel_layout="OIDHW1i{}o".format(blocking[1]) if blocking else "OIDHW",
    )
    if blocking:
        out_bias = relay.multiply(
            out_bias,
            relay.reshape(squeezed_scale, (1, channels // blocking[1], 1, 1, 1, blocking[1])),
        )
    else:
        out_bias = relay.multiply(
            out_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
        )
    y = relay.add(y, out_bias)
    y = relay.nn.relu(y)
    return relay.Function(args, y)
def check(shape, in_channels, channels, blocking):
    x = relay.var("x", shape=shape)
    weight = relay.var("weight")
    out_bias = relay.var("out_bias", shape=(channels,))
    if blocking:
        out_scale = relay.const(_get_positive_scale((channels,)))
    else:
        out_scale = relay.const(_get_positive_scale((channels, 1, 1, 1)))
    y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking)
    y1 = run_opt_pass(y1, transform.InferType())
    print("FoldScaleAxis 前:")
    tvm.IRModule.from_expr(y1).show()

    type_dict = {x.name_hint: x.checked_type for x in y1.params}
    weight = relay.var("weight", type_dict["weight"])
    y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
    print("FoldScaleAxis 后:")
    tvm.IRModule.from_expr(y1_folded).show()
    
    y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking)
    y1_expected = run_opt_pass(y1_expected, transform.InferType())
    tvm.ir.assert_structural_equal(y1_folded, y1_expected)

check((2, 4, 10, 10, 10), 4, 8, None)
check((2, 2, 10, 10, 10, 16), 32, 64, (16, 16))
FoldScaleAxis 前:
FoldScaleAxis 后:
FoldScaleAxis 前:
FoldScaleAxis 后:
def @main(%x: Tensor[(2, 4, 10, 10, 10), float32] /* ty=Tensor[(2, 4, 10, 10, 10), float32] */, %weight: Tensor[(8, 4, 3, 3, 3), float32] /* ty=Tensor[(8, 4, 3, 3, 3), float32] */, %out_bias: Tensor[(8), float32] /* ty=Tensor[(8), float32] */) -> Tensor[(2, 8, 10, 10, 10), float32] {
  %0 = nn.conv3d(%x, %weight, padding=[1, 1, 1, 1, 1, 1], channels=8, kernel_size=[3, 3, 3]) /* ty=Tensor[(2, 8, 10, 10, 10), float32] */;
  %1 = expand_dims(%out_bias, axis=1, num_newaxis=3) /* ty=Tensor[(8, 1, 1, 1), float32] */;
  %2 = add(%0, %1) /* ty=Tensor[(2, 8, 10, 10, 10), float32] */;
  %3 = nn.relu(%2) /* ty=Tensor[(2, 8, 10, 10, 10), float32] */;
  multiply(%3, meta[relay.Constant][0] /* ty=Tensor[(8, 1, 1, 1), float32] */) /* ty=Tensor[(2, 8, 10, 10, 10), float32] */
}
def @main(%x: Tensor[(2, 4, 10, 10, 10), float32] /* ty=Tensor[(2, 4, 10, 10, 10), float32] */, %weight: Tensor[(8, 4, 3, 3, 3), float32] /* ty=Tensor[(8, 4, 3, 3, 3), float32] */, %out_bias: Tensor[(8), float32] /* ty=Tensor[(8), float32] */) -> Tensor[(2, 8, 10, 10, 10), float32] {
  %0 = squeeze(meta[relay.Constant][0] /* ty=Tensor[(8, 1, 1, 1), float32] */, axis=[1, 2, 3]) /* ty=Tensor[(8), float32] */;
  %1 = expand_dims(%0, axis=1, num_newaxis=4) /* ty=Tensor[(8, 1, 1, 1, 1), float32] */;
  %2 = multiply(%weight, %1) /* ty=Tensor[(8, 4, 3, 3, 3), float32] */;
  %3 = expand_dims(%out_bias, axis=1, num_newaxis=3) /* ty=Tensor[(8, 1, 1, 1), float32] */;
  %4 = expand_dims(%0, axis=1, num_newaxis=3) /* ty=Tensor[(8, 1, 1, 1), float32] */;
  %5 = nn.conv3d(%x, %2, padding=[1, 1, 1, 1, 1, 1], channels=8, kernel_size=[3, 3, 3]) /* ty=Tensor[(2, 8, 10, 10, 10), float32] */;
  %6 = multiply(%3, %4) /* ty=Tensor[(8, 1, 1, 1), float32] */;
  %7 = add(%5, %6) /* ty=Tensor[(2, 8, 10, 10, 10), float32] */;
  nn.relu(%7) /* ty=Tensor[(2, 8, 10, 10, 10), float32] */
}
def @main(%x: Tensor[(2, 2, 10, 10, 10, 16), float32] /* ty=Tensor[(2, 2, 10, 10, 10, 16), float32] */, %weight: Tensor[(4, 32, 3, 3, 3, 1, 16), float32] /* ty=Tensor[(4, 32, 3, 3, 3, 1, 16), float32] */, %out_bias: Tensor[(64), float32] /* ty=Tensor[(64), float32] */) -> Tensor[(2, 4, 10, 10, 10, 16), float32] {
  %0 = nn.conv3d(%x, %weight, padding=[1, 1, 1, 1, 1, 1], channels=64, kernel_size=[3, 3, 3], data_layout="NCDHW16c", kernel_layout="OIDHW1i16o") /* ty=Tensor[(2, 4, 10, 10, 10, 16), float32] */;
  %1 = reshape(%out_bias, newshape=[1, 4, 1, 1, 1, 16]) /* ty=Tensor[(1, 4, 1, 1, 1, 16), float32] */;
  %2 = add(%0, %1) /* ty=Tensor[(2, 4, 10, 10, 10, 16), float32] */;
  %3 = nn.relu(%2) /* ty=Tensor[(2, 4, 10, 10, 10, 16), float32] */;
  %4 = reshape(meta[relay.Constant][0] /* ty=Tensor[(64), float32] */, newshape=[1, 4, 1, 1, 1, 16]) /* ty=Tensor[(1, 4, 1, 1, 1, 16), float32] */;
  multiply(%3, %4) /* ty=Tensor[(2, 4, 10, 10, 10, 16), float32] */
}
def @main(%x: Tensor[(2, 2, 10, 10, 10, 16), float32] /* ty=Tensor[(2, 2, 10, 10, 10, 16), float32] */, %weight: Tensor[(4, 32, 3, 3, 3, 1, 16), float32] /* ty=Tensor[(4, 32, 3, 3, 3, 1, 16), float32] */, %out_bias: Tensor[(64), float32] /* ty=Tensor[(64), float32] */) -> Tensor[(2, 4, 10, 10, 10, 16), float32] {
  %0 = reshape(meta[relay.Constant][0] /* ty=Tensor[(64), float32] */, newshape=[1, 4, 1, 1, 1, 16]) /* ty=Tensor[(1, 4, 1, 1, 1, 16), float32] */;
  %1 = squeeze(%0, axis=[0, 2, 3, 4]) /* ty=Tensor[(4, 16), float32] */;
  %2 = reshape(%1, newshape=[4, 1, 1, 1, 1, 1, 16]) /* ty=Tensor[(4, 1, 1, 1, 1, 1, 16), float32] */;
  %3 = multiply(%weight, %2) /* ty=Tensor[(4, 32, 3, 3, 3, 1, 16), float32] */;
  %4 = reshape(%out_bias, newshape=[1, 4, 1, 1, 1, 16]) /* ty=Tensor[(1, 4, 1, 1, 1, 16), float32] */;
  %5 = reshape(%1, newshape=[1, 4, 1, 1, 1, 16]) /* ty=Tensor[(1, 4, 1, 1, 1, 16), float32] */;
  %6 = nn.conv3d(%x, %3, padding=[1, 1, 1, 1, 1, 1], channels=64, kernel_size=[3, 3, 3], data_layout="NCDHW16c", kernel_layout="OIDHW1i16o") /* ty=Tensor[(2, 4, 10, 10, 10, 16), float32] */;
  %7 = multiply(%4, %5) /* ty=Tensor[(1, 4, 1, 1, 1, 16), float32] */;
  %8 = add(%6, %7) /* ty=Tensor[(2, 4, 10, 10, 10, 16), float32] */;
  nn.relu(%8) /* ty=Tensor[(2, 4, 10, 10, 10, 16), float32] */
}