optimizer#

import tvm
import tvm.testing
from tvm import relax
from tvm.ir.base import assert_structural_equal
from tvm.relax.training.optimizer import SGD, MomentumSGD, Adam
from tvm.script.parser import relax as R

测试优化器错误#

import pytest
x1 = relax.Var("x1", R.Tensor((3, 3), "float32"))
x2 = relax.Var("x2", R.Tensor((3, 3), "float64"))
x3 = relax.Var("x3", R.Tuple([R.Tensor((3, 3), "float32")]))
x4 = relax.Var("x4", R.Tensor((3, 3), "int64"))
x5 = relax.Tuple([x1])

# fine cases
SGD(0.01).init(x1)
SGD(0.01).init([x1])
assert SGD(0.01).init([x2]).dtype == "float64"

with pytest.raises(ValueError):
    SGD(0.01).init([x1, x1])
with pytest.raises(ValueError):
    SGD(0.01).init([x1, x2])
with pytest.raises(ValueError):
    SGD(0.01).init(x3)
with pytest.raises(ValueError):
    SGD(0.01).init(x4)
with pytest.raises(ValueError):
    SGD(0.01).init(x5)
with pytest.raises(
    RuntimeError,
    match="Please call init\\(\\) for the optimizer before calling get_function\\(\\)",
):
    SGD(0.01).get_function()

SGD#

x = relax.Var("x", R.Tensor((3, 3), "float32"))
y = relax.Var("y", R.Tensor((3,), "float32"))
sgd = SGD(0.01).init([x, y]).get_function()
sgd.show()
Hide code cell output
# from tvm.script import relax as R

@R.function
def SGD(params: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), gradients: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), optim_states: R.Tuple(R.Tensor((), dtype="int64"))) -> R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), R.Tuple(R.Tensor((), dtype="int64"))):
    with R.dataflow():
        num_steps: R.Tensor((), dtype="int64") = optim_states[0]
        num_steps_new: R.Tensor((), dtype="int64") = R.add(num_steps, R.const(1, "int64"))
        x: R.Tensor((3, 3), dtype="float32") = params[0]
        x_grad: R.Tensor((3, 3), dtype="float32") = gradients[0]
        lv: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.0099999997764825821, "float32"), x_grad)
        x_new: R.Tensor((3, 3), dtype="float32") = R.subtract(x, lv)
        y: R.Tensor((3,), dtype="float32") = params[1]
        y_grad: R.Tensor((3,), dtype="float32") = gradients[1]
        lv1: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.0099999997764825821, "float32"), y_grad)
        y_new: R.Tensor((3,), dtype="float32") = R.subtract(y, lv1)
        params_new: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")) = x_new, y_new
        optim_states_new: R.Tuple(R.Tensor((), dtype="int64")) = (num_steps_new,)
        R.output(params_new, optim_states_new)
    return (params_new, optim_states_new)
x = relax.Var("x", R.Tensor((3, 3), "float32"))
y = relax.Var("y", R.Tensor((3,), "float32"))
sgd = SGD(0.01, 0.02).init([x, y]).get_function()
sgd.show()
Hide code cell output
# from tvm.script import relax as R

@R.function
def SGD(params: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), gradients: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), optim_states: R.Tuple(R.Tensor((), dtype="int64"))) -> R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), R.Tuple(R.Tensor((), dtype="int64"))):
    with R.dataflow():
        num_steps: R.Tensor((), dtype="int64") = optim_states[0]
        num_steps_new: R.Tensor((), dtype="int64") = R.add(num_steps, R.const(1, "int64"))
        x: R.Tensor((3, 3), dtype="float32") = params[0]
        x_grad: R.Tensor((3, 3), dtype="float32") = gradients[0]
        lv: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.019999999552965164, "float32"), x)
        x_grad_new: R.Tensor((3, 3), dtype="float32") = R.add(lv, x_grad)
        lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.0099999997764825821, "float32"), x_grad_new)
        x_new: R.Tensor((3, 3), dtype="float32") = R.subtract(x, lv1)
        y: R.Tensor((3,), dtype="float32") = params[1]
        y_grad: R.Tensor((3,), dtype="float32") = gradients[1]
        lv2: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.019999999552965164, "float32"), y)
        y_grad_new: R.Tensor((3,), dtype="float32") = R.add(lv2, y_grad)
        lv3: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.0099999997764825821, "float32"), y_grad_new)
        y_new: R.Tensor((3,), dtype="float32") = R.subtract(y, lv3)
        params_new: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")) = x_new, y_new
        optim_states_new: R.Tuple(R.Tensor((), dtype="int64")) = (num_steps_new,)
        R.output(params_new, optim_states_new)
    return (params_new, optim_states_new)

