# {mod}`~tvm.relax.training.optimizer`

In [1]:
import tvm
import tvm.testing
from tvm import relax
from tvm.ir.base import assert_structural_equal
from tvm.relax.training.optimizer import SGD, MomentumSGD, Adam
from tvm.script.parser import relax as R

## 测试优化器错误

In [2]:
import pytest
x1 = relax.Var("x1", R.Tensor((3, 3), "float32"))
x2 = relax.Var("x2", R.Tensor((3, 3), "float64"))
x3 = relax.Var("x3", R.Tuple([R.Tensor((3, 3), "float32")]))
x4 = relax.Var("x4", R.Tensor((3, 3), "int64"))
x5 = relax.Tuple([x1])

# fine cases
SGD(0.01).init(x1)
SGD(0.01).init([x1])
assert SGD(0.01).init([x2]).dtype == "float64"

with pytest.raises(ValueError):
    SGD(0.01).init([x1, x1])
with pytest.raises(ValueError):
    SGD(0.01).init([x1, x2])
with pytest.raises(ValueError):
    SGD(0.01).init(x3)
with pytest.raises(ValueError):
    SGD(0.01).init(x4)
with pytest.raises(ValueError):
    SGD(0.01).init(x5)
with pytest.raises(
    RuntimeError,
    match="Please call init\\(\\) for the optimizer before calling get_function\\(\\)",
):
    SGD(0.01).get_function()

## SGD

In [3]:
x = relax.Var("x", R.Tensor((3, 3), "float32"))
y = relax.Var("y", R.Tensor((3,), "float32"))
sgd = SGD(0.01).init([x, y]).get_function()
sgd.show()

In [4]:
x = relax.Var("x", R.Tensor((3, 3), "float32"))
y = relax.Var("y", R.Tensor((3,), "float32"))
sgd = SGD(0.01, 0.02).init([x, y]).get_function()
sgd.show()

## MomentumSGD

In [5]:
x = relax.Var("x", R.Tensor((3, 3), "float32"))
y = relax.Var("y", R.Tensor((3,), "float32"))
msgd = MomentumSGD(0.01, 0.9).init([x, y]).get_function()
msgd.show()

In [6]:
lr, mom, damp, wd, nest = 0.01, 0.9, 0.85, 0.02, False

x = relax.Var("x", R.Tensor((3, 3), "float32"))
y = relax.Var("y", R.Tensor((3,), "float32"))
msgd = MomentumSGD(lr, mom, damp, wd, nest).init([x, y]).get_function()
msgd.show()

In [7]:
lr, mom, damp, wd, nest = 0.01, 0.9, 0.85, 0.02, True

x = relax.Var("x", R.Tensor((3, 3), "float32"))
y = relax.Var("y", R.Tensor((3,), "float32"))
msgd = MomentumSGD(lr, mom, damp, wd, nest).init([x, y]).get_function()
msgd.show()

## Adam

In [8]:
x = relax.Var("x", R.Tensor((3, 3), "float32"))
y = relax.Var("y", R.Tensor((3,), "float32"))
adam = Adam(0.01).init([x, y]).get_function()
adam.show()

In [9]:
x = relax.Var("x", R.Tensor((3, 3), "float32"))
y = relax.Var("y", R.Tensor((3,), "float32"))
adam = Adam(0.01, (0.8, 0.85), 1e-7, 0.1).init([x, y]).get_function()
adam.show()

In [10]:
x = relax.Var("x", R.Tensor((3, 3), "float64"))
y = relax.Var("y", R.Tensor((3,), "float64"))
adam = Adam(0.01, (0.8, 0.85), 1e-7, 0.1).init([x, y]).get_function()
adam.show()