# {class}`~tvm.relax.training.SetupTrainer`

In [1]:
import tvm
import tvm.testing

from tvm import relax, TVMError
from tvm.ir.base import assert_structural_equal
from tvm.relax.training import SetupTrainer
from tvm.relax.training.optimizer import SGD, MomentumSGD
from tvm.relax.training.loss import MSELoss
from tvm.script import ir as I, relax as R

## 测试简单模型

In [2]:
@I.ir_module
class Backbone:
    I.module_attrs({"param_num": 1, "state_num": 0})
    @R.function
    def backbone(x: R.Tensor((2, 2), "float64"), y: R.Tensor((2, 2), "float64")):
        with R.dataflow():
            x1 = x + y
            R.output(x1)
        return x1

In [3]:
sinfo = relax.TensorStructInfo((2, 2), "float64")
setup_trainer = SetupTrainer(MSELoss(reduction="sum"), SGD(0.1), [sinfo, sinfo], legalize=False)
train_mod = setup_trainer(Backbone)
train_mod.show()

## 测试状态

In [4]:
@I.ir_module
class Backbone:
    I.module_attrs({"param_num": 1, "state_num": 1})
    @R.function
    def backbone(x: R.Tensor((2, 2), "float64"), y: R.Tensor((2, 2), "float64"), z: R.Tensor((2, 2), "float64")):
        with R.dataflow():
            x1 = x + y
            z1 = z + R.const(1, "float64")
            R.output(x1, z1)
        return x1, z1

In [5]:
sinfo = relax.TensorStructInfo((2, 2), "float64")
setup_trainer = SetupTrainer(
    MSELoss(reduction="sum"), MomentumSGD(0.1, 0.1), [sinfo, sinfo], legalize=False
)
train_mod = setup_trainer(Backbone)
train_mod.show()