后向折叠 dense 的测试用例

后向折叠 dense 的测试用例#

%cd ..
import set_env
/media/pc/data/lxw/ai/tvm-book/tests/book/doc/tests
import numpy as np

import tvm
from tvm import relay
from tvm.relay import transform
# from tvm.relay.testing import create_workload
# from tvm.relay.build_module import bind_params_by_name


def initializer(_, param):
    param = np.zeros(param.shape)


def _get_positive_scale(size):
    return np.random.uniform(0.5, 1, size=size).astype("float32")


def run_opt_pass(expr, opt_pass):
    assert isinstance(opt_pass, tvm.transform.Pass)
    mod = tvm.IRModule.from_expr(expr)
    mod = opt_pass(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body
def before(x, weight, in_bias, in_scale):
    args = [x, weight, in_bias]
    x = relay.nn.dense(x, weight)
    x = relay.add(x, in_bias)
    x = relay.nn.relu(x)
    y = relay.multiply(x, in_scale)
    return relay.Function(args, y)

def expected(x, weight, in_bias, in_scale):
    # use a fixed order of args so alpha equal check can pass
    args = [x, weight, in_bias]
    scale = relay.expand_dims(in_scale, axis=1)
    weight = relay.multiply(weight, scale)
    x = relay.nn.dense(x, weight)
    bias = relay.multiply(in_bias, in_scale)
    x = relay.add(x, bias)
    y = relay.nn.relu(x)
    return relay.Function(args, y)
def check(data_shape, weight_shape):
    x = relay.var("x", shape=data_shape)
    weight = relay.var("weight", shape=weight_shape)
    out_channels = weight_shape[0]
    in_bias = relay.var("in_bias", shape=(out_channels,))
    in_scale = relay.const(_get_positive_scale((out_channels,)))
    y1 = before(x, weight, in_bias, in_scale)
    y1 = run_opt_pass(y1, transform.InferType())
    print("FoldScaleAxis 前:")
    tvm.IRModule.from_expr(y1).show()

    y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
    y1_expected = expected(x, weight, in_bias, in_scale)

    y1_folded = run_opt_pass(y1_folded, transform.InferType())
    print("FoldScaleAxis 后:")
    tvm.IRModule.from_expr(y1_folded).show()
    
    y1_expected = run_opt_pass(y1_expected, transform.InferType())
    tvm.ir.assert_structural_equal(y1_folded, y1_expected)

check((2, 4), (3, 4))
check((3, 5), (4, 5))
FoldScaleAxis 前:
FoldScaleAxis 后:
FoldScaleAxis 前:
FoldScaleAxis 后:
def @main(%x: Tensor[(2, 4), float32] /* ty=Tensor[(2, 4), float32] */, %weight: Tensor[(3, 4), float32] /* ty=Tensor[(3, 4), float32] */, %in_bias: Tensor[(3), float32] /* ty=Tensor[(3), float32] */) -> Tensor[(2, 3), float32] {
  %0 = nn.dense(%x, %weight, units=None) /* ty=Tensor[(2, 3), float32] */;
  %1 = add(%0, %in_bias) /* ty=Tensor[(2, 3), float32] */;
  %2 = nn.relu(%1) /* ty=Tensor[(2, 3), float32] */;
  multiply(%2, meta[relay.Constant][0] /* ty=Tensor[(3), float32] */) /* ty=Tensor[(2, 3), float32] */
}
def @main(%x: Tensor[(2, 4), float32] /* ty=Tensor[(2, 4), float32] */, %weight: Tensor[(3, 4), float32] /* ty=Tensor[(3, 4), float32] */, %in_bias: Tensor[(3), float32] /* ty=Tensor[(3), float32] */) -> Tensor[(2, 3), float32] {
  %0 = expand_dims(meta[relay.Constant][0] /* ty=Tensor[(3), float32] */, axis=1) /* ty=Tensor[(3, 1), float32] */;
  %1 = multiply(%weight, %0) /* ty=Tensor[(3, 4), float32] */;
  %2 = nn.dense(%x, %1, units=None) /* ty=Tensor[(2, 3), float32] */;
  %3 = multiply(%in_bias, meta[relay.Constant][0] /* ty=Tensor[(3), float32] */) /* ty=Tensor[(3), float32] */;
  %4 = add(%2, %3) /* ty=Tensor[(2, 3), float32] */;
  nn.relu(%4) /* ty=Tensor[(2, 3), float32] */
}
def @main(%x: Tensor[(3, 5), float32] /* ty=Tensor[(3, 5), float32] */, %weight: Tensor[(4, 5), float32] /* ty=Tensor[(4, 5), float32] */, %in_bias: Tensor[(4), float32] /* ty=Tensor[(4), float32] */) -> Tensor[(3, 4), float32] {
  %0 = nn.dense(%x, %weight, units=None) /* ty=Tensor[(3, 4), float32] */;
  %1 = add(%0, %in_bias) /* ty=Tensor[(3, 4), float32] */;
  %2 = nn.relu(%1) /* ty=Tensor[(3, 4), float32] */;
  multiply(%2, meta[relay.Constant][0] /* ty=Tensor[(4), float32] */) /* ty=Tensor[(3, 4), float32] */
}
def @main(%x: Tensor[(3, 5), float32] /* ty=Tensor[(3, 5), float32] */, %weight: Tensor[(4, 5), float32] /* ty=Tensor[(4, 5), float32] */, %in_bias: Tensor[(4), float32] /* ty=Tensor[(4), float32] */) -> Tensor[(3, 4), float32] {
  %0 = expand_dims(meta[relay.Constant][0] /* ty=Tensor[(4), float32] */, axis=1) /* ty=Tensor[(4, 1), float32] */;
  %1 = multiply(%weight, %0) /* ty=Tensor[(4, 5), float32] */;
  %2 = nn.dense(%x, %1, units=None) /* ty=Tensor[(3, 4), float32] */;
  %3 = multiply(%in_bias, meta[relay.Constant][0] /* ty=Tensor[(4), float32] */) /* ty=Tensor[(4), float32] */;
  %4 = add(%2, %3) /* ty=Tensor[(3, 4), float32] */;
  nn.relu(%4) /* ty=Tensor[(3, 4), float32] */
}