使用 Relax 训练 API 训练模型#
社区对使用 TVM 进行模型训练的关注日益增加。作为 TVM 的新一代图级中间表示(IR),Relax 也需要满足训练模型的需求。
在 Relax 上构建了完整的训练工作流,包含:
基于源码转换的自动微分工具
优化器抽象 及常见优化器实现
损失函数抽象 及常见损失函数
将这些组件整合的易用 训练器 API
这些训练 API 可满足多种需求:
从零开始训练模型:利用 TVM 的编译优势加速训练过程
基于 TVM 在设备端进行模型微调
将训练过程部署到 TVM 支持的各种设备(如 FPGA 和树莓派)
本教程将演示如何通过训练 API:
使用高层 Trainer API 从头训练模型
使用底层自动微分、优化器和损失函数 API 进行训练
深入解析自动微分系统的源码实现
将使用 Fashion MNIST 数据集训练 MLP 模型,该方法同样适用于大多数常见模型。
准备工作#
首先,需要导入必要的依赖项并加载数据集。
import numpy as np
import tvm
from tvm.relax.training.loss import CrossEntropyLoss
from tvm.relax.training.setup_trainer import SetupTrainer
from tvm.relax.training.trainer import Trainer
from tvm import relax
from tvm.script import ir as I, relax as R
from tvm.relax.transform import LegalizeOps
from tvm.relax.training.optimizer import SGD
batch_size = 64
将在 Fashion-MNIST 数据集上训练模型。以下代码使用 torchvision(PyTorch 的计算机视觉库)下载并预处理数据。
请注意,仅使用 PyTorch 进行数据加载。从 PyTorch Dataloader 加载的数据将在训练过程中转换为 NumPy 数组。
import torch
import torch.utils.data
import torchvision
import torchvision.transforms as Tr
import torch.nn.functional as Func
train_data = torchvision.datasets.FashionMNIST(
root=".temp",
train=True,
download=True,
transform=Tr.Compose([Tr.ToTensor(), Tr.Lambda(torch.flatten)]),
target_transform=lambda x:Func.one_hot(torch.tensor(x), 10).float()
)
test_data = torchvision.datasets.FashionMNIST(
root=".temp",
train=False,
download=True,
transform=Tr.Compose([Tr.ToTensor(), Tr.Lambda(torch.flatten)]),
target_transform=lambda x:Func.one_hot(torch.tensor(x), 10).float()
)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True)
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
从数据加载器中取一个样本来看:Fashion MNIST 数据集中的每个样本都是
import matplotlib.pyplot as plt
img, label = next(iter(train_loader))
img = img[0].reshape(1, 28, 28).numpy()
plt.figure()
plt.imshow(img[0])
plt.colorbar()
plt.grid(False)
plt.show()
print("Class:", class_names[label.argmax()])

