简单后向常量折叠

简单后向常量折叠#

%cd ..
import set_env
import numpy as np

import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import create_workload
from tvm.relay.build_module import bind_params_by_name


def initializer(_, param):
    param = np.zeros(param.shape)


def _get_positive_scale(size):
    return np.random.uniform(0.5, 1, size=size).astype("float32")


def run_opt_pass(expr, opt_pass):
    assert isinstance(opt_pass, tvm.transform.Pass)
    mod = tvm.IRModule.from_expr(expr)
    mod = opt_pass(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body
/media/pc/data/lxw/ai/tvm-book/tests/book/doc/tests

轴缩放被两个消费者消耗:

def before(data, weight, out_bias, channels):
    y = relay.nn.conv2d(
        data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1)
    )

    y = relay.add(y, out_bias)
    c2 = relay.const(2.0)
    y = relay.nn.relu(y)
    y = relay.multiply(y, c2)
    mod, params = create_workload(y, initializer)
    mod["main"] = bind_params_by_name(mod["main"], params)
    return mod

def expected(data, weight, out_bias, channels):
    y0 = relay.nn.conv2d(
        data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1)
    )
    y0 = relay.add(y0, out_bias)
    y0 = relay.nn.relu(y0)
    mod, params = create_workload(y0, initializer)
    mod["main"] = bind_params_by_name(mod["main"], params)
    return mod
def check(shape, channels):
    x = relay.var("data", relay.TensorType(shape, "float32"))
    weight = relay.var("weight")
    out_bias = relay.var("in_bias", shape=(channels, 1, 1))

    y0 = before(x, weight, out_bias, channels)
    print("FoldScaleAxis 前:")
    y0.show()

    remove_last_multiply = tvm.transform.Sequential(
        [
            relay.transform.InferType(),
            relay.transform.FoldScaleAxis(),
        ]
    )
    with tvm.transform.PassContext(opt_level=3):
        y0 = remove_last_multiply(y0)
    print("FoldScaleAxis 后:")
    y0.show()
    _expect = expected(x, weight, out_bias, channels)
    tvm.ir.assert_structural_equal(y0, _expect)

check((1, 3, 200, 200), 16)
FoldScaleAxis 前:
FoldScaleAxis 后:
def @main(%data: Tensor[(1, 3, 200, 200), float32] /* ty=Tensor[(1, 3, 200, 200), float32] */) -> Tensor[(1, 16, 200, 200), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0], padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 200, 200), float32] */;
  %1 = add(%0, meta[relay.Constant][1]) /* ty=Tensor[(1, 16, 200, 200), float32] */;
  %2 = nn.relu(%1) /* ty=Tensor[(1, 16, 200, 200), float32] */;
  multiply(%2, 2f /* ty=float32 */) /* ty=Tensor[(1, 16, 200, 200), float32] */
}
def @main(%data: Tensor[(1, 3, 200, 200), float32] /* ty=Tensor[(1, 3, 200, 200), float32] */) -> Tensor[(1, 16, 200, 200), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 200, 200), float32] */;
  %1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 200, 200), float32] */;
  nn.relu(%1) /* ty=Tensor[(1, 16, 200, 200), float32] */
}