MomentumSGD#

x = relax.Var("x", R.Tensor((3, 3), "float32"))
y = relax.Var("y", R.Tensor((3,), "float32"))
msgd = MomentumSGD(0.01, 0.9).init([x, y]).get_function()
msgd.show()
Hide code cell output
# from tvm.script import relax as R

@R.function
def MomentumSGD(params: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), gradients: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), optim_states: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"))) -> R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"))):
    with R.dataflow():
        num_steps: R.Tensor((), dtype="int64") = optim_states[0]
        num_steps_new: R.Tensor((), dtype="int64") = R.add(num_steps, R.const(1, "int64"))
        x: R.Tensor((3, 3), dtype="float32") = params[0]
        x_grad: R.Tensor((3, 3), dtype="float32") = gradients[0]
        x_v: R.Tensor((3, 3), dtype="float32") = optim_states[1]
        lv: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.89999997615814209, "float32"), x_v)
        x_v_new: R.Tensor((3, 3), dtype="float32") = R.add(lv, x_grad)
        lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.0099999997764825821, "float32"), x_v_new)
        x_new: R.Tensor((3, 3), dtype="float32") = R.subtract(x, lv1)
        y: R.Tensor((3,), dtype="float32") = params[1]
        y_grad: R.Tensor((3,), dtype="float32") = gradients[1]
        y_v: R.Tensor((3,), dtype="float32") = optim_states[2]
        lv2: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.89999997615814209, "float32"), y_v)
        y_v_new: R.Tensor((3,), dtype="float32") = R.add(lv2, y_grad)
        lv3: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.0099999997764825821, "float32"), y_v_new)
        y_new: R.Tensor((3,), dtype="float32") = R.subtract(y, lv3)
        params_new: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")) = x_new, y_new
        optim_states_new: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")) = num_steps_new, x_v_new, y_v_new
        R.output(params_new, optim_states_new)
    return (params_new, optim_states_new)
lr, mom, damp, wd, nest = 0.01, 0.9, 0.85, 0.02, False

x = relax.Var("x", R.Tensor((3, 3), "float32"))
y = relax.Var("y", R.Tensor((3,), "float32"))
msgd = MomentumSGD(lr, mom, damp, wd, nest).init([x, y]).get_function()
msgd.show()
Hide code cell output
# from tvm.script import relax as R

@R.function
def MomentumSGD(params: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), gradients: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), optim_states: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"))) -> R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"))):
    with R.dataflow():
        num_steps: R.Tensor((), dtype="int64") = optim_states[0]
        num_steps_new: R.Tensor((), dtype="int64") = R.add(num_steps, R.const(1, "int64"))
        x: R.Tensor((3, 3), dtype="float32") = params[0]
        x_grad: R.Tensor((3, 3), dtype="float32") = gradients[0]
        x_v: R.Tensor((3, 3), dtype="float32") = optim_states[1]
        lv: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.019999999552965164, "float32"), x)
        x_grad_new: R.Tensor((3, 3), dtype="float32") = R.add(lv, x_grad)
        lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.89999997615814209, "float32"), x_v)
        lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.15000000596046448, "float32"), x_grad_new)
        x_v_new: R.Tensor((3, 3), dtype="float32") = R.add(lv1, lv2)
        lv3: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.0099999997764825821, "float32"), x_v_new)
        x_new: R.Tensor((3, 3), dtype="float32") = R.subtract(x, lv3)
        y: R.Tensor((3,), dtype="float32") = params[1]
        y_grad: R.Tensor((3,), dtype="float32") = gradients[1]
        y_v: R.Tensor((3,), dtype="float32") = optim_states[2]
        lv4: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.019999999552965164, "float32"), y)
        y_grad_new: R.Tensor((3,), dtype="float32") = R.add(lv4, y_grad)
        lv5: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.89999997615814209, "float32"), y_v)
        lv6: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.15000000596046448, "float32"), y_grad_new)
        y_v_new: R.Tensor((3,), dtype="float32") = R.add(lv5, lv6)
        lv7: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.0099999997764825821, "float32"), y_v_new)
        y_new: R.Tensor((3,), dtype="float32") = R.subtract(y, lv7)
        params_new: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")) = x_new, y_new
        optim_states_new: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")) = num_steps_new, x_v_new, y_v_new
        R.output(params_new, optim_states_new)
    return (params_new, optim_states_new)
