lr, momentum, dampening, weight_decay, nesterov = tvm.testing.parameters(
(0.01, 0.9, 0, 0, False),
(0.01, 0.9, 0.85, 0.02, False),
(0.01, 0.9, 0.85, 0.02, True),
)
@tvm.testing.parametrize_targets("llvm")
def test_momentum_sgd(target, dev, lr, momentum, dampening, weight_decay, nesterov):
def np_func(param_tuple, grad_tuple, state_tuple):
num_steps = state_tuple[0]
param_tuple_new, state_tuple_new = [], []
state_tuple_new.append(num_steps + 1)
for i in range(len(param_tuple)):
param = param_tuple[i]
grad = grad_tuple[i]
velocity = state_tuple[i + 1]
grad = param * weight_decay + grad
velocity = momentum * velocity + grad * (1 - dampening)
if nesterov:
param = param - (grad + momentum * velocity) * lr
else:
param = param - velocity * lr
param_tuple_new.append(param)
state_tuple_new.append(velocity)
return param_tuple_new, state_tuple_new
_test_optimizer(
target, dev, np_func, MomentumSGD, lr, momentum, dampening, weight_decay, nesterov
)
lr, betas, eps, weight_decay = tvm.testing.parameters(
(0.01, (0.9, 0.999), 1e-08, 0),
(0.01, (0.8, 0.85), 1e-07, 0.1),
)
@tvm.testing.parametrize_targets("llvm")
def test_adam(target, dev, lr, betas, eps, weight_decay):
def np_func(param_tuple, grad_tuple, state_tuple):
num_steps = state_tuple[0]
num_steps_new = num_steps + 1
param_tuple_new = []
state_tuple_new = [None] * len(state_tuple) # type: ignore
state_tuple_new[0] = num_steps_new
state_tuple_new[1] = state_tuple[1] * betas[0]
state_tuple_new[2] = state_tuple[2] * betas[1]
for i in range(len(param_tuple)):
param = param_tuple[i]
grad = grad_tuple[i]
m = state_tuple[i + 3]
v = state_tuple[i + 3 + len(param_tuple)]
grad = grad + weight_decay * param
m = betas[0] * m + (1 - betas[0]) * grad
v = betas[1] * v + (1 - betas[1]) * grad * grad
m_hat = m / (1 - betas[0] ** num_steps_new)
v_hat = v / (1 - betas[1] ** num_steps_new)
param = param - lr * m_hat / (np.sqrt(v_hat) + eps)
param_tuple_new.append(param)
state_tuple_new[i + 3] = m
state_tuple_new[i + 3 + len(param_tuple)] = v
return param_tuple_new, state_tuple_new
_test_optimizer(target, dev, np_func, Adam, lr, betas, eps, weight_decay)