BN 融合模式

BN 融合模式#

import set_env
# import numpy as np

import tvm
from tvm import relay
# from tvm.relay.build_module import bind_params_by_name
from tvm.relay.dataflow_pattern import *
# from tvm.relay.testing import run_opt_pass
class BatchnormCallback(DFPatternCallback):
    def __init__(self):
        super(BatchnormCallback, self).__init__()
        self.x = wildcard()
        self.var = wildcard()
        self.mean = wildcard()
        self.beta = wildcard()
        self.gamma = wildcard()
        self.eps = is_constant()

        self.pattern = (
            self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + self.beta
        )

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        var = node_map[self.var][0]
        mean = node_map[self.mean][0]
        beta = node_map[self.beta][0]
        gamma = node_map[self.gamma][0]
        eps = node_map[self.eps][0]
        return relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=eps.data.numpy().item())[0]
def test_fuse_batchnorm():
    x = relay.var("x")
    var = relay.var("var")
    mean = relay.var("mean")
    beta = relay.var("beta")
    gamma = relay.var("gamma")

    BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta

    out = rewrite(BatchnormCallback(), BN)
    assert tvm.ir.structural_equal(
        out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
    )


def test_no_fuse_batchnorm():
    x = relay.var("x")
    var = relay.var("var")
    mean = relay.var("mean")
    beta = relay.var("beta")
    gamma = relay.var("gamma")

    fake_BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta

    out = rewrite(BatchnormCallback(), fake_BN)
    assert tvm.ir.structural_equal(out, fake_BN)


def test_fuse_double_batchnorm():
    x = relay.var("x")
    var = relay.var("var")
    mean = relay.var("mean")
    beta = relay.var("beta")
    gamma = relay.var("gamma")

    BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
    BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta

    out = rewrite(BatchnormCallback(), BN2)

    bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
    bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon=1e-5)[0]

    assert tvm.ir.structural_equal(out, bn2)


def test_partial_fuse_double_batchnorm():
    x = relay.var("x")
    var = relay.var("var")
    mean = relay.var("mean")
    beta = relay.var("beta")
    gamma = relay.var("gamma")

    BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta
    BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta

    out = rewrite(BatchnormCallback(), BN2)

    bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon=1e-5)[0]

    assert tvm.ir.structural_equal(out, bn2)


def test_fuse_batchnorm_commutation():
    x = relay.var("x")
    var = relay.var("var")
    mean = relay.var("mean")
    beta = relay.var("beta")
    gamma = relay.var("gamma")

    # commute add
    BN = beta + gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5))
    out = rewrite(BatchnormCallback(), BN)
    assert tvm.ir.structural_equal(
        out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
    )

    # associate divide/multiply
    BN = (gamma * (x - mean)) / relay.op.sqrt(var + relay.const(1e-5)) + beta
    out = rewrite(BatchnormCallback(), BN)
    assert tvm.ir.structural_equal(
        out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
    )

    # associate multiply/divide
    BN = gamma * ((x - mean) / relay.op.sqrt(var + relay.const(1e-5))) + beta
    out = rewrite(BatchnormCallback(), BN)
    assert tvm.ir.structural_equal(
        out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
    )