lr, mom, damp, wd, nest = 0.01, 0.9, 0.85, 0.02, True

x = relax.Var("x", R.Tensor((3, 3), "float32"))
y = relax.Var("y", R.Tensor((3,), "float32"))
msgd = MomentumSGD(lr, mom, damp, wd, nest).init([x, y]).get_function()
msgd.show()
Hide code cell output
# from tvm.script import relax as R

@R.function
def MomentumSGD(params: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), gradients: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), optim_states: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"))) -> R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"))):
    with R.dataflow():
        num_steps: R.Tensor((), dtype="int64") = optim_states[0]
        num_steps_new: R.Tensor((), dtype="int64") = R.add(num_steps, R.const(1, "int64"))
        x: R.Tensor((3, 3), dtype="float32") = params[0]
        x_grad: R.Tensor((3, 3), dtype="float32") = gradients[0]
        x_v: R.Tensor((3, 3), dtype="float32") = optim_states[1]
        lv: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.019999999552965164, "float32"), x)
        x_grad_new: R.Tensor((3, 3), dtype="float32") = R.add(lv, x_grad)
        lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.89999997615814209, "float32"), x_v)
        lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.15000000596046448, "float32"), x_grad_new)
        x_v_new: R.Tensor((3, 3), dtype="float32") = R.add(lv1, lv2)
        lv3: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.89999997615814209, "float32"), x_v_new)
        x_g_nest: R.Tensor((3, 3), dtype="float32") = R.add(x_grad_new, lv3)
        lv4: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.0099999997764825821, "float32"), x_g_nest)
        x_new: R.Tensor((3, 3), dtype="float32") = R.subtract(x, lv4)
        y: R.Tensor((3,), dtype="float32") = params[1]
        y_grad: R.Tensor((3,), dtype="float32") = gradients[1]
        y_v: R.Tensor((3,), dtype="float32") = optim_states[2]
        lv5: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.019999999552965164, "float32"), y)
        y_grad_new: R.Tensor((3,), dtype="float32") = R.add(lv5, y_grad)
        lv6: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.89999997615814209, "float32"), y_v)
        lv7: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.15000000596046448, "float32"), y_grad_new)
        y_v_new: R.Tensor((3,), dtype="float32") = R.add(lv6, lv7)
        lv8: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.89999997615814209, "float32"), y_v_new)
        y_g_nest: R.Tensor((3,), dtype="float32") = R.add(y_grad_new, lv8)
        lv9: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.0099999997764825821, "float32"), y_g_nest)
        y_new: R.Tensor((3,), dtype="float32") = R.subtract(y, lv9)
        params_new: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")) = x_new, y_new
        optim_states_new: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")) = num_steps_new, x_v_new, y_v_new
        R.output(params_new, optim_states_new)
    return (params_new, optim_states_new)

Adam#

x = relax.Var("x", R.Tensor((3, 3), "float32"))
y = relax.Var("y", R.Tensor((3,), "float32"))
adam = Adam(0.01).init([x, y]).get_function()
adam.show()
Hide code cell output
# from tvm.script import relax as R

