AppendLoss
#
import tvm.testing
from tvm import TVMError
from tvm.ir.base import assert_structural_equal
from tvm.script import relax as R, ir as I
from tvm.relax.training import AppendLoss
简单测试#
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")):
with R.dataflow():
gv0 = x + y
R.output(gv0)
return gv0
@R.function
def loss(arg1: R.Tensor((3, 3), "float32")):
with R.dataflow():
gv0 = R.sum(arg1)
R.output(gv0)
return gv0
After = AppendLoss("main", loss)(Before)
After.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main_loss(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"):
with R.dataflow():
gv0: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
gv0_1: R.Tensor((), dtype="float32") = R.sum(gv0, axis=None, keepdims=False)
R.output(gv0_1)
return gv0_1
@R.function
def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"):
with R.dataflow():
gv0: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
R.output(gv0)
return gv0
测试多主干#
测试带多个主干(backbone)输出的场景
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")):
with R.dataflow():
gv0 = R.sum(x)
gv1 = R.sum(y)
R.output(gv0, gv1)
return gv0, gv1
@R.function
def loss(arg1: R.Tensor((), "float32"), arg2: R.Tensor((), "float32")):
with R.dataflow():
gv0 = R.add(arg1, arg2)
R.output(gv0)
return gv0
After = AppendLoss("main", loss, 2)(Before)
After.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main_loss(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), dtype="float32"):
with R.dataflow():
gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
gv1: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False)
gv0_1: R.Tensor((), dtype="float32") = R.add(gv0, gv1)
R.output(gv0_1)
return gv0_1
@R.function
def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32")):
with R.dataflow():
gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
gv1: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False)
R.output(gv0, gv1)
return (gv0, gv1)
测试附加参数#
class Before:
# 原始主干网络返回三个输出:
# gv0: 输入张量的求和结果(标量)
# gv1: 输入自身的相加结果(3x3矩阵)
# gv2: 原始输入张量(3x3矩阵)
@R.function
def main(x: R.Tensor((3, 3), "float32")):
with R.dataflow():
gv0 = R.sum(x)
gv1 = R.add(x, x) # 矩阵自相加
gv2 = x # 保留原始输入
R.output(gv0, gv1, gv2)
return gv0, gv1, gv2
# 损失函数接收三个参数:
# - arg1: 主干第一个输出(标量)
# - arg2: 主干第二个输出(矩阵)
# - arg3: 额外参数(矩阵)
@R.function
def loss(arg1, arg2, arg3):
with R.dataflow():
gv0 = R.add(arg2, arg3) # 矩阵相加(使用额外参数)
gv1 = R.sum(gv0) # 求和得到损失值
R.output(gv1)
return gv1
After = AppendLoss("main", loss, 2)(Before)
After.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main_loss(x: R.Tensor((3, 3), dtype="float32"), arg3: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((3, 3), dtype="float32")):
with R.dataflow():
gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
gv1: R.Tensor((3, 3), dtype="float32") = R.add(x, x)
gv2: R.Tensor((3, 3), dtype="float32") = x
gv0_1: R.Tensor((3, 3), dtype="float32") = R.add(gv1, arg3)
gv1_1: R.Tensor((), dtype="float32") = R.sum(gv0_1, axis=None, keepdims=False)
R.output(gv2, gv1_1)
return (gv1_1, gv2)
@R.function
def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")):
with R.dataflow():
gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
gv1: R.Tensor((3, 3), dtype="float32") = R.add(x, x)
gv2: R.Tensor((3, 3), dtype="float32") = x
R.output(gv0, gv1, gv2)
return (gv0, gv1, gv2)