ML 模型的计算图抽象

ML 模型的计算图抽象#

计算图抽象(graph abstraction)是机器学习(ML)中一项关键技术,它被编译器用来表示和推理 ML 模型的结构和数据流。通过将模型抽象成计算图表示,编译器可以执行各种优化以提高性能和效率。本教程将涵盖计算图抽象的基础知识、其关键元素 Relax IR,以及它如何在 ML 编译器中启用优化。

什么是计算图抽象?#

计算图抽象是将机器学习模型表示为有向图的过程,其中节点代表计算算子(例如矩阵乘法、卷积),而边则代表这些算子之间数据流的流动。这种抽象使编译器能够分析模型不同部分之间的依赖关系和相互联系。

from tvm.script import relax as R

@R.function
def main(
    x: R.Tensor((1, 784), dtype="float32"),
    weight: R.Tensor((784, 256), dtype="float32"),
    bias: R.Tensor((256,), dtype="float32"),
) -> R.Tensor((1, 256), dtype="float32"):
    with R.dataflow():
        lv0 = R.matmul(x, weight)
        lv1 = R.add(lv0, bias)
        gv = R.nn.relu(lv1)
        R.output(gv)
    return gv

Relax的关键特性#

Relax,Apache TVM 的 Unity 策略中使用的计算图表示法,通过几个关键特性实现了机器学习模型的端到端优化。

  • First-class 符号形状:Relax 采用符号形状来表示张量的维度,这允许在张量算子和函数调用中全局跟踪动态形状关系。

  • 多层次抽象:Relax 支持跨级别的抽象,从高层的神经网络层到低层的张量算子,使优化能够跨越模型内的不同层次结构。

  • 可组合变换:Relax 提供了框架,用于实现可组合的“变换”,这些变换可以有选择地应用于不同的模型组件。这包括 partial lowering 和 partial specialization 的功能,提供了灵活的定制和优化选项。

这些特性共同使 Relax 能够在 Apache TVM 生态系统内提供一种强大且灵活的机器学习模型优化方法。