@R.function
def Adam(params: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), gradients: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), optim_states: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"))) -> R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"))):
    with R.dataflow():
        num_steps: R.Tensor((), dtype="int64") = optim_states[0]
        num_steps_new: R.Tensor((), dtype="int64") = R.add(num_steps, R.const(1, "int64"))
        lv: R.Tensor((), dtype="float32") = optim_states[1]
        beta1_prod: R.Tensor((), dtype="float32") = R.multiply(lv, R.const(0.89999997615814209, "float32"))
        lv1: R.Tensor((), dtype="float32") = optim_states[2]
        beta2_prod: R.Tensor((), dtype="float32") = R.multiply(lv1, R.const(0.99900001287460327, "float32"))
        x: R.Tensor((3, 3), dtype="float32") = params[0]
        x_grad: R.Tensor((3, 3), dtype="float32") = gradients[0]
        x_m: R.Tensor((3, 3), dtype="float32") = optim_states[3]
        x_v: R.Tensor((3, 3), dtype="float32") = optim_states[5]
        lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.89999997615814209, "float32"), x_m)
        lv3: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.10000000149011612, "float32"), x_grad)
        x_m_new: R.Tensor((3, 3), dtype="float32") = R.add(lv2, lv3)
        lv4: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.99900001287460327, "float32"), x_v)
        lv5: R.Tensor((3, 3), dtype="float32") = R.multiply(x_grad, x_grad)
        lv6: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.0010000000474974513, "float32"), lv5)
        x_v_new: R.Tensor((3, 3), dtype="float32") = R.add(lv4, lv6)
        lv7: R.Tensor((), dtype="float32") = R.subtract(R.const(1.0, "float32"), beta1_prod)
        x_m_hat: R.Tensor((3, 3), dtype="float32") = R.divide(x_m_new, lv7)
        lv8: R.Tensor((), dtype="float32") = R.subtract(R.const(1.0, "float32"), beta2_prod)
        x_v_hat: R.Tensor((3, 3), dtype="float32") = R.divide(x_v_new, lv8)
        lv9: R.Tensor((3, 3), dtype="float32") = R.sqrt(x_v_hat)
        lv10: R.Tensor((3, 3), dtype="float32") = R.add(lv9, R.const(9.9999999392252903e-09, "float32"))
        lv11: R.Tensor((3, 3), dtype="float32") = R.divide(x_m_hat, lv10)
        lv12: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.0099999997764825821, "float32"), lv11)
        x_new: R.Tensor((3, 3), dtype="float32") = R.subtract(x, lv12)
        y: R.Tensor((3,), dtype="float32") = params[1]
        y_grad: R.Tensor((3,), dtype="float32") = gradients[1]
        y_m: R.Tensor((3,), dtype="float32") = optim_states[4]
        y_v: R.Tensor((3,), dtype="float32") = optim_states[6]
        lv13: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.89999997615814209, "float32"), y_m)
        lv14: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.10000000149011612, "float32"), y_grad)
        y_m_new: R.Tensor((3,), dtype="float32") = R.add(lv13, lv14)
        lv15: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.99900001287460327, "float32"), y_v)
        lv16: R.Tensor((3,), dtype="float32") = R.multiply(y_grad, y_grad)
        lv17: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.0010000000474974513, "float32"), lv16)
        y_v_new: R.Tensor((3,), dtype="float32") = R.add(lv15, lv17)
        lv18: R.Tensor((), dtype="float32") = R.subtract(R.const(1.0, "float32"), beta1_prod)
        y_m_hat: R.Tensor((3,), dtype="float32") = R.divide(y_m_new, lv18)
        lv19: R.Tensor((), dtype="float32") = R.subtract(R.const(1.0, "float32"), beta2_prod)
        y_v_hat: R.Tensor((3,), dtype="float32") = R.divide(y_v_new, lv19)
        lv20: R.Tensor((3,), dtype="float32") = R.sqrt(y_v_hat)
        lv21: R.Tensor((3,), dtype="float32") = R.add(lv20, R.const(9.9999999392252903e-09, "float32"))
        lv22: R.Tensor((3,), dtype="float32") = R.divide(y_m_hat, lv21)
        lv23: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.0099999997764825821, "float32"), lv22)
        y_new: R.Tensor((3,), dtype="float32") = R.subtract(y, lv23)
        params_new: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")) = x_new, y_new
        optim_states_new: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")) = num_steps_new, beta1_prod, beta2_prod, x_m_new, y_m_new, x_v_new, y_v_new
        R.output(params_new, optim_states_new)
    return (params_new, optim_states_new)
