算子融合测试#
import numpy as np
import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import run_opt_pass
import tvm.testing
import tvm.topi.testing
def before():
x = relay.var("x", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
w = relay.squeeze(z)
return relay.Function([x], w)
def expected():
x = relay.var("p", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
w = relay.squeeze(z)
f1 = relay.Function([x], w)
f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x])
return relay.Function([x], y)
z = before()
zz = run_opt_pass(z, transform.FuseOps())
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(zz, after)
def test_conv2d_fuse():
assert tvm.ir.structural_equal(zz, after)
def test_concatenate():
assert tvm.ir.structural_equal(zz, after)
def test_tuple_root():
assert tvm.ir.structural_equal(zz, after)
def test_stop_fusion():
assert tvm.ir.structural_equal(zz, after)
def test_fuse_myia_regression():
assert tvm.ir.structural_equal(zz, after)
def test_fuse_tuple_get_elemwise():
assert tvm.ir.structural_equal(zz, after)
def test_tuple_get_root():
assert tvm.ir.structural_equal(zz, after)
def fuse0(mod):
mod = relay.transform.InferType()(mod)
return relay.transform.FuseOps(fuse_opt_level=0)(mod)
def fuse2(mod):
mod = relay.transform.InferType()(mod)
return relay.transform.FuseOps(fuse_opt_level=2)(mod)
def test_tuple_intermediate():
assert tvm.ir.structural_equal(m["main"], after)
def test_tuple_consecutive():
assert tvm.ir.structural_equal(m["main"], after)
def test_inception_like():
assert tvm.ir.structural_equal(m["main"], after)
def test_fuse_parallel_injective():
assert tvm.ir.structural_equal(zz, after)
def test_immutable():
assert tvm.ir.structural_equal(new_mod, transform.InferType()(expected()))
def test_split():
mod = transform.FuseOps()(mod)
def test_fuse_max():
assert tvm.ir.structural_equal(zz, after)
link_params = tvm.testing.parameter(False, True)
def test_fuse_take(link_params):
relay.build(m, "llvm")
def test_fuse_gather_nd(link_params):
relay.build(m, "llvm")
@tvm.testing.uses_gpu
def test_fuse_bcast_reduce_scalar():
assert tvm.ir.structural_equal(m["main"], after)
def test_fuse_max_diamond():
assert tvm.ir.structural_equal(fused, expected)
def test_fuse_dynamic_squeeze_slice_take():
assert np.allclose(result.numpy(), np_result)
@tvm.testing.uses_gpu
def test_fuse_softmax():
tvm.testing.assert_allclose(result, ref, rtol=1e-4, atol=1e-4)
fn (%x: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */) -> Tensor[(10, 20), float32] {
%2 = fn (%p0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] {
%0 = add(%p0, 1f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
%1 = exp(%0) /* ty=Tensor[(10, 20), float32] */;
squeeze(%1) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
%2(%x) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */