数值训练器#

import pytest
import numpy as np

import tvm
from tvm import relax, TVMError
from tvm.relax.training import SetupTrainer, Trainer
from tvm.relax.training.optimizer import SGD, Adam
from tvm.relax.training.loss import MSELoss
from tvm.script import ir as I, relax as R
def _get_backbone():
    @I.ir_module
    class MLP:
        I.module_attrs({"param_num": 2, "state_num": 0})

        @R.function
        def backbone(
            x: R.Tensor((1, 10), "float32"),
            w0: R.Tensor((10, 5), "float32"),
            b0: R.Tensor((5,), "float32"),
        ):
            with R.dataflow():
                lv0 = R.matmul(x, w0)
                lv1 = R.add(lv0, b0)
                out = R.nn.relu(lv1)
                R.output(out)
            return out

    return MLP


def _make_dataset():
    N = 100
    return [[np.ones((1, 10)).astype(np.float32), np.array([[0, 0, 1, 0, 0]], np.float32)]] * N

测试主干网络#

backbone = _get_backbone()
pred_sinfo = relax.TensorStructInfo((1, 5), "float32")
target_sinfo = relax.TensorStructInfo((1, 5), "float32")

setup_trainer = SetupTrainer(
    MSELoss(reduction="sum"),
    Adam(0.01),
    [pred_sinfo, target_sinfo],
)

target = "llvm"
dev = tvm.device(target, 0)
train_mod = setup_trainer(backbone)
ex = tvm.compile(train_mod, target)
vm = relax.VirtualMachine(ex, dev, profile=True)

trainer = Trainer(train_mod, vm, dev, False)
trainer.zero_init_params()
trainer.xaiver_uniform_init_params()

dataset = _make_dataset()
trainer.predict(dataset[0][0])
trainer.update(dataset[0][0], dataset[0][1])
trainer.profile_adjoint(dataset[0][0], dataset[0][1])
Hide code cell output
Name                          Duration (us)  Percent  Device  Count                                      Argument Shapes  
vm.builtin.reshape                     2.99     1.27    cpu0      1                                       float32[1, 10]  
broadcast_to                           2.43     1.03    cpu0      1                             float32[], float32[1, 5]  
multiply                               2.28     0.97    cpu0      3          float32[1, 5], float32[1, 5], float32[1, 5]  
where                                  2.19     0.93    cpu0      1  bool[1, 5], float32[], float32[1, 5], float32[1, 5]  
vm.builtin.check_tensor_info           2.07     0.88    cpu0      1                                       float32[1, 10]  
vm.builtin.match_shape                 1.30     0.55    cpu0      1                                       float32[1, 10]  
matmul                                 1.20     0.51    cpu0      1        float32[1, 10], float32[10, 5], float32[1, 5]  
vm.builtin.check_tensor_info           1.16     0.49    cpu0      1                                       float32[10, 5]  
matmul1                                1.10     0.47    cpu0      1        float32[10, 1], float32[1, 5], float32[10, 5]  
vm.builtin.make_tuple                  1.02     0.43    cpu0      1                           float32[10, 5], float32[5]  
vm.builtin.make_tuple                  0.97     0.41    cpu0      1                                            float32[]  
vm.builtin.match_shape                 0.95     0.40    cpu0      1                                       float32[10, 5]  
vm.builtin.match_shape                 0.92     0.39    cpu0      1                                        float32[1, 5]  
add                                    0.90     0.39    cpu0      1             float32[1, 5], float32[5], float32[1, 5]  
vm.builtin.match_shape                 0.89     0.38    cpu0      1                                           float32[5]  
vm.builtin.check_tensor_info           0.89     0.38    cpu0      1                                        float32[1, 5]  
collapse_sum                           0.86     0.37    cpu0      1                            float32[1, 5], float32[5]  
less                                   0.79     0.33    cpu0      1                            float32[1, 5], bool[1, 5]  
add1                                   0.76     0.32    cpu0      1          float32[1, 5], float32[1, 5], float32[1, 5]  
subtract                               0.74     0.32    cpu0      1          float32[1, 5], float32[1, 5], float32[1, 5]  
relu                                   0.70     0.30    cpu0      1                         float32[1, 5], float32[1, 5]  
vm.builtin.check_tensor_info           0.69     0.30    cpu0      1                                           float32[5]  
sum                                    0.67     0.29    cpu0      1                             float32[1, 5], float32[]  
ones                                   0.56     0.24    cpu0      1                                            float32[]  
----------                                                                                                                
Sum                                   29.02    12.37             26                                                       
Total                                234.51             cpu0      1                                                       

Configuration
-------------
Number of threads: 24
Executor: VM

测试数值一致性#

backbone = _get_backbone()
pred_sinfo = relax.TensorStructInfo((1, 5), "float32")
target_sinfo = relax.TensorStructInfo((1, 5), "float32")

setup_trainer = SetupTrainer(
    MSELoss(reduction="sum"),
    SGD(0.01),
    [pred_sinfo, target_sinfo],
)

train_mod = setup_trainer(backbone)
ex = tvm.compile(train_mod, target)
vm = relax.VirtualMachine(ex, dev)

trainer = Trainer(train_mod, vm, dev, False)
trainer.zero_init_params()

dataset = _make_dataset()
for _ in range(2):
    for x, label in dataset:
        loss = trainer.update(x, label)
np.testing.assert_allclose(loss.numpy(), 3.1974423e-14)

result = trainer.predict(dataset[0][0])
result_expected = np.array([[0, 0, 0.9999998, 0, 0]], np.float32)
np.testing.assert_allclose(result.numpy(), result_expected)

加载导出的参数#

backbone = _get_backbone()
pred_sinfo = relax.TensorStructInfo((1, 5), "float32")
target_sinfo = relax.TensorStructInfo((1, 5), "float32")

setup_trainer = SetupTrainer(
    MSELoss(reduction="sum"),
    SGD(0.01),
    [pred_sinfo, target_sinfo],
)

train_mod = setup_trainer(backbone)
ex = tvm.compile(train_mod, target)
vm = relax.VirtualMachine(ex, dev)

trainer = Trainer(train_mod, vm, dev, False)
trainer.xaiver_uniform_init_params()

dataset = _make_dataset()
for input, label in dataset:
    trainer.update(input, label)

param_dict = trainer.export_params()
assert "w0" in param_dict
assert "b0" in param_dict

trainer1 = Trainer(train_mod, vm, dev, False)
trainer1.load_params(param_dict)

x_sample = dataset[np.random.randint(len(dataset))][0]
np.testing.assert_allclose(
    trainer.predict(x_sample).numpy(), trainer1.predict(x_sample).numpy()
)

测试设置错误#

backbone = _get_backbone()
pred_sinfo = relax.TensorStructInfo((1, 5), "float32")
target_sinfo = relax.TensorStructInfo((1, 5), "float32")

setup_trainer = SetupTrainer(
    MSELoss(reduction="sum"),
    SGD(0.01),
    [pred_sinfo, target_sinfo],
)

train_mod = setup_trainer(backbone)
ex = tvm.compile(train_mod, target)
vm = relax.VirtualMachine(ex, dev)

trainer = Trainer(train_mod, vm, dev, False)

dataset = _make_dataset()
# parameters are not inited
with pytest.raises(TVMError):
    trainer.predict(dataset[0][0])
with pytest.raises(TVMError):
    trainer.update(dataset[0][0], dataset[0][1])