使用 Relax 训练 API 训练模型#

社区对使用 TVM 进行模型训练的关注日益增加。作为 TVM 的新一代图级中间表示(IR),Relax 也需要满足训练模型的需求。

在 Relax 上构建了完整的训练工作流,包含:

  • 基于源码转换的自动微分工具

  • 优化器抽象 及常见优化器实现

  • 损失函数抽象 及常见损失函数

  • 将这些组件整合的易用 训练器 API

这些训练 API 可满足多种需求:

  • 从零开始训练模型:利用 TVM 的编译优势加速训练过程

  • 基于 TVM 在设备端进行模型微调

  • 将训练过程部署到 TVM 支持的各种设备(如 FPGA 和树莓派)

本教程将演示如何通过训练 API:

  1. 使用高层 Trainer API 从头训练模型

  2. 使用底层自动微分、优化器和损失函数 API 进行训练

  3. 深入解析自动微分系统的源码实现

将使用 Fashion MNIST 数据集训练 MLP 模型,该方法同样适用于大多数常见模型。

准备工作#

首先,需要导入必要的依赖项并加载数据集。

import numpy as np
import tvm
from tvm.relax.training.loss import CrossEntropyLoss
from tvm.relax.training.setup_trainer import SetupTrainer
from tvm.relax.training.trainer import Trainer
from tvm import relax
from tvm.script import ir as I, relax as R
from tvm.relax.transform import LegalizeOps
from tvm.relax.training.optimizer import SGD

batch_size = 64

将在 Fashion-MNIST 数据集上训练模型。以下代码使用 torchvision(PyTorch 的计算机视觉库)下载并预处理数据。

请注意,仅使用 PyTorch 进行数据加载。从 PyTorch Dataloader 加载的数据将在训练过程中转换为 NumPy 数组。

import torch
import torch.utils.data
import torchvision
import torchvision.transforms as Tr
import torch.nn.functional as Func

train_data = torchvision.datasets.FashionMNIST(
    root=".temp",
    train=True,
    download=True,
    transform=Tr.Compose([Tr.ToTensor(), Tr.Lambda(torch.flatten)]),
    target_transform=lambda x:Func.one_hot(torch.tensor(x), 10).float()
)
test_data = torchvision.datasets.FashionMNIST(
    root=".temp",
    train=False,
    download=True,
    transform=Tr.Compose([Tr.ToTensor(), Tr.Lambda(torch.flatten)]),
    target_transform=lambda x:Func.one_hot(torch.tensor(x), 10).float()
)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True)
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

从数据加载器中取一个样本来看:Fashion MNIST 数据集中的每个样本都是 28×28 的灰度图像,并属于10种服装类别之一。

import matplotlib.pyplot as plt

img, label = next(iter(train_loader))
img = img[0].reshape(1, 28, 28).numpy()
plt.figure()
plt.imshow(img[0])
plt.colorbar()
plt.grid(False)
plt.show()
print("Class:", class_names[label.argmax()])
../../_images/a997bf53053e26a0197d484ed62d9a81671600110997d44bf16509a80e9bf7d9.png
Class: Ankle boot

模型定义#

将使用三层感知机(Perceptron)进行图像分类。首先需要定义该感知机的主干结构:

@tvm.script.ir_module
class MLP:
    I.module_attrs({"param_num": 6, "state_num": 0})
    @R.function
    def backbone(
        x: R.Tensor((batch_size, 784), "float32"),
        w0: R.Tensor((784, 128), "float32"),
        b0: R.Tensor((128,), "float32"),
        w1: R.Tensor((128, 128), "float32"),
        b1: R.Tensor((128,), "float32"),
        w2: R.Tensor((128, 10), "float32"),
        b2: R.Tensor((10,), "float32"),
    ) -> R.Tensor((batch_size, 10), "float32"):
        with R.dataflow():
            lv0 = R.matmul(x, w0)
            lv1 = R.add(lv0, b0)
            lv2 = R.nn.relu(lv1)
            lv3 = R.matmul(lv2, w1)
            lv4 = R.add(lv3, b1)
            lv5 = R.nn.relu(lv4)
            lv6 = R.matmul(lv5, w2)
            out = R.add(lv6, b2)
            R.output(out)
        return out

