Relax RemoveRedundantReshape: 消除冗余的 reshape#
参考:python/tvm/relax/transform/remove_redundant_reshape.py
%cd ..
import set_env
/media/pc/data/lxw/ai/tvm-book/doc/read
import tvm.testing
from tvm import relax
from tvm.relax.transform import DeadCodeElimination
from tvm.relax.transform import RemoveRedundantReshape
from tvm.script import ir as I, relax as R
def _run_pass_compare_output(Before, Expected):
fused_mod = RemoveRedundantReshape()(Before)
fused_mod = DeadCodeElimination()(fused_mod)
tvm.ir.assert_structural_equal(Expected, fused_mod)
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor((1, 1001, 1, 1), dtype="float16")
) -> R.Tensor((1, 1001), dtype="float16"):
with R.dataflow():
lv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, R.shape([1, 1001]))
lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv, R.shape([1, 1001]))
gv: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv1, R.shape([1, 1001]))
R.output(gv)
return gv
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((1, 1001, 1, 1), dtype="float16")
) -> R.Tensor((1, 1001), dtype="float16"):
with R.dataflow():
gv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, R.shape([1, 1001]))
R.output(gv)
return gv
_run_pass_compare_output(Before, Expected)
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor((1, 1001, 1, 1), dtype="float16")
) -> R.Tensor((1, 1001), dtype="float16"):
with R.dataflow():
lv: R.Tensor((1, 1001, 1), dtype="float16") = R.reshape(x, R.shape([1, 1001, 1]))
lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv, R.shape([1, 1001]))
R.output(lv1)
return lv1
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((1, 1001, 1, 1), dtype="float16")
) -> R.Tensor((1, 1001), dtype="float16"):
with R.dataflow():
lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, R.shape([1, 1001]))
R.output(lv1)
return lv1
_run_pass_compare_output(Before, Expected)
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor((1, 1001, 1, 1), dtype="float16")
) -> R.Tensor((1, 1001, 1, 1), dtype="float16"):
with R.dataflow():
lv: R.Tensor((1, 1001, 1, 1), dtype="float16") = R.reshape(
x, R.shape([1, 1001, 1, 1])
)
R.output(lv)
return lv
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((1, 1001, 1, 1), dtype="float16")
) -> R.Tensor((1, 1001, 1, 1), dtype="float16"):
return x
_run_pass_compare_output(Before, Expected)