后向折叠 ReLU 失败

后向折叠 ReLU 失败#

%cd ..
import set_env
/media/pc/data/lxw/ai/tvm-book/tests/book/doc/tests
import numpy as np

import tvm
from tvm import relay
from tvm.relay import transform
# from tvm.relay.testing import create_workload
# from tvm.relay.build_module import bind_params_by_name


def initializer(_, param):
    param = np.zeros(param.shape)


# def _get_positive_scale(size):
#     return np.random.uniform(0.5, 1, size=size).astype("float32")


def run_opt_pass(expr, opt_pass):
    assert isinstance(opt_pass, tvm.transform.Pass)
    mod = tvm.IRModule.from_expr(expr)
    mod = opt_pass(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body

因为缩放无法通过 ReLU,所以我们不能折叠的测试用例:

def before(x, conv_weight, out_scale, channels, blocking):
    y = relay.nn.conv2d(
        x,
        conv_weight,
        channels=channels,
        kernel_size=(3, 3),
        padding=(1, 1),
        data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
        kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
    )
    y = relay.nn.relu(y)
    y = relay.multiply(x, out_scale)
    return relay.Function(relay.analysis.free_vars(y), y)
def check(shape, channels, blocking, out_scale):
    x = relay.var("x", shape=shape)
    in_channels = shape[1]
    weight = relay.var("weight")
    y1 = before(x, weight, out_scale, channels, blocking)
    y1 = run_opt_pass(y1, transform.InferType())
    tvm.IRModule.from_expr(y1).show()
    
    y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
    tvm.ir.assert_structural_equal(y1, y1_folded)

out_scale = relay.var("in_scale", shape=(4, 1, 1))
check((4, 4, 10, 10), 4, None, out_scale)
out_scale = relay.const(np.random.uniform(size=(4, 1, 1), low=-1.0, high=0.0)).astype("float32")
check((4, 4, 10, 10), 4, None, out_scale)

out_scale = relay.var("in_scale", shape=(1, 2, 1, 1, 2))
check((4, 2, 10, 10, 2), 4, (2, 2), out_scale)
out_scale = relay.const(np.random.uniform(size=(1, 2, 1, 1, 2), low=-1.0, high=0.0)).astype(
    "float32"
)
check((4, 2, 10, 10, 2), 4, (2, 2), out_scale)
def @main(%x: Tensor[(4, 4, 10, 10), float32] /* ty=Tensor[(4, 4, 10, 10), float32] */, %in_scale: Tensor[(4, 1, 1), float32] /* ty=Tensor[(4, 1, 1), float32] */) -> Tensor[(4, 4, 10, 10), float32] {
  multiply(%x, %in_scale) /* ty=Tensor[(4, 4, 10, 10), float32] */
}
def @main(%x: Tensor[(4, 4, 10, 10), float32] /* ty=Tensor[(4, 4, 10, 10), float32] */) -> Tensor[(4, 4, 10, 10), float32] {
  %0 = cast(meta[relay.Constant][0] /* ty=Tensor[(4, 1, 1), float32] */, dtype="float32") /* ty=Tensor[(4, 1, 1), float32] */;
  multiply(%x, %0) /* ty=Tensor[(4, 4, 10, 10), float32] */
}
def @main(%x: Tensor[(4, 2, 10, 10, 2), float32] /* ty=Tensor[(4, 2, 10, 10, 2), float32] */, %in_scale: Tensor[(1, 2, 1, 1, 2), float32] /* ty=Tensor[(1, 2, 1, 1, 2), float32] */) -> Tensor[(4, 2, 10, 10, 2), float32] {
  multiply(%x, %in_scale) /* ty=Tensor[(4, 2, 10, 10, 2), float32] */
}
def @main(%x: Tensor[(4, 2, 10, 10, 2), float32] /* ty=Tensor[(4, 2, 10, 10, 2), float32] */) -> Tensor[(4, 2, 10, 10, 2), float32] {
  %0 = cast(meta[relay.Constant][0] /* ty=Tensor[(1, 2, 1, 1, 2), float32] */, dtype="float32") /* ty=Tensor[(1, 2, 1, 1, 2), float32] */;
  multiply(%x, %0) /* ty=Tensor[(4, 2, 10, 10, 2), float32] */
}