Class: Ankle boot
模型定义#
将使用三层感知机(Perceptron)进行图像分类。首先需要定义该感知机的主干结构:
@tvm.script.ir_module
class MLP:
I.module_attrs({"param_num": 6, "state_num": 0})
@R.function
def backbone(
x: R.Tensor((batch_size, 784), "float32"),
w0: R.Tensor((784, 128), "float32"),
b0: R.Tensor((128,), "float32"),
w1: R.Tensor((128, 128), "float32"),
b1: R.Tensor((128,), "float32"),
w2: R.Tensor((128, 10), "float32"),
b2: R.Tensor((10,), "float32"),
) -> R.Tensor((batch_size, 10), "float32"):
with R.dataflow():
lv0 = R.matmul(x, w0)
lv1 = R.add(lv0, b0)
lv2 = R.nn.relu(lv1)
lv3 = R.matmul(lv2, w1)
lv4 = R.add(lv3, b1)
lv5 = R.nn.relu(lv4)
lv6 = R.matmul(lv5, w2)
out = R.add(lv6, b2)
R.output(out)
return out
方法一:使用训练器 API#
训练器结构#
训练给定模型的更简单方式是使用训练器 API。该 API 提供了参数更新和模型推理的核心接口。
构建训练器时,需要先创建优化器和损失函数。我们只需指定超参数(如学习率、归约方法等)即可完成构建,在此阶段无需提供模型参数。
loss = CrossEntropyLoss(reduction="sum")
opt = SGD(0.01, weight_decay=0.01)
随后,需要构建 SetupTrainer
。这是 Trainer
的辅助类,本质上是变换(pass),用于将主干模块转换为完整且规范化的训练器模块。
变换后的模块将包含以下方法:
predict
: 模型预测方法(由输入模块提供)loss
: 计算预测结果与真实标签之间的指定损失loss_adjoint
: 计算损失值及参数的伴随梯度update_params
: 接收参数、参数梯度和优化器状态作为输入,返回更新后的参数和新优化器状态。该方法包含名为optim_state
的函数属性,表示指定优化器的初始状态。
构建 SetupTrainer
需要指定以下要素:
损失函数
优化器
模型输出和标签的
struct_info
(结构信息)
out_sinfo = relax.TensorStructInfo((batch_size, 10), "float32")
label_sinfo = relax.TensorStructInfo((batch_size, 10), "int64")
setup_trainer = SetupTrainer(loss, opt, [out_sinfo, label_sinfo])
最后一步,引入 Trainer
。Trainer
是运行时组件,通过 SetupTrainer 配置主干模块结构后构建并运行模块,同时内部维护参数的运行时值。
构建 Trainer 需要指定以下要素:
主干模块(Backbone)
参数数量
SetupTrainer
实例
主干函数的前
target = "llvm"
dev = tvm.device(target, 0)
train_mod = setup_trainer(MLP)
ex = tvm.compile(train_mod, target)
vm = relax.VirtualMachine(ex, dev, profile=True)
---------------------------------------------------------------------------
TVMError Traceback (most recent call last)
Cell In[7], line 3
1 target = "llvm"
2 dev = tvm.device(target, 0)
----> 3 train_mod = setup_trainer(MLP)
4 ex = tvm.compile(train_mod, target)
5 vm = relax.VirtualMachine(ex, dev, profile=True)
File /media/pc/data/lxw/ai/tvm/python/tvm/ir/transform.py:238, in Pass.__call__(self, mod)
224 def __call__(self, mod):
225 """Execute the pass. Note that for sequential pass, the dependency among
226 different passes will be resolved in the backend.
227
(...)
236 The updated module after applying this pass.
237 """
--> 238 return _ffi_transform_api.RunPass(self, mod)
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:339, in tvm._ffi._cy3.core.PackedFuncBase.__call__()
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:270, in tvm._ffi._cy3.core.FuncCall()
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:259, in tvm._ffi._cy3.core.FuncCall3()
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/base.pxi:185, in tvm._ffi._cy3.core.CHECK_CALL()
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/base.py:468, in raise_last_ffi_error()
462 # The exception PyObject may contain a large amount of state,
463 # including all stack frames that may be inspected in a later
464 # PDB post-mortem. Therefore, we must make sure to remove the
465 # underlying PyObject* from the C++ side after we retrieve it.
466 _LIB.TVMDropLastPythonError()
--> 468 raise py_err
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:56, in tvm._ffi._cy3.core.tvm_callback()
File /media/pc/data/lxw/ai/tvm/python/tvm/ir/transform.py:307, in _wrap_class_module_pass.<locals>.PyModulePass.__init__.<locals>._pass_func(mod, ctx)
306 def _pass_func(mod, ctx):
--> 307 return inst.transform_module(mod, ctx)
File /media/pc/data/lxw/ai/tvm/python/tvm/relax/training/setup_trainer.py:179, in SetupTrainer.transform_module(self, mod, ctx)
174 """Transform the backbone module into a trainer module."""
175 self._check_well_formed(mod)
177 mod = AppendLoss(
178 self.BACKBONE_FUNC,
--> 179 self._loss(*self._loss_args), # type: ignore
180 self._loss.num_backbone_outputs,
181 self.BACKBONE_LOSS_FUNC,
182 )(mod)
184 # Decompose batch_norm operator, which behaves differently in inference and training stages
185 mod = DecomposeOpsForInference(self.BACKBONE_FUNC)(mod)
File /media/pc/data/lxw/ai/tvm/python/tvm/relax/training/loss.py:287, in CrossEntropyLoss.__call__(self, predictions, targets, weights)
285 with bb.dataflow():
286 logits = bb.emit(log_softmax(predictions))
--> 287 loss = bb.emit_output(
288 nll_loss(logits, targets, weights, self._reduction, self.ignore_index)
289 )
290 bb.emit_func_output(loss)
292 return bb.get()[self._loss_name]
File /media/pc/data/lxw/ai/tvm/python/tvm/relax/block_builder.py:585, in BlockBuilder.emit_output(self, output, name_hint)
569 """Emit output for the current dataflow block or function.
570
571 Parameters
(...)
582 The return variable which gets bound to the output.
583 """
584 output = self._normalize_python_tuple(output)
--> 585 return _ffi_api.BlockBuilderEmitOutput(self, output, name_hint)
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:339, in tvm._ffi._cy3.core.PackedFuncBase.__call__()
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:270, in tvm._ffi._cy3.core.FuncCall()
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:259, in tvm._ffi._cy3.core.FuncCall3()
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/base.pxi:185, in tvm._ffi._cy3.core.CHECK_CALL()
File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:1085, in operator()()
1083 TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput")
1084 .set_body_typed([](BlockBuilder builder, const Expr& output, String name_hint) {
-> 1085 return builder->EmitOutput(output, name_hint);
1086 });
1087
File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:273, in tvm::relax::BlockBuilderImpl::EmitOutput(tvm::RelaxExpr, tvm::runtime::String)()
271 ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block.";
272
--> 273 return Emit(output, false, name_hint);
274 }
275
File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:395, in tvm::relax::BlockBuilderImpl::Emit(tvm::RelaxExpr, bool, tvm::runtime::String)()
393 */
394 Var Emit(Expr expr, bool is_dataflow, String name_hint) {
--> 395 expr = this->Normalize(expr);
396
397 Var var = CreateVar(is_dataflow, name_hint);
File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:533, in tvm::relax::Normalizer::Normalize(tvm::RelaxExpr const&)()
531
532 Expr Normalize(const Expr& expr) final {
--> 533 Expr normalized = this->VisitExpr(expr);
534 // Invariant:
535 // After Normalize: an Expr always have
File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:616, in tvm::relax::Normalizer::VisitExpr(tvm::RelaxExpr const&)()
614 }
615 }
--> 616 return ExprFunctor::VisitExpr(expr);
617 }
618
File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:664, in tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)()
662
663 if (!call->struct_info_.defined()) {
--> 664 auto inferred_sinfo = InferStructInfo(call);
665 UpdateStructInfo(call, inferred_sinfo);
666 }
File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:847, in tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)()
845 ICHECK(op_map_infer_struct_info_.count(op))
846 << " Cannot find the FInferStructInfo attribute registered to op: " << op->name;
--> 847 return op_map_infer_struct_info_[op](call, GetRef<BlockBuilder>(this));
848 } else {
849 // derive using function parameters
File /media/pc/data/lxw/ai/tvm/src/relax/op/nn/nn.cc:798, in tvm::relax::InferStructInfoNLLLoss(tvm::relax::Call const&, tvm::relax::BlockBuilder const&)()
796 int K_tgt = tgt_sinfo->ndim <= 1 ? 0 : tgt_sinfo->ndim - 1;
797 if (K != kUnknownNDim && K != K_tgt) {
--> 798 ctx->ReportFatal(Diagnostic::Error(call)
799 << "NLLLoss expects number of dimensions K inferred from different "
800 "arguments to be equal. However, K from predictions is "
File /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:157, in tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)()
155 // continue use the builder after an error is thrown to avoid state building up.
156 // in an interactive environment.
--> 157 LOG(FATAL) << diagnostic->message;
158 }
159
TVMError: Traceback (most recent call last):
8: operator()
at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:1085
7: tvm::relax::BlockBuilderImpl::EmitOutput(tvm::RelaxExpr, tvm::runtime::String)
at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:273
6: tvm::relax::BlockBuilderImpl::Emit(tvm::RelaxExpr, bool, tvm::runtime::String)
at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:395
5: tvm::relax::Normalizer::Normalize(tvm::RelaxExpr const&)
at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:533
4: tvm::relax::Normalizer::VisitExpr(tvm::RelaxExpr const&)
at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:616
3: tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:664
2: tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:847
1: tvm::relax::InferStructInfoNLLLoss(tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
at /media/pc/data/lxw/ai/tvm/src/relax/op/nn/nn.cc:798
0: tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
at /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:157
File "/media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc", line 157
TVMError: NLLLoss expects number of dimensions K inferred from different arguments to be equal. However, K from predictions is 0 while K from targets is 1
trainer = Trainer(Backbone, 6, setup_trainer)
# build the IRModule in the trainer
trainer.build(target="llvm", device=tvm.cpu(0))
---------------------------------------------------------------------------
InternalError Traceback (most recent call last)
Cell In[21], line 1
----> 1 trainer = Trainer(Backbone, 6, setup_trainer)
2 # build the IRModule in the trainer
3 trainer.build(target="llvm", device=tvm.cpu(0))
File /media/pc/data/lxw/ai/tvm/python/tvm/relax/training/trainer.py:87, in Trainer.__init__(self, train_mod, vm, device, zero_init_param_state)
84 self.vm = vm
85 self.device = device
---> 87 self._optim_state = [d.copyto(device) for d in train_mod.attrs["optim_state"]]
89 self._input_num = int(train_mod.attrs["input_num"])
90 self._param_num = int(train_mod.attrs["param_num"])
File /media/pc/data/lxw/ai/tvm/python/tvm/ir/attrs.py:115, in DictAttrs.__getitem__(self, k)
114 def __getitem__(self, k):
--> 115 return self._dict().__getitem__(k)
File /media/pc/data/lxw/ai/tvm/python/tvm/ir/container.py:62, in Map.__getitem__(self, k)
61 def __getitem__(self, k):
---> 62 return _ffi_api.MapGetItem(self, k)
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:339, in tvm._ffi._cy3.core.PackedFuncBase.__call__()
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:270, in tvm._ffi._cy3.core.FuncCall()
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:259, in tvm._ffi._cy3.core.FuncCall3()
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/base.pxi:185, in tvm._ffi._cy3.core.CHECK_CALL()
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/base.py:468, in raise_last_ffi_error()
462 # The exception PyObject may contain a large amount of state,
463 # including all stack frames that may be inspected in a later
464 # PDB post-mortem. Therefore, we must make sure to remove the
465 # underlying PyObject* from the C++ side after we retrieve it.
466 _LIB.TVMDropLastPythonError()
--> 468 raise py_err
File /media/pc/data/lxw/ai/tvm/src/runtime/container.cc:104, in operator()()
102 ICHECK_EQ(args[0].type_code(), kTVMObjectHandle);
103 Object* ptr = static_cast<Object*>(args[0].value().v_handle);
--> 104 ICHECK(ptr->IsInstance<MapNode>());
105
106 auto* n = static_cast<const MapNode*>(ptr);
InternalError: Traceback (most recent call last):
0: operator()
at /media/pc/data/lxw/ai/tvm/src/runtime/container.cc:104
File "/media/pc/data/lxw/ai/tvm/src/runtime/container.cc", line 109
InternalError: Check failed: (it != n->end()) is false: cannot find the corresponding key in the Map
训练流程#
训练器构建完成后,即可在其基础上执行标准训练流程。我们将随机初始化参数,并进行 5 轮(epoch)训练。
Trainer
提供 xaiver_uniform_init_params
方法(注:应为 Xavier Uniform 初始化),用于通过 Xavier 均匀分布初始化所有参数。若需自定义参数初始化,可调用以下方法:
trainer.load_params(extern_param_dict: Dict[str, Union[np.ndarray, NDArray]])
加载预设参数trainer.export_params() -> Dict[str, NDArray]
导出当前参数
update_params
方法将用于参数更新,其内部执行流程如下:
前向传播:获取模型输出及损失值
梯度计算:计算参数梯度
参数更新:根据优化器算法更新参数
返回损失:将当前损失值返回调用方
predict
方法专为推理设计,接收一批特征数据并返回预测结果(即主干网络的输出)。
trainer.xaiver_uniform_init_params()
epochs = 5
log_interval = 200
for epoch in range(epochs):
for batch_idx, (data, target) in enumerate(train_loader):
loss = trainer.update_params(data.numpy(), target.numpy())
if batch_idx % log_interval == 0 or batch_idx == len(train_loader):
print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.numpy():.2f}")
total, correct = 0, 0
for data, target in test_loader:
predict = trainer.predict(data.numpy()) # batch_size * 10
total += len(data)
correct += np.sum(predict.numpy().argmax(1) == target.numpy().argmax(1))
print(f"Train Epoch: {epoch} Accuracy on test dataset: {100.0 * correct / total:.2f}%")
为什么需要区分 Trainer 和 SetupTrainer?#
这种设计源于「编译期」与「运行期」的职责分离:
编译期组件(SetupTrainer 及之前组件):
负责构建完整的计算图(IRModule)
完成所有静态分析与优化
生成可部署的通用计算逻辑
运行期组件(Trainer):
接收编译期生成的 IRModule
管理模型参数的动态更新
维护训练过程中的临时状态
这种分离架构使 TVM 能够:
在服务器端完成计算图编译优化
将优化后的 IRModule 部署到边缘设备
在资源受限的设备上仅执行必要的参数更新
这正是 TVM 实现「一次编译,到处运行」的关键设计决策。
方法二:使用底层训练 API#
我们也可以通过底层训练 API 直接构建和运行 IRModule。这些 API 主要包括:
损失函数库
优化器库
自动微分过程
损失函数#
TVM 在 tvm.relax.training.loss
模块中提供了丰富的损失函数实现,包括:
CrossEntropyLoss
(交叉熵损失)L1Loss
(L1 损失)MSELoss
(均方误差损失)等
您也可以通过继承 tvm.relax.training.loss.Loss
基类来自定义损失函数。
损失类的实例化仅需指定超参数,其 __call__()
方法将接收模型输出和标签的 struct_info(结构信息),并生成对应的 Relax 损失函数:
func = CrossEntropyLoss(reduction="sum")(out_sinfo, label_sinfo)
print(func)
基于自动微分过程的技术要求,我们需要将主干函数与损失函数进行融合。为此,我们提供了 relax.training.utils.append_loss
工具来实现二者的融合:
Backbone["loss"] = relax.training.utils.append_loss(Backbone["predict"], func)
Backbone.show()
梯度计算过程#
为优化模型参数,我们需要计算参数的梯度。TVM 提供了自动微分转换过程 relax.transform.Gradient
来实现梯度计算。
该自动微分(AD)系统是训练工作流的核心,基于源码转换方法实现。当前版本对输入函数有以下限制:
单数据流块限制:函数必须仅包含一个数据流块
算子支持限制:仅支持算术运算、元组操作等基础 Relax 算子
Gradient
接收三个关键参数:
目标函数的全局变量名
需要计算梯度的参数变量
输入 IRModule
执行后将返回包含梯度计算逻辑的新 IRModule。
params = Backbone["loss"].params[:6]
Backbone = relax.transform.Gradient(
Backbone.get_global_var("loss"),
require_grads=params
)(Backbone)
Backbone.show()
优化器#
TVM 在 relax.training.optimizer
模块中提供了多种经典优化器实现,包括:
基础 SGD
带动量的 SGD
Adam 优化器
您也可以通过继承 relax.training.optimizer.Optimizer
基类来实现自定义优化器。
优化器实例的创建仅需指定超参数(如学习率)。通过 init()
方法进行初始化时需传入:
单个 Relax 变量 或
Relax 变量列表(计算图中的变量节点)
该方法将完成优化器状态的初始化。初始化后,可通过以下两种方式使用优化器:
调用
get_function()
获取对应的 Relax 优化函数将其关联到现有 IRModule 的计算流程中
opt = relax.optimizer.SGD(0.1).init(params)
Backbone["SGD"] = opt.get_function()
print(Backbone["SGD"])
训练流程#
完成 IRModule 的构建后,即可开始模型训练。我们需要依次执行以下操作:
对 IRModule 进行规范化处理
编译生成可执行模块
准备必要的输入数据:
# Build and legalize module
lowered_mod = LegalizeOps()(Backbone)
ex = relax.vm.build(lowered_mod, target="llvm")
vm = relax.VirtualMachine(ex, tvm.cpu())
def _get_shape_as_int_list(var):
return [int(val) for val in var.struct_info.shape]
params_list = [tvm.nd.array(np.ones(_get_shape_as_int_list(i), "float32")) for i in params]
param_input_tuple = tuple_object(params_list)
x_input, y_input = next(iter(train_loader))
x_input = tvm.nd.array(x_input)
y_input = tvm.nd.array(y_input)
# The input should be (*param_input_tuple, x_input, y_input)
# At the runtime of TVM, arguments should be TVM NDArray or TVM runtime ADT objects.
本演示仅展示单步训练过程,多步训练逻辑与此类似。
核心组件交互流程:
伴随函数(由自动微分过程生成):
输入:主干网络输入 + 真实标签
输出:损失值 + 参数梯度元组
优化器函数(由优化器类构建):
输入:参数元组 + 梯度元组 + 优化器状态元组
输出:更新后的参数元组 + 新优化器状态元组
通过 opt.state
可获取优化器状态对象,该状态包含优化过程中的关键信息:
已执行的训练步数(steps)
动量缓存(momentum)
自适应学习率参数(如 Adam 中的一/二阶矩估计)
# forward and find the gradient
loss, param_grad_tuple = vm["loss_adjoint"](*param_input_tuple, x_input, y_input)
# update parameters
param_input_tuple, opt.state = vm["SGD"](param_input_tuple, param_grad_tuple, opt.state)
打印计算结果:
print(loss.numpy)
print(len(param_input_tuple), len(param_grad_tuple))
print(param_input_tuple[0])
print(param_grad_tuple[0])