简单后向双消费者#
%cd ..
import set_env
import numpy as np
import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import create_workload
from tvm.relay.build_module import bind_params_by_name
def initializer(_, param):
param = np.zeros(param.shape)
def _get_positive_scale(size):
return np.random.uniform(0.5, 1, size=size).astype("float32")
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
/media/pc/data/lxw/ai/tvm-book/tests/book/doc/tests
轴缩放被两个消费者消耗:
def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
args = [x, conv_weight, out_bias]
y0 = relay.nn.conv2d(
x,
conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
)
y0 = relay.multiply(y0, out_scale)
y0 = relay.nn.relu(y0)
y1 = relay.nn.conv2d(
y0,
conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
)
y1 = relay.multiply(y1, out_scale)
y1 = relay.nn.relu(y1)
y2 = relay.nn.conv2d(
y0,
conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
)
y2 = relay.multiply(y2, out_scale)
y2 = relay.nn.relu(y2)
y = relay.add(y1, y2)
return relay.Function(args, y)
def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, out_bias]
def fold_conv_weight():
squeezed_scale = relay.squeeze(out_scale, axis=[1, 2])
if blocking:
return relay.multiply(
conv_weight,
relay.reshape(
squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, blocking[1])
),
)
else:
return relay.multiply(
conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
)
y0 = relay.nn.conv2d(
x,
fold_conv_weight(),
channels=channels,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
)
y0 = relay.nn.relu(y0)
y1 = relay.nn.conv2d(
y0,
fold_conv_weight(),
channels=channels,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
)
y1 = relay.nn.relu(y1)
y2 = relay.nn.conv2d(
y0,
fold_conv_weight(),
channels=channels,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
)
y2 = relay.nn.relu(y2)
y = relay.add(y1, y2)
return relay.Function(args, y)
def check(shape, in_channels, channels, blocking):
x = relay.var("x", shape=shape)
weight = relay.var("weight")
if blocking:
out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1]))
out_scale = relay.const(
_get_positive_scale((channels // blocking[1], 1, 1, blocking[1]))
)
else:
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking)
y1 = run_opt_pass(y1, transform.InferType())
print("FoldScaleAxis 前:")
tvm.IRModule.from_expr(y1).show()
type_dict = {x.name_hint: x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
print("FoldScaleAxis 后:")
tvm.IRModule.from_expr(y1_folded).show()
y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking)
y1_expected = run_opt_pass(y1_expected, transform.InferType())
tvm.ir.assert_structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 4, 4, None)
check((2, 2, 10, 10, 2), 4, 4, (2, 2))
FoldScaleAxis 前:
FoldScaleAxis 后:
FoldScaleAxis 前:
FoldScaleAxis 后:
def @main(%x: Tensor[(2, 4, 10, 10), float32] /* ty=Tensor[(2, 4, 10, 10), float32] */, %weight: Tensor[(4, 4, 3, 3), float32] /* ty=Tensor[(4, 4, 3, 3), float32] */, %out_bias: Tensor[(4), float32] /* ty=Tensor[(4), float32] */) -> Tensor[(2, 4, 10, 10), float32] {
%0 = nn.conv2d(%x, %weight, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3]) /* ty=Tensor[(2, 4, 10, 10), float32] */;
%1 = multiply(%0, meta[relay.Constant][0] /* ty=Tensor[(4, 1, 1), float32] */) /* ty=Tensor[(2, 4, 10, 10), float32] */;
%2 = nn.relu(%1) /* ty=Tensor[(2, 4, 10, 10), float32] */;
%3 = nn.conv2d(%2, %weight, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3]) /* ty=Tensor[(2, 4, 10, 10), float32] */;
%4 = multiply(%3, meta[relay.Constant][0] /* ty=Tensor[(4, 1, 1), float32] */) /* ty=Tensor[(2, 4, 10, 10), float32] */;
%5 = nn.conv2d(%2, %weight, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3]) /* ty=Tensor[(2, 4, 10, 10), float32] */;
%6 = multiply(%5, meta[relay.Constant][0] /* ty=Tensor[(4, 1, 1), float32] */) /* ty=Tensor[(2, 4, 10, 10), float32] */;
%7 = nn.relu(%4) /* ty=Tensor[(2, 4, 10, 10), float32] */;
%8 = nn.relu(%6) /* ty=Tensor[(2, 4, 10, 10), float32] */;
add(%7, %8) /* ty=Tensor[(2, 4, 10, 10), float32] */
}
def @main(%x: Tensor[(2, 4, 10, 10), float32] /* ty=Tensor[(2, 4, 10, 10), float32] */, %weight: Tensor[(4, 4, 3, 3), float32] /* ty=Tensor[(4, 4, 3, 3), float32] */, %out_bias: Tensor[(4), float32] /* ty=Tensor[(4), float32] */) -> Tensor[(2, 4, 10, 10), float32] {
%0 = squeeze(meta[relay.Constant][0] /* ty=Tensor[(4, 1, 1), float32] */, axis=[1, 2]) /* ty=Tensor[(4), float32] */;
%1 = expand_dims(%0, axis=1, num_newaxis=3) /* ty=Tensor[(4, 1, 1, 1), float32] */;
%2 = multiply(%weight, %1) /* ty=Tensor[(4, 4, 3, 3), float32] */;
%3 = nn.conv2d(%x, %2, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3]) /* ty=Tensor[(2, 4, 10, 10), float32] */;
%4 = squeeze(meta[relay.Constant][0] /* ty=Tensor[(4, 1, 1), float32] */, axis=[1, 2]) /* ty=Tensor[(4), float32] */;
%5 = expand_dims(%4, axis=1, num_newaxis=3) /* ty=Tensor[(4, 1, 1, 1), float32] */;
%6 = nn.relu(%3) /* ty=Tensor[(2, 4, 10, 10), float32] */;
%7 = multiply(%weight, %5) /* ty=Tensor[(4, 4, 3, 3), float32] */;
%8 = nn.conv2d(%6, %7, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3]) /* ty=Tensor[(2, 4, 10, 10), float32] */;
%9 = squeeze(meta[relay.Constant][0] /* ty=Tensor[(4, 1, 1), float32] */, axis=[1, 2]) /* ty=Tensor[(4), float32] */;
%10 = expand_dims(%9, axis=1, num_newaxis=3) /* ty=Tensor[(4, 1, 1, 1), float32] */;
%11 = multiply(%weight, %10) /* ty=Tensor[(4, 4, 3, 3), float32] */;
%12 = nn.conv2d(%6, %11, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3]) /* ty=Tensor[(2, 4, 10, 10), float32] */;
%13 = nn.relu(%8) /* ty=Tensor[(2, 4, 10, 10), float32] */;
%14 = nn.relu(%12) /* ty=Tensor[(2, 4, 10, 10), float32] */;
add(%13, %14) /* ty=Tensor[(2, 4, 10, 10), float32] */
}
def @main(%x: Tensor[(2, 2, 10, 10, 2), float32] /* ty=Tensor[(2, 2, 10, 10, 2), float32] */, %weight: Tensor[(2, 4, 3, 3, 1, 2), float32] /* ty=Tensor[(2, 4, 3, 3, 1, 2), float32] */, %out_bias: Tensor[(2, 1, 1, 2), float32] /* ty=Tensor[(2, 1, 1, 2), float32] */) -> Tensor[(2, 2, 10, 10, 2), float32] {
%0 = nn.conv2d(%x, %weight, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3], data_layout="NCHW2c", kernel_layout="OIHW1i2o") /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
%1 = multiply(%0, meta[relay.Constant][0] /* ty=Tensor[(2, 1, 1, 2), float32] */) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
%2 = nn.relu(%1) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
%3 = nn.conv2d(%2, %weight, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3], data_layout="NCHW2c", kernel_layout="OIHW1i2o") /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
%4 = multiply(%3, meta[relay.Constant][0] /* ty=Tensor[(2, 1, 1, 2), float32] */) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
%5 = nn.conv2d(%2, %weight, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3], data_layout="NCHW2c", kernel_layout="OIHW1i2o") /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
%6 = multiply(%5, meta[relay.Constant][0] /* ty=Tensor[(2, 1, 1, 2), float32] */) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
%7 = nn.relu(%4) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
%8 = nn.relu(%6) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
add(%7, %8) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */
}
def @main(%x: Tensor[(2, 2, 10, 10, 2), float32] /* ty=Tensor[(2, 2, 10, 10, 2), float32] */, %weight: Tensor[(2, 4, 3, 3, 1, 2), float32] /* ty=Tensor[(2, 4, 3, 3, 1, 2), float32] */, %out_bias: Tensor[(2, 1, 1, 2), float32] /* ty=Tensor[(2, 1, 1, 2), float32] */) -> Tensor[(2, 2, 10, 10, 2), float32] {
%0 = squeeze(meta[relay.Constant][0] /* ty=Tensor[(2, 1, 1, 2), float32] */, axis=[1, 2]) /* ty=Tensor[(2, 2), float32] */;
%1 = reshape(%0, newshape=[2, 1, 1, 1, 1, 2]) /* ty=Tensor[(2, 1, 1, 1, 1, 2), float32] */;
%2 = multiply(%weight, %1) /* ty=Tensor[(2, 4, 3, 3, 1, 2), float32] */;
%3 = nn.conv2d(%x, %2, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3], data_layout="NCHW2c", kernel_layout="OIHW1i2o") /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
%4 = squeeze(meta[relay.Constant][0] /* ty=Tensor[(2, 1, 1, 2), float32] */, axis=[1, 2]) /* ty=Tensor[(2, 2), float32] */;
%5 = reshape(%4, newshape=[2, 1, 1, 1, 1, 2]) /* ty=Tensor[(2, 1, 1, 1, 1, 2), float32] */;
%6 = nn.relu(%3) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
%7 = multiply(%weight, %5) /* ty=Tensor[(2, 4, 3, 3, 1, 2), float32] */;
%8 = nn.conv2d(%6, %7, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3], data_layout="NCHW2c", kernel_layout="OIHW1i2o") /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
%9 = squeeze(meta[relay.Constant][0] /* ty=Tensor[(2, 1, 1, 2), float32] */, axis=[1, 2]) /* ty=Tensor[(2, 2), float32] */;
%10 = reshape(%9, newshape=[2, 1, 1, 1, 1, 2]) /* ty=Tensor[(2, 1, 1, 1, 1, 2), float32] */;
%11 = multiply(%weight, %10) /* ty=Tensor[(2, 4, 3, 3, 1, 2), float32] */;
%12 = nn.conv2d(%6, %11, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3], data_layout="NCHW2c", kernel_layout="OIHW1i2o") /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
%13 = nn.relu(%8) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
%14 = nn.relu(%12) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */;
add(%13, %14) /* ty=Tensor[(2, 2, 10, 10, 2), float32] */
}