简化推理#
SimplifyInference
(见tvm/src/relay/transforms/simplify_inference.cc
)
InferenceSimplifier()
: batch_norm_op_(Op::Get("nn.batch_norm")),
dropout_op_(Op::Get("nn.dropout")),
instance_norm_op_(Op::Get("nn.instance_norm")),
layer_norm_op_(Op::Get("nn.layer_norm")),
group_norm_op_(Op::Get("nn.group_norm")),
l2_norm_op_(Op::Get("nn.l2_normalize")) {}
from tvm.ir import IRModule, structural_equal
from tvm import relay as rly
from tvm.relay.transform import SimplifyInference, InferType
定义简单的 batch-norm(可以参考:batch-norm):
\[
\mathrm{BN}(\mathbf{x}) = \boldsymbol{\gamma} \odot \frac{\mathbf{x} - \hat{\boldsymbol{\mu}}_\mathcal{B}}{\hat{\boldsymbol{\sigma}}_\mathcal{B}} + \boldsymbol{\beta}.
\]
其中 \(\hat{\boldsymbol{\mu}}_\mathcal{B}\) 和 \(\hat{\boldsymbol{\sigma}}_\mathcal{B}\) 分别是小批量 \(\mathcal{B}\) 的样本均值和样本标准差。
\[\begin{split}
\begin{split}\begin{aligned} \hat{\boldsymbol{\mu}}_\mathcal{B} &= \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} \mathbf{x},\\
\hat{\boldsymbol{\sigma}}_\mathcal{B}^2 &= \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} (\mathbf{x} - \hat{\boldsymbol{\mu}}_{\mathcal{B}})^2 + \epsilon.\end{aligned}\end{split}
\end{split}\]
应用标准化后,生成的小批量的平均值为 \(0\) 和单位方差为 \(1\)。由于单位方差是主观的选择,因此通常需要包含拉伸参数(scale) \(\boldsymbol{\gamma}\) 和偏移参数(shift) \(\boldsymbol{\beta}\),它们的形状与 \(\mathbf{x}\) 相同。请注意,\(\boldsymbol{\gamma}\) 和 \(\boldsymbol{\beta}\) 是需要与其他模型参数一起学习的参数。
dim = 4
nstep = 1 # 1, 3
axis = 1 # 0, 1
dtype = "float16"
eps = 0.01
ttype1 = rly.TensorType(tuple(10 for i in range(dim)), dtype)
ttype2 = rly.TensorType((10,), dtype)
x = rly.var("x", ttype1)
beta = rly.var("beta", ttype2)
gamma = rly.var("gamma", ttype2)
moving_var = rly.var("moving_var", ttype2)
moving_mean = rly.var("moving_mean", ttype2)
y1, y2 = x, x
for _ in range(nstep):
y1, _, _ = rly.nn.batch_norm(
y1 + rly.const(1, dtype),
gamma,
beta,
moving_mean,
moving_var,
epsilon=eps,
axis=axis,
)
y1 = rly.nn.dropout(y1)
mod = IRModule.from_expr(y1)
simplify = SimplifyInference()
mod = InferType()(mod)