Relax RemoveRedundantReshape: 消除冗余的 reshape

Relax RemoveRedundantReshape: 消除冗余的 reshape#

参考:python/tvm/relax/transform/remove_redundant_reshape.py

%cd ..
import set_env
/media/pc/data/lxw/ai/tvm-book/doc/read
import tvm.testing
from tvm import relax
from tvm.relax.transform import DeadCodeElimination
from tvm.relax.transform import RemoveRedundantReshape
from tvm.script import ir as I, relax as R

def _run_pass_compare_output(Before, Expected):
    fused_mod = RemoveRedundantReshape()(Before)
    fused_mod = DeadCodeElimination()(fused_mod)
    tvm.ir.assert_structural_equal(Expected, fused_mod)
@I.ir_module
class Before:
    @R.function
    def main(
        x: R.Tensor((1, 1001, 1, 1), dtype="float16")
    ) -> R.Tensor((1, 1001), dtype="float16"):
        with R.dataflow():
            lv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, R.shape([1, 1001]))
            lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv, R.shape([1, 1001]))
            gv: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv1, R.shape([1, 1001]))
            R.output(gv)
        return gv

@I.ir_module
class Expected:
    @R.function
    def main(
        x: R.Tensor((1, 1001, 1, 1), dtype="float16")
    ) -> R.Tensor((1, 1001), dtype="float16"):
        with R.dataflow():
            gv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, R.shape([1, 1001]))
            R.output(gv)
        return gv

_run_pass_compare_output(Before, Expected)
@I.ir_module
class Before:
    @R.function
    def main(
        x: R.Tensor((1, 1001, 1, 1), dtype="float16")
    ) -> R.Tensor((1, 1001), dtype="float16"):
        with R.dataflow():
            lv: R.Tensor((1, 1001, 1), dtype="float16") = R.reshape(x, R.shape([1, 1001, 1]))
            lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv, R.shape([1, 1001]))
            R.output(lv1)
        return lv1

@I.ir_module
class Expected:
    @R.function
    def main(
        x: R.Tensor((1, 1001, 1, 1), dtype="float16")
    ) -> R.Tensor((1, 1001), dtype="float16"):
        with R.dataflow():
            lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, R.shape([1, 1001]))
            R.output(lv1)
        return lv1

_run_pass_compare_output(Before, Expected)
@I.ir_module
class Before:
    @R.function
    def main(
        x: R.Tensor((1, 1001, 1, 1), dtype="float16")
    ) -> R.Tensor((1, 1001, 1, 1), dtype="float16"):
        with R.dataflow():
            lv: R.Tensor((1, 1001, 1, 1), dtype="float16") = R.reshape(
                x, R.shape([1, 1001, 1, 1])
            )
            R.output(lv)
        return lv

@I.ir_module
class Expected:
    @R.function
    def main(
        x: R.Tensor((1, 1001, 1, 1), dtype="float16")
    ) -> R.Tensor((1, 1001, 1, 1), dtype="float16"):
        return x

_run_pass_compare_output(Before, Expected)