def test_fuse_batchnorm():
x = relay.var("x")
var = relay.var("var")
mean = relay.var("mean")
beta = relay.var("beta")
gamma = relay.var("gamma")
BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
out = rewrite(BatchnormCallback(), BN)
assert tvm.ir.structural_equal(
out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
)
def test_no_fuse_batchnorm():
x = relay.var("x")
var = relay.var("var")
mean = relay.var("mean")
beta = relay.var("beta")
gamma = relay.var("gamma")
fake_BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta
out = rewrite(BatchnormCallback(), fake_BN)
assert tvm.ir.structural_equal(out, fake_BN)
def test_fuse_double_batchnorm():
x = relay.var("x")
var = relay.var("var")
mean = relay.var("mean")
beta = relay.var("beta")
gamma = relay.var("gamma")
BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
out = rewrite(BatchnormCallback(), BN2)
bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon=1e-5)[0]
assert tvm.ir.structural_equal(out, bn2)
def test_partial_fuse_double_batchnorm():
x = relay.var("x")
var = relay.var("var")
mean = relay.var("mean")
beta = relay.var("beta")
gamma = relay.var("gamma")
BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta
BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
out = rewrite(BatchnormCallback(), BN2)
bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon=1e-5)[0]
assert tvm.ir.structural_equal(out, bn2)
def test_fuse_batchnorm_commutation():
x = relay.var("x")
var = relay.var("var")
mean = relay.var("mean")
beta = relay.var("beta")
gamma = relay.var("gamma")
# commute add
BN = beta + gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5))
out = rewrite(BatchnormCallback(), BN)
assert tvm.ir.structural_equal(
out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
)
# associate divide/multiply
BN = (gamma * (x - mean)) / relay.op.sqrt(var + relay.const(1e-5)) + beta
out = rewrite(BatchnormCallback(), BN)
assert tvm.ir.structural_equal(
out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
)
# associate multiply/divide
BN = gamma * ((x - mean) / relay.op.sqrt(var + relay.const(1e-5))) + beta
out = rewrite(BatchnormCallback(), BN)
assert tvm.ir.structural_equal(
out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
)