方法一:使用训练器 API#

训练器结构#

训练给定模型的更简单方式是使用训练器 API。该 API 提供了参数更新和模型推理的核心接口。

构建训练器时,需要先创建优化器和损失函数。我们只需指定超参数(如学习率、归约方法等)即可完成构建,在此阶段无需提供模型参数。

loss = CrossEntropyLoss(reduction="sum")
opt = SGD(0.01, weight_decay=0.01)

随后,需要构建 SetupTrainer。这是 Trainer 的辅助类,本质上是变换(pass),用于将主干模块转换为完整且规范化的训练器模块。

变换后的模块将包含以下方法:

  • predict: 模型预测方法(由输入模块提供)

  • loss: 计算预测结果与真实标签之间的指定损失

  • loss_adjoint: 计算损失值及参数的伴随梯度

  • update_params: 接收参数、参数梯度和优化器状态作为输入,返回更新后的参数和新优化器状态。该方法包含名为 optim_state 的函数属性,表示指定优化器的初始状态。

构建 SetupTrainer 需要指定以下要素:

  1. 损失函数

  2. 优化器

  3. 模型输出和标签的 struct_info(结构信息)

out_sinfo = relax.TensorStructInfo((batch_size, 10), "float32")
label_sinfo = relax.TensorStructInfo((batch_size, 10), "int64")

setup_trainer = SetupTrainer(loss, opt, [out_sinfo, label_sinfo])

最后一步,引入 TrainerTrainer 是运行时组件,通过 SetupTrainer 配置主干模块结构后构建并运行模块,同时内部维护参数的运行时值。

构建 Trainer 需要指定以下要素:

  1. 主干模块(Backbone)

  2. 参数数量 n

  3. SetupTrainer 实例

主干函数的前 n 个参数将被识别为模型参数,这些参数将在训练过程中被优化器更新。

target = "llvm"
dev = tvm.device(target, 0)
train_mod = setup_trainer(MLP)
ex = tvm.compile(train_mod, target)
vm = relax.VirtualMachine(ex, dev, profile=True)
---------------------------------------------------------------------------
TVMError                                  Traceback (most recent call last)
Cell In[7], line 3
      1 target = "llvm"
      2 dev = tvm.device(target, 0)
----> 3 train_mod = setup_trainer(MLP)
      4 ex = tvm.compile(train_mod, target)
      5 vm = relax.VirtualMachine(ex, dev, profile=True)

File /media/pc/data/lxw/ai/tvm/python/tvm/ir/transform.py:238, in Pass.__call__(self, mod)
    224 def __call__(self, mod):
    225     """Execute the pass. Note that for sequential pass, the dependency among
    226     different passes will be resolved in the backend.
    227 
   (...)
    236         The updated module after applying this pass.
    237     """
--> 238     return _ffi_transform_api.RunPass(self, mod)

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:339, in tvm._ffi._cy3.core.PackedFuncBase.__call__()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:270, in tvm._ffi._cy3.core.FuncCall()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:259, in tvm._ffi._cy3.core.FuncCall3()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/base.pxi:185, in tvm._ffi._cy3.core.CHECK_CALL()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/base.py:468, in raise_last_ffi_error()
    462 # The exception PyObject may contain a large amount of state,
    463 # including all stack frames that may be inspected in a later
    464 # PDB post-mortem.  Therefore, we must make sure to remove the
    465 # underlying PyObject* from the C++ side after we retrieve it.
    466 _LIB.TVMDropLastPythonError()
--> 468 raise py_err

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:56, in tvm._ffi._cy3.core.tvm_callback()

