算子融合测试

算子融合测试#

import numpy as np
import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import run_opt_pass
import tvm.testing
import tvm.topi.testing
def before():
    x = relay.var("x", shape=(10, 20))
    y = relay.add(x, relay.const(1, "float32"))
    z = relay.exp(y)
    w = relay.squeeze(z)
    return relay.Function([x], w)

def expected():
    x = relay.var("p", shape=(10, 20))
    y = relay.add(x, relay.const(1, "float32"))
    z = relay.exp(y)
    w = relay.squeeze(z)
    f1 = relay.Function([x], w)
    f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
    x = relay.var("x", shape=(10, 20))
    y = relay.Call(f1, [x])
    return relay.Function([x], y)

z = before()
zz = run_opt_pass(z, transform.FuseOps())
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(zz, after)
def test_conv2d_fuse():
    assert tvm.ir.structural_equal(zz, after)


def test_concatenate():
    assert tvm.ir.structural_equal(zz, after)


def test_tuple_root():
    assert tvm.ir.structural_equal(zz, after)


def test_stop_fusion():
    assert tvm.ir.structural_equal(zz, after)


def test_fuse_myia_regression():
    assert tvm.ir.structural_equal(zz, after)


def test_fuse_tuple_get_elemwise():
    assert tvm.ir.structural_equal(zz, after)


def test_tuple_get_root():
    assert tvm.ir.structural_equal(zz, after)


def fuse0(mod):
    mod = relay.transform.InferType()(mod)
    return relay.transform.FuseOps(fuse_opt_level=0)(mod)


def fuse2(mod):
    mod = relay.transform.InferType()(mod)
    return relay.transform.FuseOps(fuse_opt_level=2)(mod)


def test_tuple_intermediate():
    assert tvm.ir.structural_equal(m["main"], after)


def test_tuple_consecutive():
    assert tvm.ir.structural_equal(m["main"], after)


def test_inception_like():
    assert tvm.ir.structural_equal(m["main"], after)


def test_fuse_parallel_injective():
    assert tvm.ir.structural_equal(zz, after)


def test_immutable():
    assert tvm.ir.structural_equal(new_mod, transform.InferType()(expected()))


def test_split():
    mod = transform.FuseOps()(mod)


def test_fuse_max():
    assert tvm.ir.structural_equal(zz, after)


link_params = tvm.testing.parameter(False, True)


def test_fuse_take(link_params):
    relay.build(m, "llvm")


def test_fuse_gather_nd(link_params):
    relay.build(m, "llvm")


@tvm.testing.uses_gpu
def test_fuse_bcast_reduce_scalar():
    assert tvm.ir.structural_equal(m["main"], after)


def test_fuse_max_diamond():
    assert tvm.ir.structural_equal(fused, expected)


def test_fuse_dynamic_squeeze_slice_take():
    assert np.allclose(result.numpy(), np_result)


@tvm.testing.uses_gpu
def test_fuse_softmax():
        tvm.testing.assert_allclose(result, ref, rtol=1e-4, atol=1e-4)
fn (%x: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */) -> Tensor[(10, 20), float32] {
  %2 = fn (%p0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] {
    %0 = add(%p0, 1f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
    %1 = exp(%0) /* ty=Tensor[(10, 20), float32] */;
    squeeze(%1) /* ty=Tensor[(10, 20), float32] */
  } /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
  %2(%x) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */