AppendLoss#

import tvm.testing
from tvm import TVMError
from tvm.ir.base import assert_structural_equal
from tvm.script import relax as R, ir as I
from tvm.relax.training import AppendLoss

简单测试#

@I.ir_module
class Before:
    @R.function
    def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")):
        with R.dataflow():
            gv0 = x + y
            R.output(gv0)
        return gv0

@R.function
def loss(arg1: R.Tensor((3, 3), "float32")):
    with R.dataflow():
        gv0 = R.sum(arg1)
        R.output(gv0)
    return gv0
After = AppendLoss("main", loss)(Before)
After.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main_loss(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"):
        with R.dataflow():
            gv0: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
            gv0_1: R.Tensor((), dtype="float32") = R.sum(gv0, axis=None, keepdims=False)
            R.output(gv0_1)
        return gv0_1

    @R.function
    def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"):
        with R.dataflow():
            gv0: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
            R.output(gv0)
        return gv0

测试多主干#

测试带多个主干(backbone)输出的场景

@I.ir_module
class Before:
    @R.function
    def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")):
        with R.dataflow():
            gv0 = R.sum(x)
            gv1 = R.sum(y)
            R.output(gv0, gv1)
        return gv0, gv1

@R.function
def loss(arg1: R.Tensor((), "float32"), arg2: R.Tensor((), "float32")):
    with R.dataflow():
        gv0 = R.add(arg1, arg2)
        R.output(gv0)
    return gv0

After = AppendLoss("main", loss, 2)(Before)
After.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main_loss(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"):
        with R.dataflow():
            gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
            gv1: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False)
            gv0_1: R.Tensor((), dtype="float32") = R.add(gv0, gv1)
            R.output(gv0_1)
        return gv0_1

    @R.function
    def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32")):
        with R.dataflow():
            gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
            gv1: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False)
            R.output(gv0, gv1)
        return (gv0, gv1)

测试附加参数#

class Before:
    # 原始主干网络返回三个输出:
    # gv0: 输入张量的求和结果(标量)
    # gv1: 输入自身的相加结果(3x3矩阵) 
    # gv2: 原始输入张量(3x3矩阵)
    @R.function
    def main(x: R.Tensor((3, 3), "float32")):
        with R.dataflow():
            gv0 = R.sum(x)
            gv1 = R.add(x, x)  # 矩阵自相加
            gv2 = x            # 保留原始输入
            R.output(gv0, gv1, gv2)
        return gv0, gv1, gv2

# 损失函数接收三个参数:
# - arg1: 主干第一个输出(标量)
# - arg2: 主干第二个输出(矩阵)
# - arg3: 额外参数(矩阵)
@R.function
def loss(arg1, arg2, arg3):
    with R.dataflow():
        gv0 = R.add(arg2, arg3)  # 矩阵相加(使用额外参数)
        gv1 = R.sum(gv0)         # 求和得到损失值
        R.output(gv1)
    return gv1
After = AppendLoss("main", loss, 2)(Before)
After.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main_loss(x: R.Tensor((3, 3), dtype="float32"), arg3: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((3, 3), dtype="float32")):
        with R.dataflow():
            gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
            gv1: R.Tensor((3, 3), dtype="float32") = R.add(x, x)
            gv2: R.Tensor((3, 3), dtype="float32") = x
            gv0_1: R.Tensor((3, 3), dtype="float32") = R.add(gv1, arg3)
            gv1_1: R.Tensor((), dtype="float32") = R.sum(gv0_1, axis=None, keepdims=False)
            R.output(gv2, gv1_1)
        return (gv1_1, gv2)

    @R.function
    def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")):
        with R.dataflow():
            gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
            gv1: R.Tensor((3, 3), dtype="float32") = R.add(x, x)
            gv2: R.Tensor((3, 3), dtype="float32") = x
            R.output(gv0, gv1, gv2)
        return (gv0, gv1, gv2)