简化推理

简化推理#

  1. SimplifyInference (见 tvm/src/relay/transforms/simplify_inference.cc)

InferenceSimplifier()
      : batch_norm_op_(Op::Get("nn.batch_norm")),
        dropout_op_(Op::Get("nn.dropout")),
        instance_norm_op_(Op::Get("nn.instance_norm")),
        layer_norm_op_(Op::Get("nn.layer_norm")),
        group_norm_op_(Op::Get("nn.group_norm")),
        l2_norm_op_(Op::Get("nn.l2_normalize")) {}
from tvm.ir import IRModule, structural_equal
from tvm import relay as rly
from tvm.relay.transform import SimplifyInference, InferType

定义简单的 batch-norm(可以参考:batch-norm):

\[ \mathrm{BN}(\mathbf{x}) = \boldsymbol{\gamma} \odot \frac{\mathbf{x} - \hat{\boldsymbol{\mu}}_\mathcal{B}}{\hat{\boldsymbol{\sigma}}_\mathcal{B}} + \boldsymbol{\beta}. \]

其中 \(\hat{\boldsymbol{\mu}}_\mathcal{B}\)\(\hat{\boldsymbol{\sigma}}_\mathcal{B}\) 分别是小批量 \(\mathcal{B}\) 的样本均值和样本标准差。

\[\begin{split} \begin{split}\begin{aligned} \hat{\boldsymbol{\mu}}_\mathcal{B} &= \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} \mathbf{x},\\ \hat{\boldsymbol{\sigma}}_\mathcal{B}^2 &= \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} (\mathbf{x} - \hat{\boldsymbol{\mu}}_{\mathcal{B}})^2 + \epsilon.\end{aligned}\end{split} \end{split}\]

应用标准化后,生成的小批量的平均值为 \(0\) 和单位方差为 \(1\)。由于单位方差是主观的选择,因此通常需要包含拉伸参数(scale) \(\boldsymbol{\gamma}\) 和偏移参数(shift) \(\boldsymbol{\beta}\),它们的形状与 \(\mathbf{x}\) 相同。请注意,\(\boldsymbol{\gamma}\)\(\boldsymbol{\beta}\) 是需要与其他模型参数一起学习的参数。

dim = 4
nstep = 1 # 1, 3
axis = 1 # 0, 1
dtype = "float16"
eps = 0.01
ttype1 = rly.TensorType(tuple(10 for i in range(dim)), dtype)
ttype2 = rly.TensorType((10,), dtype)
x = rly.var("x", ttype1)
beta = rly.var("beta", ttype2)
gamma = rly.var("gamma", ttype2)
moving_var = rly.var("moving_var", ttype2)
moving_mean = rly.var("moving_mean", ttype2)
y1, y2 = x, x
for _ in range(nstep):
    y1, _, _ = rly.nn.batch_norm(
        y1 + rly.const(1, dtype),
        gamma,
        beta,
        moving_mean,
        moving_var,
        epsilon=eps,
        axis=axis,
    )
    y1 = rly.nn.dropout(y1)
mod = IRModule.from_expr(y1)
simplify = SimplifyInference()
mod = InferType()(mod)