File /media/pc/data/lxw/ai/tvm/python/tvm/ir/transform.py:307, in _wrap_class_module_pass.<locals>.PyModulePass.__init__.<locals>._pass_func(mod, ctx)
    306 def _pass_func(mod, ctx):
--> 307     return inst.transform_module(mod, ctx)

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/training/setup_trainer.py:179, in SetupTrainer.transform_module(self, mod, ctx)
    174 """Transform the backbone module into a trainer module."""
    175 self._check_well_formed(mod)
    177 mod = AppendLoss(
    178     self.BACKBONE_FUNC,
--> 179     self._loss(*self._loss_args),  # type: ignore
    180     self._loss.num_backbone_outputs,
    181     self.BACKBONE_LOSS_FUNC,
    182 )(mod)
    184 # Decompose batch_norm operator, which behaves differently in inference and training stages
    185 mod = DecomposeOpsForInference(self.BACKBONE_FUNC)(mod)

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/training/loss.py:287, in CrossEntropyLoss.__call__(self, predictions, targets, weights)
    285     with bb.dataflow():
    286         logits = bb.emit(log_softmax(predictions))
--> 287         loss = bb.emit_output(
    288             nll_loss(logits, targets, weights, self._reduction, self.ignore_index)
    289         )
    290     bb.emit_func_output(loss)
    292 return bb.get()[self._loss_name]

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/block_builder.py:585, in BlockBuilder.emit_output(self, output, name_hint)
    569 """Emit output for the current dataflow block or function.
    570 
    571 Parameters
   (...)
    582     The return variable which gets bound to the output.
    583 """
    584 output = self._normalize_python_tuple(output)
--> 585 return _ffi_api.BlockBuilderEmitOutput(self, output, name_hint)

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:339, in tvm._ffi._cy3.core.PackedFuncBase.__call__()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:270, in tvm._ffi._cy3.core.FuncCall()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:259, in tvm._ffi._cy3.core.FuncCall3()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/base.pxi:185, in tvm._ffi._cy3.core.CHECK_CALL()

File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:1085, in operator()()
   1083 TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput")
   1084     .set_body_typed([](BlockBuilder builder, const Expr& output, String name_hint) {
-> 1085       return builder->EmitOutput(output, name_hint);
   1086     });
   1087 

File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:273, in tvm::relax::BlockBuilderImpl::EmitOutput(tvm::RelaxExpr, tvm::runtime::String)()
    271   ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block.";
    272 
--> 273   return Emit(output, false, name_hint);
    274 }
    275 

File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:395, in tvm::relax::BlockBuilderImpl::Emit(tvm::RelaxExpr, bool, tvm::runtime::String)()
    393  */
    394 Var Emit(Expr expr, bool is_dataflow, String name_hint) {
--> 395   expr = this->Normalize(expr);
    396 
    397   Var var = CreateVar(is_dataflow, name_hint);

File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:533, in tvm::relax::Normalizer::Normalize(tvm::RelaxExpr const&)()
    531 
    532   Expr Normalize(const Expr& expr) final {
--> 533     Expr normalized = this->VisitExpr(expr);
    534     // Invariant:
    535     // After Normalize: an Expr always have

File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:616, in tvm::relax::Normalizer::VisitExpr(tvm::RelaxExpr const&)()
    614     }
    615   }
--> 616   return ExprFunctor::VisitExpr(expr);
    617 }
    618 

File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:664, in tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)()
    662 
    663     if (!call->struct_info_.defined()) {
--> 664       auto inferred_sinfo = InferStructInfo(call);
    665       UpdateStructInfo(call, inferred_sinfo);
    666     }

File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:847, in tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)()
    845   ICHECK(op_map_infer_struct_info_.count(op))
    846       << " Cannot find the FInferStructInfo attribute registered to op: " << op->name;
