IRModule#

Apache TVM Unity 的核心抽象,即 IRModule。IRModule 包含整个 ML 模型,包括 计算图、张量程序和对外部库的潜在调用。

import os
os.environ['PATH'] += ':/usr/local/cuda/bin' # 保证 nvcc 可以被找到
import numpy as np
import tvm
from tvm import relax

创建 IRModule#

IRModules 可以通过多种方式进行初始化。

从现有前端模型导入#

初始化 IRModule 的最常见方法是从现有模型导入。Apache TVM Unity 支持从一系列框架导入,如 PyTorch 和 ONNX。

import torch
from torch import nn
from torch.export import export
from tvm.relax.frontend.torch import from_exported_program
# Create a dummy model
class TorchModel(nn.Module):
    def __init__(self):
        super(TorchModel, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


# Give an example argument to torch.export
example_args = (torch.randn(1, 784, dtype=torch.float32),)

# Convert the model to IRModule
with torch.no_grad():
    exported_program = export(TorchModel().eval(), example_args)
    mod_from_torch = from_exported_program(
        exported_program, keep_params_as_input=True, unwrap_unit_return_tuple=True
    )

mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch)
# Print the IRModule
mod_from_torch.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(p_fc1_weight, axes=[1, 0])
            lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(x, lv, out_dtype="float32")
            lv2: R.Tensor((1, 256), dtype="float32") = R.add(p_fc1_bias, lv1)
            lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
            lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(p_fc2_weight, axes=[1, 0])
            lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
            lv6: R.Tensor((1, 10), dtype="float32") = R.add(p_fc2_bias, lv5)
            gv: R.Tensor((1, 10), dtype="float32") = lv6
            R.output(gv)
        return gv

使用 Relax NN 模块编写#

Apache TVM Unity还提供了一系列类似 PyTorch 的 API,帮助用户直接编写 IRModule。

from tvm.relax.frontend import nn


