简化推理#
SimplifyInference
(见tvm/src/relay/transforms/simplify_inference.cc
)
InferenceSimplifier()
: batch_norm_op_(Op::Get("nn.batch_norm")),
dropout_op_(Op::Get("nn.dropout")),
instance_norm_op_(Op::Get("nn.instance_norm")),
layer_norm_op_(Op::Get("nn.layer_norm")),
group_norm_op_(Op::Get("nn.group_norm")),
l2_norm_op_(Op::Get("nn.l2_normalize")) {}
from tvm.ir import IRModule, structural_equal
from tvm import relay as rly
from tvm.relay.transform import SimplifyInference, InferType
定义简单的 batch-norm(可以参考:batch-norm):
其中
应用标准化后,生成的小批量的平均值为
dim = 4
nstep = 1 # 1, 3
axis = 1 # 0, 1
dtype = "float16"
eps = 0.01
ttype1 = rly.TensorType(tuple(10 for i in range(dim)), dtype)
ttype2 = rly.TensorType((10,), dtype)
x = rly.var("x", ttype1)
beta = rly.var("beta", ttype2)
gamma = rly.var("gamma", ttype2)
moving_var = rly.var("moving_var", ttype2)
moving_mean = rly.var("moving_mean", ttype2)
y1, y2 = x, x
for _ in range(nstep):
y1, _, _ = rly.nn.batch_norm(
y1 + rly.const(1, dtype),
gamma,
beta,
moving_mean,
moving_var,
epsilon=eps,
axis=axis,
)
y1 = rly.nn.dropout(y1)
mod = IRModule.from_expr(y1)
simplify = SimplifyInference()
mod = InferType()(mod)