x = relax.Var("x", R.Tensor((3, 3), "float32"))
y = relax.Var("y", R.Tensor((3,), "float32"))
adam = Adam(0.01, (0.8, 0.85), 1e-7, 0.1).init([x, y]).get_function()
adam.show()
Hide code cell output
# from tvm.script import relax as R

@R.function
def Adam(params: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), gradients: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), optim_states: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"))) -> R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")), R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"))):
    with R.dataflow():
        num_steps: R.Tensor((), dtype="int64") = optim_states[0]
        num_steps_new: R.Tensor((), dtype="int64") = R.add(num_steps, R.const(1, "int64"))
        lv: R.Tensor((), dtype="float32") = optim_states[1]
        beta1_prod: R.Tensor((), dtype="float32") = R.multiply(lv, R.const(0.80000001192092896, "float32"))
        lv1: R.Tensor((), dtype="float32") = optim_states[2]
        beta2_prod: R.Tensor((), dtype="float32") = R.multiply(lv1, R.const(0.85000002384185791, "float32"))
        x: R.Tensor((3, 3), dtype="float32") = params[0]
        x_grad: R.Tensor((3, 3), dtype="float32") = gradients[0]
        x_m: R.Tensor((3, 3), dtype="float32") = optim_states[3]
        x_v: R.Tensor((3, 3), dtype="float32") = optim_states[5]
        lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.10000000149011612, "float32"), x)
        x_grad_new: R.Tensor((3, 3), dtype="float32") = R.add(lv2, x_grad)
        lv3: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.80000001192092896, "float32"), x_m)
        lv4: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.20000000298023224, "float32"), x_grad_new)
        x_m_new: R.Tensor((3, 3), dtype="float32") = R.add(lv3, lv4)
        lv5: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.85000002384185791, "float32"), x_v)
        lv6: R.Tensor((3, 3), dtype="float32") = R.multiply(x_grad_new, x_grad_new)
        lv7: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.15000000596046448, "float32"), lv6)
        x_v_new: R.Tensor((3, 3), dtype="float32") = R.add(lv5, lv7)
        lv8: R.Tensor((), dtype="float32") = R.subtract(R.const(1.0, "float32"), beta1_prod)
        x_m_hat: R.Tensor((3, 3), dtype="float32") = R.divide(x_m_new, lv8)
        lv9: R.Tensor((), dtype="float32") = R.subtract(R.const(1.0, "float32"), beta2_prod)
        x_v_hat: R.Tensor((3, 3), dtype="float32") = R.divide(x_v_new, lv9)
        lv10: R.Tensor((3, 3), dtype="float32") = R.sqrt(x_v_hat)
        lv11: R.Tensor((3, 3), dtype="float32") = R.add(lv10, R.const(1.0000000116860974e-07, "float32"))
        lv12: R.Tensor((3, 3), dtype="float32") = R.divide(x_m_hat, lv11)
        lv13: R.Tensor((3, 3), dtype="float32") = R.multiply(R.const(0.0099999997764825821, "float32"), lv12)
        x_new: R.Tensor((3, 3), dtype="float32") = R.subtract(x, lv13)
        y: R.Tensor((3,), dtype="float32") = params[1]
        y_grad: R.Tensor((3,), dtype="float32") = gradients[1]
        y_m: R.Tensor((3,), dtype="float32") = optim_states[4]
        y_v: R.Tensor((3,), dtype="float32") = optim_states[6]
        lv14: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.10000000149011612, "float32"), y)
        y_grad_new: R.Tensor((3,), dtype="float32") = R.add(lv14, y_grad)
        lv15: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.80000001192092896, "float32"), y_m)
        lv16: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.20000000298023224, "float32"), y_grad_new)
        y_m_new: R.Tensor((3,), dtype="float32") = R.add(lv15, lv16)
        lv17: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.85000002384185791, "float32"), y_v)
        lv18: R.Tensor((3,), dtype="float32") = R.multiply(y_grad_new, y_grad_new)
        lv19: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.15000000596046448, "float32"), lv18)
        y_v_new: R.Tensor((3,), dtype="float32") = R.add(lv17, lv19)
        lv20: R.Tensor((), dtype="float32") = R.subtract(R.const(1.0, "float32"), beta1_prod)
        y_m_hat: R.Tensor((3,), dtype="float32") = R.divide(y_m_new, lv20)
        lv21: R.Tensor((), dtype="float32") = R.subtract(R.const(1.0, "float32"), beta2_prod)
        y_v_hat: R.Tensor((3,), dtype="float32") = R.divide(y_v_new, lv21)
        lv22: R.Tensor((3,), dtype="float32") = R.sqrt(y_v_hat)
        lv23: R.Tensor((3,), dtype="float32") = R.add(lv22, R.const(1.0000000116860974e-07, "float32"))
        lv24: R.Tensor((3,), dtype="float32") = R.divide(y_m_hat, lv23)
        lv25: R.Tensor((3,), dtype="float32") = R.multiply(R.const(0.0099999997764825821, "float32"), lv24)
        y_new: R.Tensor((3,), dtype="float32") = R.subtract(y, lv25)
        params_new: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")) = x_new, y_new
        optim_states_new: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3,), dtype="float32")) = num_steps_new, beta1_prod, beta2_prod, x_m_new, y_m_new, x_v_new, y_v_new
        R.output(params_new, optim_states_new)
    return (params_new, optim_states_new)