class RelaxModel(nn.Module):
    def __init__(self):
        super(RelaxModel, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


mod_from_relax, params_from_relax = RelaxModel().export_tvm(
    {"forward": {"x": nn.spec.Tensor((1, 784), "float32")}}
)
mod_from_relax.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
            gv: R.Tensor((1, 10), dtype="float32") = add1
            R.output(gv)
        return gv

通过 TVMScript 创建#

TVMScript 是一种基于 Python 的 DSL,用于 IRModule。我们可以直接以 TVMScript 语法输出 IRModule,或者解析 TVMScript 以获取 IRModule。

from tvm.script import ir as I
from tvm.script import relax as R


@I.ir_module
class TVMScriptModule:
    @R.function
    def main(
        x: R.Tensor((1, 784), dtype="float32"),
        fc1_weight: R.Tensor((256, 784), dtype="float32"),
        fc1_bias: R.Tensor((256,), dtype="float32"),
        fc2_weight: R.Tensor((10, 256), dtype="float32"),
        fc2_bias: R.Tensor((10,), dtype="float32"),
    ) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims = R.permute_dims(fc1_weight, axes=None)
            matmul = R.matmul(x, permute_dims, out_dtype="void")
            add = R.add(matmul, fc1_bias)
            relu = R.nn.relu(add)
            permute_dims1 = R.permute_dims(fc2_weight, axes=None)
            matmul1 = R.matmul(relu, permute_dims1, out_dtype="void")
            add1 = R.add(matmul1, fc2_bias)
            gv = add1
            R.output(gv)
        return gv


mod_from_script = TVMScriptModule
mod_from_script.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
            gv: R.Tensor((1, 10), dtype="float32") = add1
            R.output(gv)
        return gv

IRModule的属性#

IRModule 是一组函数的集合,通过 GlobalVars 索引。

mod = mod_from_torch
print(mod.get_global_vars())
[I.GlobalVar("main")]

我们可以通过使用 GlobalVars 或它们的名称来索引 IRModule 中的函数。

# index by global var name
print(mod["main"])
# index by global var, and checking they are the same function
(gv,) = mod.get_global_vars()
assert mod[gv] == mod["main"]

Hide code cell output

# from tvm.script import relax as R

@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
    R.func_attr({"num_input": 1})
    with R.dataflow():
        lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(p_fc1_weight, axes=[1, 0])
        lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(x, lv, out_dtype="float32")
        lv2: R.Tensor((1, 256), dtype="float32") = R.add(p_fc1_bias, lv1)
        lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
        lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(p_fc2_weight, axes=[1, 0])
        lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
        lv6: R.Tensor((1, 10), dtype="float32") = R.add(p_fc2_bias, lv5)
        gv: R.Tensor((1, 10), dtype="float32") = lv6
        R.output(gv)
    return gv

IRModule 上的变换#

变换是 Apache TVM Unity 的重要组成部分。一个变换接受一个 IRModule 并输出另一个 IRModule。我们可以将一系列变换应用于一个 IRModule 以获得一个新的 IRModule。这是优化模型的常见方法。

有关每个变换的详细信息,请参阅变换 API 参考

首先对 IRModule 应用 LegalizeOps 变换。此变换将 Relax 模块转换为混合阶段,同一个模块内包 Relax 和 TensorIR 函数。同时,Relax 算子将被转换为 call_tir

mod = mod_from_torch
mod = relax.transform.LegalizeOps()(mod)
mod.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(p_fc1_bias: T.Buffer((T.int64(256),), "float32"), lv1: T.Buffer((T.int64(1), T.int64(256)), "float32"), T_add: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(p_fc1_bias[v_ax1], lv1[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = p_fc1_bias[v_ax1] + lv1[v_ax0, v_ax1]

    @T.prim_func(private=True)
    def add1(p_fc2_bias: T.Buffer((T.int64(10),), "float32"), lv5: T.Buffer((T.int64(1), T.int64(10)), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(p_fc2_bias[v_ax1], lv5[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = p_fc2_bias[v_ax1] + lv5[v_ax0, v_ax1]

    @T.prim_func(private=True)
    def matmul(x: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(x[v_i0, v_k], lv[v_k, v_i1])
                T.writes(matmul[v_i0, v_i1])
                with T.init():
                    matmul[v_i0, v_i1] = T.float32(0.0)
                matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * lv[v_k, v_i1]

    @T.prim_func(private=True)
    def matmul1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
                T.writes(matmul[v_i0, v_i1])
                with T.init():
                    matmul[v_i0, v_i1] = T.float32(0.0)
                matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]

    @T.prim_func(private=True)
    def relu(lv2: T.Buffer((T.int64(1), T.int64(256)), "float32"), compute: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for i0, i1 in T.grid(T.int64(1), T.int64(256)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(lv2[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.max(lv2[v_i0, v_i1], T.float32(0.0))

    @T.prim_func(private=True)
    def transpose(p_fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(p_fc1_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = p_fc1_weight[v_ax1, v_ax0]

    @T.prim_func(private=True)
    def transpose1(p_fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(p_fc2_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = p_fc2_weight[v_ax1, v_ax0]

    @R.function
    def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            lv = R.call_tir(cls.transpose, (p_fc1_weight,), out_sinfo=R.Tensor((784, 256), dtype="float32"))
            lv1 = R.call_tir(cls.matmul, (x, lv), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            lv2 = R.call_tir(cls.add, (p_fc1_bias, lv1), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            lv3 = R.call_tir(cls.relu, (lv2,), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            lv4 = R.call_tir(cls.transpose1, (p_fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
            lv5 = R.call_tir(cls.matmul1, (lv3, lv4), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            lv6 = R.call_tir(cls.add1, (p_fc2_bias, lv5), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            gv: R.Tensor((1, 10), dtype="float32") = lv6
            R.output(gv)
        return gv

变换后,模块内会有更多的函数。再次打印全局变量。

print(mod.get_global_vars())
[I.GlobalVar("add"), I.GlobalVar("add1"), I.GlobalVar("main"), I.GlobalVar("matmul"), I.GlobalVar("matmul1"), I.GlobalVar("relu"), I.GlobalVar("transpose"), I.GlobalVar("transpose1")]

Apache TVM Unity 为用户提供了一组默认的变换管道,以简化变换过程。然后我们可以将默认管道应用于模块。默认的零管道包含一些非常基础的变换,包括:

  • LegalizeOps:此变换将 Relax 算子转换为具有相应 TensorIR 函数的 call_tir 函数。在此变换之后,IRModule 将包含 Relax 函数和 TensorIR 函数。

  • AnnotateTIROpPattern:此变换注解 TensorIR 函数的模式,为后续的算子融合做准备。

  • FoldConstant:此 pass 执行常量折叠,优化涉及常量的运算。

  • FuseOpsFuseTIR:这两个传递基于上一步(AnnotateTIROpPattern)中注解的模式共同工作以融合算子。这些传递转换 Relax 函数和 TensorIR 函数。

备注

在这里,我们在流程中应用了两次 LegalizeOps。第二次是多余的,但无害。

每个传递都可以在流程中重复,因为我们确保传递可以处理所有合法的 IRModule 输入。这种设计可以帮助用户构建他们自己的管道。

mod = relax.get_pipeline("zero")(mod)
mod.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def fused_matmul1_add1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), p_fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(10)))
        for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
                T.writes(matmul_intermediate[v_i0, v_i1])
                with T.init():
                    matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(p_fc2_bias[v_ax1], matmul_intermediate[v_ax0, v_ax1])
                T.writes(T_add_intermediate[v_ax0, v_ax1])
                T_add_intermediate[v_ax0, v_ax1] = p_fc2_bias[v_ax1] + matmul_intermediate[v_ax0, v_ax1]

    @T.prim_func(private=True)
    def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), p_fc1_bias: T.Buffer((T.int64(256),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
        T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
        for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(x[v_i0, v_k], lv[v_k, v_i1])
                T.writes(matmul_intermediate[v_i0, v_i1])
                with T.init():
                    matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + x[v_i0, v_k] * lv[v_k, v_i1]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(p_fc1_bias[v_ax1], matmul_intermediate[v_ax0, v_ax1])
                T.writes(T_add_intermediate[v_ax0, v_ax1])
                T_add_intermediate[v_ax0, v_ax1] = p_fc1_bias[v_ax1] + matmul_intermediate[v_ax0, v_ax1]
        for i0, i1 in T.grid(T.int64(1), T.int64(256)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_add_intermediate[v_i0, v_i1])
                T.writes(compute_intermediate[v_i0, v_i1])
                compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0.0))

    @T.prim_func(private=True)
    def transpose(p_fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": True})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(p_fc1_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = p_fc1_weight[v_ax1, v_ax0]

    @T.prim_func(private=True)
    def transpose1(p_fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": True})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(p_fc2_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = p_fc2_weight[v_ax1, v_ax0]

    @R.function
    def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            lv = R.call_tir(cls.transpose, (p_fc1_weight,), out_sinfo=R.Tensor((784, 256), dtype="float32"))
            lv_1 = R.call_tir(cls.fused_matmul_add_relu, (x, lv, p_fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            lv4 = R.call_tir(cls.transpose1, (p_fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
            gv = R.call_tir(cls.fused_matmul1_add1, (lv_1, lv4, p_fc2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            R.output(gv)
        return gv

通用部署IRModule#

优化完成后,我们可以将模型编译为 TVM 运行时模块。值得注意的是,Apache TVM Unity 提供了通用部署的能力,这意味着可以在不同的后端(包括 CPU、GPU 和其他新兴后端)上部署相同的 IRModule。

在 CPU 上部署#

我们可以通过将目标指定为 llvm 来在 CPU 上部署 IRModule。

exec = tvm.compile(mod, target="llvm")
dev = tvm.cpu()
vm = relax.VirtualMachine(exec, dev)

raw_data = np.random.rand(1, 784).astype("float32")
data = tvm.runtime.tensor(raw_data, dev)
cpu_out = vm["main"](data, *params_from_torch["main"]).numpy()
print(cpu_out)
[[-0.00638016  0.03848503  0.06300074 -0.19170322  0.0904716  -0.0460794
  -0.06462531 -0.03650295 -0.04207093  0.14024206]]

在 GPU 上部署#

除了 CPU 后端,还可以在其他后端上部署 IRModule。例如,可以将 IRModule 部署在 GPU 上。GPU 需要包含额外信息的程序,如线程绑定和共享内存分配。需要进一步的转换来生成 GPU 程序。

使用 DLight 来生成 GPU 程序。

from tvm import dlight as dl

with tvm.target.Target("cuda"):
    gpu_mod = dl.ApplyDefaultSchedule(
        dl.gpu.Matmul(),
        dl.gpu.Fallback(),
    )(mod)

现在可以像在 CPU 上那样,在 GPU 上编译 IRModule。

exec = tvm.compile(gpu_mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(exec, dev)
# Need to allocate data and params on GPU device
data = tvm.runtime.tensor(raw_data, dev)
gpu_params = [tvm.runtime.tensor(p, dev) for p in params_from_torch["main"]]
gpu_out = vm["main"](data, *gpu_params).numpy()
print(gpu_out)

# Check the correctness of the results
assert np.allclose(cpu_out, gpu_out, atol=1e-3)

Hide code cell output

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[13], line 1
----> 1 exec = tvm.compile(gpu_mod, target="cuda")
      2 dev = tvm.device("cuda", 0)
      3 vm = relax.VirtualMachine(exec, dev)

File D:\AI\client\hub\tvm\python\tvm\driver\build_module.py:104, in compile(mod, target, relax_pipeline, tir_pipeline)
    102 # TODO(tvm-team): combine two path into unified one
    103 if _contains_relax(mod):
--> 104     return tvm.relax.build(
    105         mod,
    106         target,
    107         relax_pipeline=relax_pipeline,
    108         tir_pipeline=tir_pipeline,
    109     )
    110 lib = tvm.tir.build(mod, target, pipeline=tir_pipeline)
    111 return Executable(lib)

File D:\AI\client\hub\tvm\python\tvm\relax\vm_build.py:263, in build(mod, target, params, relax_pipeline, tir_pipeline, exec_mode, system_lib)
    261 builder = relax.ExecBuilder()
    262 mod = _vmcodegen(builder, mod, exec_mode)
--> 263 return _vmlink(
    264     builder=builder,
    265     target=target,
    266     tir_mod=_filter_tir(mod),
    267     tir_pipeline=tir_pipeline,
    268     ext_libs=ext_libs,
    269     params=params,
    270     system_lib=system_lib,
    271 )

File D:\AI\client\hub\tvm\python\tvm\relax\vm_build.py:158, in _vmlink(builder, target, tir_mod, tir_pipeline, ext_libs, params, system_lib)
    156 if tir_mod is not None and len(tir_mod.get_global_vars()) > 0:
    157     tir_mod = _auto_attach_system_lib_prefix(tir_mod, target, system_lib)
--> 158     lib = tvm.tir.build(tir_mod, target=target, pipeline=tir_pipeline)
    159 for ext_mod in ext_libs:
    160     if _is_device_module(ext_mod):

File D:\AI\client\hub\tvm\python\tvm\tir\build.py:239, in build(mod, target, pipeline)
    233 device_mod_dict = {
    234     target: tvm.tir.pipeline.finalize_device_passes()(device_mod)
    235     for target, device_mod in device_mod_dict.items()
    236 }
    238 # Convert TIR IRModules to runtime Module by calling target.build
--> 239 return tir_to_runtime(host_mod, device_mod_dict, target_host)

File D:\AI\client\hub\tvm\python\tvm\tir\build.py:147, in tir_to_runtime(host_mod, device_mod_dict, target_host)
    145 for target, device_mod in device_mod_dict.items():
    146     if len(device_mod.functions) != 0:
--> 147         device_modules.append(codegen_build(device_mod, target))
    149 mhost = codegen_build(mhost_all, target_host)
    150 for dev_mod in device_modules:

File D:\AI\client\hub\tvm\python\tvm\tir\build.py:131, in codegen_build(mod, target)
    129 if bf is None:
    130     raise ValueError(f"{build_f_name} is not enabled")
--> 131 return bf(mod, target)

File python/tvm_ffi/cython/function.pxi:923, in tvm_ffi.core.Function.__call__()

File python/tvm_ffi/cython/function.pxi:1077, in tvm_ffi.core.tvm_ffi_callback()

File D:\AI\client\hub\tvm\python\tvm\contrib\nvcc.py:319, in tvm_callback_cuda_compile()
    316 @tvm_ffi.register_global_func
    317 def tvm_callback_cuda_compile(code, target):  # pylint: disable=unused-argument
    318     """use nvcc to generate fatbin code for better optimization"""
--> 319     ptx = compile_cuda(code, target_format="fatbin")
    320     return ptx

File D:\AI\client\hub\tvm\python\tvm\contrib\nvcc.py:152, in compile_cuda()
    150     msg += "\nCompilation error:\n"
    151     msg += py_str(out)
--> 152     raise RuntimeError(msg)
    154 # start second stage of compilation
    155 if use_nvshmem:

RuntimeError: #include <cuda.h>

#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \
     (__CUDACC_VER_MAJOR__ > 11))
#define TVM_ENABLE_L2_PREFETCH 1
#else
#define TVM_ENABLE_L2_PREFETCH 0
#endif
#include <cstdint>
using uint = unsigned int;
using uchar = unsigned char;
using ushort = unsigned short;
extern "C" __global__ void __launch_bounds__(1024) fused_matmul1_add1_kernel(float* __restrict__ lv3, float* __restrict__ lv4, float* __restrict__ matmul);
extern "C" __global__ void __launch_bounds__(1024) fused_matmul1_add1_kernel_1(float* __restrict__ T_add, float* __restrict__ matmul, float* __restrict__ p_fc2_bias);
extern "C" __global__ void __launch_bounds__(1024) fused_matmul_add_relu_kernel(float* __restrict__ lv, float* __restrict__ matmul, float* __restrict__ x);
extern "C" __global__ void __launch_bounds__(1024) fused_matmul_add_relu_kernel_1(float* __restrict__ compute, float* __restrict__ matmul, float* __restrict__ p_fc1_bias);
extern "C" __global__ void __launch_bounds__(1024) transpose_kernel(float* __restrict__ T_transpose, float* __restrict__ p_fc1_weight);
extern "C" __global__ void __launch_bounds__(1024) transpose1_kernel(float* __restrict__ T_transpose, float* __restrict__ p_fc2_weight);
extern "C" __global__ void __launch_bounds__(1024) fused_matmul1_add1_kernel(float* __restrict__ lv3, float* __restrict__ lv4, float* __restrict__ matmul) {
  if (((int)threadIdx.x) < 10) {
    matmul[((int)threadIdx.x)] = 0x0.0000000000000p+0f/*0.000000e+00*/;
  }
  for (int ax1 = 0; ax1 < 256; ++ax1) {
    if (((int)threadIdx.x) < 10) {
      matmul[((int)threadIdx.x)] = (matmul[((int)threadIdx.x)] + (lv3[ax1] * lv4[((ax1 * 10) + ((int)threadIdx.x))]));
    }
  }
}

extern "C" __global__ void __launch_bounds__(1024) fused_matmul1_add1_kernel_1(float* __restrict__ T_add, float* __restrict__ matmul, float* __restrict__ p_fc2_bias) {
  if (((int)threadIdx.x) < 10) {
    T_add[((int)threadIdx.x)] = (p_fc2_bias[((int)threadIdx.x)] + matmul[((int)threadIdx.x)]);
  }
}

extern "C" __global__ void __launch_bounds__(1024) fused_matmul_add_relu_kernel(float* __restrict__ lv, float* __restrict__ matmul, float* __restrict__ x) {
  if (((int)threadIdx.x) < 256) {
    matmul[((int)threadIdx.x)] = 0x0.0000000000000p+0f/*0.000000e+00*/;
  }
  for (int ax1 = 0; ax1 < 784; ++ax1) {
    if (((int)threadIdx.x) < 256) {
      matmul[((int)threadIdx.x)] = (matmul[((int)threadIdx.x)] + (x[ax1] * lv[((ax1 * 256) + ((int)threadIdx.x))]));
    }
  }
}

extern "C" __global__ void __launch_bounds__(1024) fused_matmul_add_relu_kernel_1(float* __restrict__ compute, float* __restrict__ matmul, float* __restrict__ p_fc1_bias) {
  if (((int)threadIdx.x) < 256) {
    compute[((int)threadIdx.x)] = max((p_fc1_bias[((int)threadIdx.x)] + matmul[((int)threadIdx.x)]), 0x0.0000000000000p+0f/*0.000000e+00*/);
  }
}

extern "C" __global__ void __launch_bounds__(1024) transpose_kernel(float* __restrict__ T_transpose, float* __restrict__ p_fc1_weight) {
  T_transpose[((((int)blockIdx.x) * 1024) + ((int)threadIdx.x))] = p_fc1_weight[((((((int)threadIdx.x) & 255) * 784) + (((int)blockIdx.x) * 4)) + (((int)threadIdx.x) >> 8))];
}

extern "C" __global__ void __launch_bounds__(1024) transpose1_kernel(float* __restrict__ T_transpose, float* __restrict__ p_fc2_weight) {
  if (((((int)blockIdx.x) * 2) + (((int)threadIdx.x) >> 9)) < 5) {
    T_transpose[((((int)blockIdx.x) * 1024) + ((int)threadIdx.x))] = p_fc2_weight[(((((((int)blockIdx.x) * 4) + ((int)threadIdx.x)) % 10) * 256) + (((((int)blockIdx.x) * 512) + (((int)threadIdx.x) >> 1)) / 5))];
  }
}


Compilation error:
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.0\include\crt/host_config.h(164): fatal error C1189: #error:  -- unsupported Microsoft Visual Studio version! Only the versions between 2019 and 2022 (inclusive) are supported! The nvcc flag '-allow-unsupported-compiler' can be used to override this version check; however, using an unsupported host compiler may cause compilation failure or incorrect run time execution. Use at your own risk.
tvm_kernels.cu

在其他后端上部署#

Apache TVM Unity 还支持其他后端,如各种类型的 GPU(Metal、ROCm、Vulkan 和 OpenCL)、各种类型的 CPU(x86、ARM)以及其他新兴后端(例如 WebAssembly)。部署过程与 GPU 后端类似。