--> 847   return op_map_infer_struct_info_[op](call, GetRef<BlockBuilder>(this));
    848 } else {
    849   // derive using function parameters

File /media/pc/data/lxw/ai/tvm/src/relax/op/nn/nn.cc:798, in tvm::relax::InferStructInfoNLLLoss(tvm::relax::Call const&, tvm::relax::BlockBuilder const&)()
    796 int K_tgt = tgt_sinfo->ndim <= 1 ? 0 : tgt_sinfo->ndim - 1;
    797 if (K != kUnknownNDim && K != K_tgt) {
--> 798   ctx->ReportFatal(Diagnostic::Error(call)
    799                    << "NLLLoss expects number of dimensions K inferred from different "
    800                       "arguments to be equal. However, K from predictions is "

File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:157, in tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)()
    155   // continue use the builder after an error is thrown to avoid state building up.
    156   // in an interactive environment.
--> 157   LOG(FATAL) << diagnostic->message;
    158 }
    159 

TVMError: Traceback (most recent call last):
  8: operator()
        at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:1085
  7: tvm::relax::BlockBuilderImpl::EmitOutput(tvm::RelaxExpr, tvm::runtime::String)
        at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:273
  6: tvm::relax::BlockBuilderImpl::Emit(tvm::RelaxExpr, bool, tvm::runtime::String)
        at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:395
  5: tvm::relax::Normalizer::Normalize(tvm::RelaxExpr const&)
        at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:533
  4: tvm::relax::Normalizer::VisitExpr(tvm::RelaxExpr const&)
        at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:616
  3: tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
        at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:664
  2: tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
        at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:847
  1: tvm::relax::InferStructInfoNLLLoss(tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
        at /media/pc/data/lxw/ai/tvm/src/relax/op/nn/nn.cc:798
  0: tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
        at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:157
  File "/media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc", line 157
TVMError: NLLLoss expects number of dimensions K inferred from different arguments to be equal. However, K from predictions is 0 while K from targets is 1
trainer = Trainer(Backbone, 6, setup_trainer)
# build the IRModule in the trainer
trainer.build(target="llvm", device=tvm.cpu(0))
---------------------------------------------------------------------------
InternalError                             Traceback (most recent call last)
Cell In[21], line 1
----> 1 trainer = Trainer(Backbone, 6, setup_trainer)
      2 # build the IRModule in the trainer
      3 trainer.build(target="llvm", device=tvm.cpu(0))

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/training/trainer.py:87, in Trainer.__init__(self, train_mod, vm, device, zero_init_param_state)
     84 self.vm = vm
     85 self.device = device
---> 87 self._optim_state = [d.copyto(device) for d in train_mod.attrs["optim_state"]]
     89 self._input_num = int(train_mod.attrs["input_num"])
     90 self._param_num = int(train_mod.attrs["param_num"])

File /media/pc/data/lxw/ai/tvm/python/tvm/ir/attrs.py:115, in DictAttrs.__getitem__(self, k)
    114 def __getitem__(self, k):
--> 115     return self._dict().__getitem__(k)

File /media/pc/data/lxw/ai/tvm/python/tvm/ir/container.py:62, in Map.__getitem__(self, k)
     61 def __getitem__(self, k):
---> 62     return _ffi_api.MapGetItem(self, k)

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:339, in tvm._ffi._cy3.core.PackedFuncBase.__call__()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:270, in tvm._ffi._cy3.core.FuncCall()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:259, in tvm._ffi._cy3.core.FuncCall3()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/base.pxi:185, in tvm._ffi._cy3.core.CHECK_CALL()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/base.py:468, in raise_last_ffi_error()
    462 # The exception PyObject may contain a large amount of state,
    463 # including all stack frames that may be inspected in a later
    464 # PDB post-mortem.  Therefore, we must make sure to remove the
    465 # underlying PyObject* from the C++ side after we retrieve it.
    466 _LIB.TVMDropLastPythonError()
--> 468 raise py_err

File /media/pc/data/lxw/ai/tvm/src/runtime/container.cc:104, in operator()()
    102 ICHECK_EQ(args[0].type_code(), kTVMObjectHandle);
    103 Object* ptr = static_cast<Object*>(args[0].value().v_handle);
--> 104 ICHECK(ptr->IsInstance<MapNode>());
    105 
    106 auto* n = static_cast<const MapNode*>(ptr);

InternalError: Traceback (most recent call last):
  0: operator()
        at /media/pc/data/lxw/ai/tvm/src/runtime/container.cc:104
  File "/media/pc/data/lxw/ai/tvm/src/runtime/container.cc", line 109
InternalError: Check failed: (it != n->end()) is false: cannot find the corresponding key in the Map

训练流程#

训练器构建完成后,即可在其基础上执行标准训练流程。我们将随机初始化参数,并进行 5 轮(epoch)训练。

Trainer 提供 xaiver_uniform_init_params 方法(注:应为 Xavier Uniform 初始化),用于通过 Xavier 均匀分布初始化所有参数。若需自定义参数初始化,可调用以下方法:

  • trainer.load_params(extern_param_dict: Dict[str, Union[np.ndarray, NDArray]]) 加载预设参数

  • trainer.export_params() -> Dict[str, NDArray] 导出当前参数

update_params 方法将用于参数更新,其内部执行流程如下:

  1. 前向传播:获取模型输出及损失值

  2. 梯度计算:计算参数梯度

  3. 参数更新:根据优化器算法更新参数

  4. 返回损失:将当前损失值返回调用方

predict 方法专为推理设计,接收一批特征数据并返回预测结果(即主干网络的输出)。

trainer.xaiver_uniform_init_params()


epochs = 5
log_interval = 200


for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        loss = trainer.update_params(data.numpy(), target.numpy())

        if batch_idx % log_interval == 0 or batch_idx == len(train_loader):
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
                f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.numpy():.2f}")

    total, correct = 0, 0
    for data, target in test_loader:
        predict = trainer.predict(data.numpy()) # batch_size * 10
        total += len(data)
        correct += np.sum(predict.numpy().argmax(1) == target.numpy().argmax(1))

    print(f"Train Epoch: {epoch} Accuracy on test dataset: {100.0 * correct / total:.2f}%")

为什么需要区分 Trainer 和 SetupTrainer?#

这种设计源于「编译期」与「运行期」的职责分离:

  1. 编译期组件(SetupTrainer 及之前组件):

    • 负责构建完整的计算图(IRModule)

    • 完成所有静态分析与优化

    • 生成可部署的通用计算逻辑

  2. 运行期组件(Trainer):

    • 接收编译期生成的 IRModule

    • 管理模型参数的动态更新

    • 维护训练过程中的临时状态

这种分离架构使 TVM 能够:

  • 在服务器端完成计算图编译优化

  • 将优化后的 IRModule 部署到边缘设备

  • 在资源受限的设备上仅执行必要的参数更新

这正是 TVM 实现「一次编译,到处运行」的关键设计决策。

方法二:使用底层训练 API#

我们也可以通过底层训练 API 直接构建和运行 IRModule。这些 API 主要包括:

  • 损失函数库

  • 优化器库

  • 自动微分过程

损失函数#

TVM 在 tvm.relax.training.loss 模块中提供了丰富的损失函数实现,包括:

  • CrossEntropyLoss(交叉熵损失)

  • L1Loss(L1 损失)

  • MSELoss(均方误差损失)等

您也可以通过继承 tvm.relax.training.loss.Loss 基类来自定义损失函数。

损失类的实例化仅需指定超参数,其 __call__() 方法将接收模型输出和标签的 struct_info(结构信息),并生成对应的 Relax 损失函数:

func = CrossEntropyLoss(reduction="sum")(out_sinfo, label_sinfo)
print(func)

基于自动微分过程的技术要求,我们需要将主干函数与损失函数进行融合。为此,我们提供了 relax.training.utils.append_loss 工具来实现二者的融合:

Backbone["loss"] = relax.training.utils.append_loss(Backbone["predict"], func)
Backbone.show()

梯度计算过程#

为优化模型参数,我们需要计算参数的梯度。TVM 提供了自动微分转换过程 relax.transform.Gradient 来实现梯度计算。

该自动微分(AD)系统是训练工作流的核心,基于源码转换方法实现。当前版本对输入函数有以下限制:

  1. 单数据流块限制:函数必须仅包含一个数据流块

  2. 算子支持限制:仅支持算术运算、元组操作等基础 Relax 算子

Gradient 接收三个关键参数:

  • 目标函数的全局变量名

  • 需要计算梯度的参数变量

  • 输入 IRModule

执行后将返回包含梯度计算逻辑的新 IRModule。

params = Backbone["loss"].params[:6]

Backbone = relax.transform.Gradient(
    Backbone.get_global_var("loss"),
    require_grads=params
)(Backbone)
Backbone.show()

优化器#

TVM 在 relax.training.optimizer 模块中提供了多种经典优化器实现,包括:

  • 基础 SGD

  • 带动量的 SGD

  • Adam 优化器

您也可以通过继承 relax.training.optimizer.Optimizer 基类来实现自定义优化器。

优化器实例的创建仅需指定超参数(如学习率)。通过 init() 方法进行初始化时需传入:

  • 单个 Relax 变量 或

  • Relax 变量列表(计算图中的变量节点)

该方法将完成优化器状态的初始化。初始化后,可通过以下两种方式使用优化器:

  1. 调用 get_function() 获取对应的 Relax 优化函数

  2. 将其关联到现有 IRModule 的计算流程中

opt = relax.optimizer.SGD(0.1).init(params)
Backbone["SGD"] = opt.get_function()
print(Backbone["SGD"])

训练流程#

完成 IRModule 的构建后,即可开始模型训练。我们需要依次执行以下操作:

  1. 对 IRModule 进行规范化处理

  2. 编译生成可执行模块

  3. 准备必要的输入数据:

# Build and legalize module
lowered_mod = LegalizeOps()(Backbone)
ex = relax.vm.build(lowered_mod, target="llvm")
vm = relax.VirtualMachine(ex, tvm.cpu())


def _get_shape_as_int_list(var):
    return [int(val) for val in var.struct_info.shape]

params_list = [tvm.nd.array(np.ones(_get_shape_as_int_list(i), "float32")) for i in params]
param_input_tuple = tuple_object(params_list)

x_input, y_input = next(iter(train_loader))
x_input = tvm.nd.array(x_input)
y_input = tvm.nd.array(y_input)

# The input should be (*param_input_tuple, x_input, y_input)
# At the runtime of TVM, arguments should be TVM NDArray or TVM runtime ADT objects.

本演示仅展示单步训练过程,多步训练逻辑与此类似。

核心组件交互流程

  1. 伴随函数(由自动微分过程生成):

    • 输入:主干网络输入 + 真实标签

    • 输出:损失值 + 参数梯度元组

  2. 优化器函数(由优化器类构建):

    • 输入:参数元组 + 梯度元组 + 优化器状态元组

    • 输出:更新后的参数元组 + 新优化器状态元组

通过 opt.state 可获取优化器状态对象,该状态包含优化过程中的关键信息:

  • 已执行的训练步数(steps)

  • 动量缓存(momentum)

  • 自适应学习率参数(如 Adam 中的一/二阶矩估计)

# forward and find the gradient
loss, param_grad_tuple = vm["loss_adjoint"](*param_input_tuple, x_input, y_input)
# update parameters
param_input_tuple, opt.state = vm["SGD"](param_input_tuple, param_grad_tuple, opt.state)

打印计算结果:

print(loss.numpy)
print(len(param_input_tuple), len(param_grad_tuple))
print(param_input_tuple[0])
print(param_grad_tuple[0])