测试优化器

测试优化器#

Hide code cell content
from typing import Callable

import numpy as np
import tvm
from tvm import relax
from tvm import IRModule
from tvm.relax.training.optimizer import Adam, SGD, MomentumSGD
from tvm.script.parser import relax as R
from tvm.runtime.relax_vm import VirtualMachine
from tvm.testing import assert_allclose

def _legalize_and_build(mod: IRModule, target, dev):
    ex = tvm.compile(mod, target)
    vm = VirtualMachine(ex, dev)
    return vm


def _numpy_to_tvm(data):
    if isinstance(data, (list, tuple)):
        return [_numpy_to_tvm(_data) for _data in data]
    return tvm.nd.array(data)


def _tvm_to_numpy(data):
    if isinstance(data, (list, tuple, tvm.ir.Array)):
        return [_tvm_to_numpy(_data) for _data in data]
    return data.numpy()


def _assert_allclose_nested(data1, data2):
    if isinstance(data1, (list, tuple)):
        assert isinstance(data2, (list, tuple))
        assert len(data1) == len(data2)
        for x, y in zip(data1, data2):
            _assert_allclose_nested(x, y)
    else:
        assert_allclose(data1, data2)


def _assert_run_result_same(tvm_func: Callable, np_func: Callable, np_inputs: list):
    result = _tvm_to_numpy(tvm_func(*[_numpy_to_tvm(i) for i in np_inputs]))
    expected = np_func(*np_inputs)
    _assert_allclose_nested(result, expected)
@tvm.testing.parametrize_targets("llvm")
def _test_optimizer(target, dev, np_func, opt_type, *args, **kwargs):
    # 创建两个待优化的张量变量(3x3矩阵和3维向量)
    x = relax.Var("x", R.Tensor((3, 3), "float32"))
    y = relax.Var("y", R.Tensor((3,), "float32"))
    
    # 初始化优化器(如SGD/Adam等),参数通过*args,**kwargs传递
    opt = opt_type(*args, **kwargs).init([x, y])
    
    # 将优化器的更新函数编译为可执行模块
    mod = IRModule.from_expr(opt.get_function().with_attr("global_symbol", "main"))
    tvm_func = _legalize_and_build(mod, target, dev)["main"]

    # 生成随机测试数据
    param_arr = [np.random.rand(3, 3).astype(np.float32), # 参数初始值
                np.random.rand(3).astype(np.float32)]
    grad_arr = [np.random.rand(3, 3).astype(np.float32),  # 梯度初始值
               np.random.rand(3).astype(np.float32)]
    state_arr = _tvm_to_numpy(opt.state)  # 优化器状态转换

    # 对比TVM实现与NumPy参考实现的输出结果
    _assert_run_result_same(tvm_func, np_func, [param_arr, grad_arr, state_arr])
args = (
    (0.01, 0),
    (0.01, 0.02),
)
target = "llvm"
dev = tvm.device(target, 0)

测试 SGD 优化器#

def np_func(param_tuple, grad_tuple, state_tuple,):
    num_steps = state_tuple[0]
    param_tuple_new, state_tuple_new = [], []
    state_tuple_new.append(num_steps + 1)
    for i in range(len(param_tuple)):
        param = param_tuple[i]
        grad = grad_tuple[i]
        param_tuple_new.append(param - lr * (grad + weight_decay * param))
    return param_tuple_new, state_tuple_new

for lr, weight_decay in args:
    _test_optimizer(target, dev, np_func, SGD, lr, weight_decay)
lr, momentum, dampening, weight_decay, nesterov = tvm.testing.parameters(
    (0.01, 0.9, 0, 0, False),
    (0.01, 0.9, 0.85, 0.02, False),
    (0.01, 0.9, 0.85, 0.02, True),
)


@tvm.testing.parametrize_targets("llvm")
def test_momentum_sgd(target, dev, lr, momentum, dampening, weight_decay, nesterov):
    def np_func(param_tuple, grad_tuple, state_tuple):
        num_steps = state_tuple[0]
        param_tuple_new, state_tuple_new = [], []
        state_tuple_new.append(num_steps + 1)

        for i in range(len(param_tuple)):
            param = param_tuple[i]
            grad = grad_tuple[i]
            velocity = state_tuple[i + 1]
            grad = param * weight_decay + grad
            velocity = momentum * velocity + grad * (1 - dampening)
            if nesterov:
                param = param - (grad + momentum * velocity) * lr
            else:
                param = param - velocity * lr
            param_tuple_new.append(param)
            state_tuple_new.append(velocity)

        return param_tuple_new, state_tuple_new

    _test_optimizer(
        target, dev, np_func, MomentumSGD, lr, momentum, dampening, weight_decay, nesterov
    )


lr, betas, eps, weight_decay = tvm.testing.parameters(
    (0.01, (0.9, 0.999), 1e-08, 0),
    (0.01, (0.8, 0.85), 1e-07, 0.1),
)


@tvm.testing.parametrize_targets("llvm")
def test_adam(target, dev, lr, betas, eps, weight_decay):
    def np_func(param_tuple, grad_tuple, state_tuple):
        num_steps = state_tuple[0]
        num_steps_new = num_steps + 1

        param_tuple_new = []
        state_tuple_new = [None] * len(state_tuple)  # type: ignore
        state_tuple_new[0] = num_steps_new
        state_tuple_new[1] = state_tuple[1] * betas[0]
        state_tuple_new[2] = state_tuple[2] * betas[1]

        for i in range(len(param_tuple)):
            param = param_tuple[i]
            grad = grad_tuple[i]
            m = state_tuple[i + 3]
            v = state_tuple[i + 3 + len(param_tuple)]
            grad = grad + weight_decay * param
            m = betas[0] * m + (1 - betas[0]) * grad
            v = betas[1] * v + (1 - betas[1]) * grad * grad
            m_hat = m / (1 - betas[0] ** num_steps_new)
            v_hat = v / (1 - betas[1] ** num_steps_new)
            param = param - lr * m_hat / (np.sqrt(v_hat) + eps)
            param_tuple_new.append(param)
            state_tuple_new[i + 3] = m
            state_tuple_new[i + 3 + len(param_tuple)] = v

        return param_tuple_new, state_tuple_new

    _test_optimizer(target, dev, np_func, Adam, lr, betas, eps, weight_decay)