x = relax.Var("x", R.Tensor((3, 3), "float64"))
y = relax.Var("y", R.Tensor((3,), "float64"))
adam = Adam(0.01, (0.8, 0.85), 1e-7, 0.1).init([x, y]).get_function()
adam.show()
Hide code cell output
# from tvm.script import relax as R

@R.function
def Adam(params: R.Tuple(R.Tensor((3, 3), dtype="float64"), R.Tensor((3,), dtype="float64")), gradients: R.Tuple(R.Tensor((3, 3), dtype="float64"), R.Tensor((3,), dtype="float64")), optim_states: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((), dtype="float64"), R.Tensor((), dtype="float64"), R.Tensor((3, 3), dtype="float64"), R.Tensor((3,), dtype="float64"), R.Tensor((3, 3), dtype="float64"), R.Tensor((3,), dtype="float64"))) -> R.Tuple(R.Tuple(R.Tensor((3, 3), dtype="float64"), R.Tensor((3,), dtype="float64")), R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((), dtype="float64"), R.Tensor((), dtype="float64"), R.Tensor((3, 3), dtype="float64"), R.Tensor((3,), dtype="float64"), R.Tensor((3, 3), dtype="float64"), R.Tensor((3,), dtype="float64"))):
    with R.dataflow():
        num_steps: R.Tensor((), dtype="int64") = optim_states[0]
        num_steps_new: R.Tensor((), dtype="int64") = R.add(num_steps, R.const(1, "int64"))
        lv: R.Tensor((), dtype="float64") = optim_states[1]
        beta1_prod: R.Tensor((), dtype="float64") = R.multiply(lv, R.const(0.80000000000000004, "float64"))
        lv1: R.Tensor((), dtype="float64") = optim_states[2]
        beta2_prod: R.Tensor((), dtype="float64") = R.multiply(lv1, R.const(0.84999999999999998, "float64"))
        x: R.Tensor((3, 3), dtype="float64") = params[0]
        x_grad: R.Tensor((3, 3), dtype="float64") = gradients[0]
        x_m: R.Tensor((3, 3), dtype="float64") = optim_states[3]
        x_v: R.Tensor((3, 3), dtype="float64") = optim_states[5]
        lv2: R.Tensor((3, 3), dtype="float64") = R.multiply(R.const(0.10000000000000001, "float64"), x)
        x_grad_new: R.Tensor((3, 3), dtype="float64") = R.add(lv2, x_grad)
        lv3: R.Tensor((3, 3), dtype="float64") = R.multiply(R.const(0.80000000000000004, "float64"), x_m)
        lv4: R.Tensor((3, 3), dtype="float64") = R.multiply(R.const(0.20000000000000001, "float64"), x_grad_new)
        x_m_new: R.Tensor((3, 3), dtype="float64") = R.add(lv3, lv4)
        lv5: R.Tensor((3, 3), dtype="float64") = R.multiply(R.const(0.84999999999999998, "float64"), x_v)
        lv6: R.Tensor((3, 3), dtype="float64") = R.multiply(x_grad_new, x_grad_new)
        lv7: R.Tensor((3, 3), dtype="float64") = R.multiply(R.const(0.14999999999999999, "float64"), lv6)
        x_v_new: R.Tensor((3, 3), dtype="float64") = R.add(lv5, lv7)
        lv8: R.Tensor((), dtype="float64") = R.subtract(R.const(1.0, "float64"), beta1_prod)
        x_m_hat: R.Tensor((3, 3), dtype="float64") = R.divide(x_m_new, lv8)
        lv9: R.Tensor((), dtype="float64") = R.subtract(R.const(1.0, "float64"), beta2_prod)
        x_v_hat: R.Tensor((3, 3), dtype="float64") = R.divide(x_v_new, lv9)
        lv10: R.Tensor((3, 3), dtype="float64") = R.sqrt(x_v_hat)
        lv11: R.Tensor((3, 3), dtype="float64") = R.add(lv10, R.const(9.9999999999999995e-08, "float64"))
        lv12: R.Tensor((3, 3), dtype="float64") = R.divide(x_m_hat, lv11)
        lv13: R.Tensor((3, 3), dtype="float64") = R.multiply(R.const(0.01, "float64"), lv12)
        x_new: R.Tensor((3, 3), dtype="float64") = R.subtract(x, lv13)
        y: R.Tensor((3,), dtype="float64") = params[1]
        y_grad: R.Tensor((3,), dtype="float64") = gradients[1]
        y_m: R.Tensor((3,), dtype="float64") = optim_states[4]
        y_v: R.Tensor((3,), dtype="float64") = optim_states[6]
        lv14: R.Tensor((3,), dtype="float64") = R.multiply(R.const(0.10000000000000001, "float64"), y)
        y_grad_new: R.Tensor((3,), dtype="float64") = R.add(lv14, y_grad)
        lv15: R.Tensor((3,), dtype="float64") = R.multiply(R.const(0.80000000000000004, "float64"), y_m)
        lv16: R.Tensor((3,), dtype="float64") = R.multiply(R.const(0.20000000000000001, "float64"), y_grad_new)
        y_m_new: R.Tensor((3,), dtype="float64") = R.add(lv15, lv16)
        lv17: R.Tensor((3,), dtype="float64") = R.multiply(R.const(0.84999999999999998, "float64"), y_v)
        lv18: R.Tensor((3,), dtype="float64") = R.multiply(y_grad_new, y_grad_new)
        lv19: R.Tensor((3,), dtype="float64") = R.multiply(R.const(0.14999999999999999, "float64"), lv18)
        y_v_new: R.Tensor((3,), dtype="float64") = R.add(lv17, lv19)
        lv20: R.Tensor((), dtype="float64") = R.subtract(R.const(1.0, "float64"), beta1_prod)
        y_m_hat: R.Tensor((3,), dtype="float64") = R.divide(y_m_new, lv20)
        lv21: R.Tensor((), dtype="float64") = R.subtract(R.const(1.0, "float64"), beta2_prod)
        y_v_hat: R.Tensor((3,), dtype="float64") = R.divide(y_v_new, lv21)
        lv22: R.Tensor((3,), dtype="float64") = R.sqrt(y_v_hat)
        lv23: R.Tensor((3,), dtype="float64") = R.add(lv22, R.const(9.9999999999999995e-08, "float64"))
        lv24: R.Tensor((3,), dtype="float64") = R.divide(y_m_hat, lv23)
        lv25: R.Tensor((3,), dtype="float64") = R.multiply(R.const(0.01, "float64"), lv24)
        y_new: R.Tensor((3,), dtype="float64") = R.subtract(y, lv25)
        params_new: R.Tuple(R.Tensor((3, 3), dtype="float64"), R.Tensor((3,), dtype="float64")) = x_new, y_new
        optim_states_new: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((), dtype="float64"), R.Tensor((), dtype="float64"), R.Tensor((3, 3), dtype="float64"), R.Tensor((3,), dtype="float64"), R.Tensor((3, 3), dtype="float64"), R.Tensor((3,), dtype="float64")) = num_steps_new, beta1_prod, beta2_prod, x_m_new, y_m_new, x_v_new, y_v_new
        R.output(params_new, optim_states_new)
    return (params_new, optim_states_new)