SetupTrainer
#
import tvm
import tvm.testing
from tvm import relax, TVMError
from tvm.ir.base import assert_structural_equal
from tvm.relax.training import SetupTrainer
from tvm.relax.training.optimizer import SGD, MomentumSGD
from tvm.relax.training.loss import MSELoss
from tvm.script import ir as I, relax as R
测试简单模型#
@I.ir_module
class Backbone:
I.module_attrs({"param_num": 1, "state_num": 0})
@R.function
def backbone(x: R.Tensor((2, 2), "float64"), y: R.Tensor((2, 2), "float64")):
with R.dataflow():
x1 = x + y
R.output(x1)
return x1
sinfo = relax.TensorStructInfo((2, 2), "float64")
setup_trainer = SetupTrainer(MSELoss(reduction="sum"), SGD(0.1), [sinfo, sinfo], legalize=False)
train_mod = setup_trainer(Backbone)
train_mod.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
I.module_attrs({"input_num": 1, "optim_state": [metadata["runtime.NDArray"][0]], "param_num": 1, "state_num": 0})
@R.function
def backbone(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64")) -> R.Tensor((2, 2), dtype="float64"):
with R.dataflow():
x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
R.output(x1)
return x1
@R.function
def backbone_loss(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), targets: R.Tensor((2, 2), dtype="float64")) -> R.Tensor((), dtype="float64"):
with R.dataflow():
x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
lv: R.Tensor((2, 2), dtype="float64") = R.subtract(x1, targets)
lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv, lv)
gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False)
R.output(gv)
return gv
@R.function
def backbone_loss_adjoint(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), targets: R.Tensor((2, 2), dtype="float64")) -> R.Tuple(R.Tensor((), dtype="float64"), R.Tuple(R.Tensor((2, 2), dtype="float64"))):
with R.dataflow():
x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
lv: R.Tensor((2, 2), dtype="float64") = R.subtract(x1, targets)
lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv, lv)
gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False)
gv_adjoint: R.Tensor((), dtype="float64") = R.ones(R.shape([]), dtype="float64")
lv1_adjoint: R.Tensor((2, 2), dtype="float64") = R.broadcast_to(gv_adjoint, R.shape([2, 2]))
lv_adjoint: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv)
lv_1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv)
lv_adjoint1: R.Tensor((2, 2), dtype="float64") = R.add(lv_adjoint, lv_1)
x1_adjoint: R.Tensor((2, 2), dtype="float64") = lv_adjoint1
y_adjoint: R.Tensor((2, 2), dtype="float64") = x1_adjoint
y_adjoint_out: R.Tensor((2, 2), dtype="float64") = y_adjoint
R.output(gv, y_adjoint_out)
return (gv, (y_adjoint_out,))
@R.function
def optimizer(params: R.Tuple(R.Tensor((2, 2), dtype="float64")), gradients: R.Tuple(R.Tensor((2, 2), dtype="float64")), optim_states: R.Tuple(R.Tensor((), dtype="int64"))) -> R.Tuple(R.Tuple(R.Tensor((2, 2), dtype="float64")), R.Tuple(R.Tensor((), dtype="int64"))):
with R.dataflow():
num_steps: R.Tensor((), dtype="int64") = optim_states[0]
num_steps_new: R.Tensor((), dtype="int64") = R.add(num_steps, R.const(1, "int64"))
y: R.Tensor((2, 2), dtype="float64") = params[0]
y_grad: R.Tensor((2, 2), dtype="float64") = gradients[0]
lv: R.Tensor((2, 2), dtype="float64") = R.multiply(R.const(0.10000000000000001, "float64"), y_grad)
y_new: R.Tensor((2, 2), dtype="float64") = R.subtract(y, lv)
params_new: R.Tuple(R.Tensor((2, 2), dtype="float64")) = (y_new,)
optim_states_new: R.Tuple(R.Tensor((), dtype="int64")) = (num_steps_new,)
R.output(params_new, optim_states_new)
return (params_new, optim_states_new)
# Metadata omitted. Use show_meta=True in script() method to show it.
测试状态#
@I.ir_module
class Backbone:
I.module_attrs({"param_num": 1, "state_num": 1})
@R.function
def backbone(x: R.Tensor((2, 2), "float64"), y: R.Tensor((2, 2), "float64"), z: R.Tensor((2, 2), "float64")):
with R.dataflow():
x1 = x + y
z1 = z + R.const(1, "float64")
R.output(x1, z1)
return x1, z1
sinfo = relax.TensorStructInfo((2, 2), "float64")
setup_trainer = SetupTrainer(
MSELoss(reduction="sum"), MomentumSGD(0.1, 0.1), [sinfo, sinfo], legalize=False
)
train_mod = setup_trainer(Backbone)
train_mod.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
I.module_attrs({"input_num": 1, "optim_state": [metadata["runtime.NDArray"][0], metadata["runtime.NDArray"][1]], "param_num": 1, "state_num": 1})
@R.function
def backbone(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), z: R.Tensor((2, 2), dtype="float64")) -> R.Tuple(R.Tensor((2, 2), dtype="float64"), R.Tensor((2, 2), dtype="float64")):
with R.dataflow():
x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
z1: R.Tensor((2, 2), dtype="float64") = R.add(z, R.const(1.0, "float64"))
R.output(x1, z1)
return (x1, z1)
@R.function
def backbone_loss(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), z: R.Tensor((2, 2), dtype="float64"), targets: R.Tensor((2, 2), dtype="float64")) -> R.Tuple(R.Tensor((), dtype="float64"), R.Tensor((2, 2), dtype="float64")):
with R.dataflow():
x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
z1: R.Tensor((2, 2), dtype="float64") = R.add(z, R.const(1.0, "float64"))
lv: R.Tensor((2, 2), dtype="float64") = R.subtract(x1, targets)
lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv, lv)
gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False)
R.output(z1, gv)
return (gv, z1)
@R.function
def backbone_loss_adjoint(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), z: R.Tensor((2, 2), dtype="float64"), targets: R.Tensor((2, 2), dtype="float64")) -> R.Tuple(R.Tuple(R.Tensor((), dtype="float64"), R.Tensor((2, 2), dtype="float64")), R.Tuple(R.Tensor((2, 2), dtype="float64"))):
with R.dataflow():
x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
z1: R.Tensor((2, 2), dtype="float64") = R.add(z, R.const(1.0, "float64"))
lv: R.Tensor((2, 2), dtype="float64") = R.subtract(x1, targets)
lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv, lv)
gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False)
gv_adjoint: R.Tensor((), dtype="float64") = R.ones(R.shape([]), dtype="float64")
lv1_adjoint: R.Tensor((2, 2), dtype="float64") = R.broadcast_to(gv_adjoint, R.shape([2, 2]))
lv_adjoint: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv)
lv_1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv)
lv_adjoint1: R.Tensor((2, 2), dtype="float64") = R.add(lv_adjoint, lv_1)
x1_adjoint: R.Tensor((2, 2), dtype="float64") = lv_adjoint1
y_adjoint: R.Tensor((2, 2), dtype="float64") = x1_adjoint
y_adjoint_out: R.Tensor((2, 2), dtype="float64") = y_adjoint
R.output(z1, gv, y_adjoint_out)
return ((gv, z1), (y_adjoint_out,))
@R.function
def optimizer(params: R.Tuple(R.Tensor((2, 2), dtype="float64")), gradients: R.Tuple(R.Tensor((2, 2), dtype="float64")), optim_states: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((2, 2), dtype="float64"))) -> R.Tuple(R.Tuple(R.Tensor((2, 2), dtype="float64")), R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((2, 2), dtype="float64"))):
with R.dataflow():
num_steps: R.Tensor((), dtype="int64") = optim_states[0]
num_steps_new: R.Tensor((), dtype="int64") = R.add(num_steps, R.const(1, "int64"))
y: R.Tensor((2, 2), dtype="float64") = params[0]
y_grad: R.Tensor((2, 2), dtype="float64") = gradients[0]
y_v: R.Tensor((2, 2), dtype="float64") = optim_states[1]
lv: R.Tensor((2, 2), dtype="float64") = R.multiply(R.const(0.10000000000000001, "float64"), y_v)
y_v_new: R.Tensor((2, 2), dtype="float64") = R.add(lv, y_grad)
lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(R.const(0.10000000000000001, "float64"), y_v_new)
y_new: R.Tensor((2, 2), dtype="float64") = R.subtract(y, lv1)
params_new: R.Tuple(R.Tensor((2, 2), dtype="float64")) = (y_new,)
optim_states_new: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((2, 2), dtype="float64")) = num_steps_new, y_v_new
R.output(params_new, optim_states_new)
return (params_new, optim_states_new)
# Metadata omitted. Use show_meta=True in script() method to show it.