简化推理

简化推理#

  1. SimplifyInference (见 tvm/src/relay/transforms/simplify_inference.cc)

InferenceSimplifier()
      : batch_norm_op_(Op::Get("nn.batch_norm")),
        dropout_op_(Op::Get("nn.dropout")),
        instance_norm_op_(Op::Get("nn.instance_norm")),
        layer_norm_op_(Op::Get("nn.layer_norm")),
        group_norm_op_(Op::Get("nn.group_norm")),
        l2_norm_op_(Op::Get("nn.l2_normalize")) {}
from tvm.ir import IRModule, structural_equal
from tvm import relay as rly
from tvm.relay.transform import SimplifyInference, InferType

定义简单的 batch-norm(可以参考:batch-norm):

BN(x)=γxμ^Bσ^B+β.

其中 μ^Bσ^B 分别是小批量 B 的样本均值和样本标准差。

μ^B=1|B|xBx,σ^B2=1|B|xB(xμ^B)2+ϵ.

应用标准化后,生成的小批量的平均值为 0 和单位方差为 1。由于单位方差是主观的选择,因此通常需要包含拉伸参数(scale) γ 和偏移参数(shift) β,它们的形状与 x 相同。请注意,γβ 是需要与其他模型参数一起学习的参数。

dim = 4
nstep = 1 # 1, 3
axis = 1 # 0, 1
dtype = "float16"
eps = 0.01
ttype1 = rly.TensorType(tuple(10 for i in range(dim)), dtype)
ttype2 = rly.TensorType((10,), dtype)
x = rly.var("x", ttype1)
beta = rly.var("beta", ttype2)
gamma = rly.var("gamma", ttype2)
moving_var = rly.var("moving_var", ttype2)
moving_mean = rly.var("moving_mean", ttype2)
y1, y2 = x, x
for _ in range(nstep):
    y1, _, _ = rly.nn.batch_norm(
        y1 + rly.const(1, dtype),
        gamma,
        beta,
        moving_mean,
        moving_var,
        epsilon=eps,
        axis=axis,
    )
    y1 = rly.nn.dropout(y1)
mod = IRModule.from_expr(y1)
simplify = SimplifyInference()
mod = InferType()(mod)