IRModule#

Apache TVM Unity 的核心抽象,即 IRModule。IRModule 包含整个 ML 模型,包括 计算图、张量程序和对外部库的潜在调用。

创建 IRModule#

IRModules 可以通过多种方式进行初始化。

import set_env
import numpy as np
import tvm
from tvm import relax

从现有模型导入#

初始化 IRModule 的最常见方法是从现有模型导入。Apache TVM Unity 支持从一系列框架导入,如 PyTorch 和 ONNX。

import torch
from torch import fx, nn
from tvm.relax.frontend.torch import from_fx
# Create a dummy model
class TorchModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


# Give the input shape and data type
input_info = [((1, 784), "float32")]

# Convert the model to IRModule
with torch.no_grad():
    torch_fx_model = fx.symbolic_trace(TorchModel())
    mod_from_torch = from_fx(torch_fx_model, input_info, keep_params_as_input=True)

mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch)
# Print the IRModule
mod_from_torch.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(inp_0: R.Tensor((1, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(inp_0, lv, out_dtype="float32")
            lv2: R.Tensor((1, 256), dtype="float32") = R.add(lv1, fc1_bias)
            lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
            lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
            lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, fc2_bias)
            gv: R.Tensor((1, 10), dtype="float32") = lv6
            R.output(gv)
        return gv

使用 Relax NN 模块编写#

Apache TVM Unity还提供了一系列类似 PyTorch 的 API,帮助用户直接编写 IRModule。

from tvm.relax.frontend import nn


class RelaxModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


mod_from_relax, params_from_relax = RelaxModel().export_tvm(
    {"forward": {"x": nn.spec.Tensor((1, 784), "float32")}}
)
mod_from_relax.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
            gv: R.Tensor((1, 10), dtype="float32") = add1
            R.output(gv)
        return gv

通过 TVMScript 创建#

TVMScript 是一种基于 Python 的 DSL,用于 IRModule。我们可以直接以 TVMScript 语法输出 IRModule,或者解析 TVMScript 以获取 IRModule。

from tvm.script import ir as I
from tvm.script import relax as R


@I.ir_module
class TVMScriptModule:
    @R.function
    def main(
        x: R.Tensor((1, 784), dtype="float32"),
        fc1_weight: R.Tensor((256, 784), dtype="float32"),
        fc1_bias: R.Tensor((256,), dtype="float32"),
        fc2_weight: R.Tensor((10, 256), dtype="float32"),
        fc2_bias: R.Tensor((10,), dtype="float32"),
    ) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims = R.permute_dims(fc1_weight, axes=None)
            matmul = R.matmul(x, permute_dims, out_dtype="void")
            add = R.add(matmul, fc1_bias)
            relu = R.nn.relu(add)
            permute_dims1 = R.permute_dims(fc2_weight, axes=None)
            matmul1 = R.matmul(relu, permute_dims1, out_dtype="void")
            add1 = R.add(matmul1, fc2_bias)
            gv = add1
            R.output(gv)
        return gv


mod_from_script = TVMScriptModule
mod_from_script.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
            gv: R.Tensor((1, 10), dtype="float32") = add1
            R.output(gv)
        return gv

IRModule的属性#

IRModule 是一组函数的集合,通过 GlobalVars 索引。

mod = mod_from_torch
print(mod.get_global_vars())
[I.GlobalVar("main")]

我们可以通过使用 GlobalVars 或它们的名称来索引 IRModule 中的函数。

# index by global var name
print(mod["main"])
# index by global var, and checking they are the same function
(gv,) = mod.get_global_vars()
assert mod[gv] == mod["main"]
# from tvm.script import relax as R

@R.function
def main(inp_0: R.Tensor((1, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
    R.func_attr({"num_input": 1})
    with R.dataflow():
        lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
        lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(inp_0, lv, out_dtype="float32")
        lv2: R.Tensor((1, 256), dtype="float32") = R.add(lv1, fc1_bias)
        lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
        lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
        lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
        lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, fc2_bias)
        gv: R.Tensor((1, 10), dtype="float32") = lv6
        R.output(gv)
    return gv

IRModule 上的变换#

变换是 Apache TVM Unity 的重要组成部分。一个变换接受一个 IRModule 并输出另一个 IRModule。我们可以将一系列变换应用于一个 IRModule 以获得一个新的 IRModule。这是优化模型的常见方法。

有关每个变换的详细信息,请参阅变换 API 参考

我们首先对 IRModule 应用 LegalizeOps 变换。此变换将 Relax 模块转换为混合阶段,同一个模块内包 Relax 和 TensorIR 函数。同时,Relax 算子将被转换为 call_tir

mod = mod_from_torch
mod = relax.transform.LegalizeOps()(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(lv1: T.Buffer((T.int64(1), T.int64(256)), "float32"), fc1_bias: T.Buffer((T.int64(256),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv1[v_ax0, v_ax1], fc1_bias[v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = lv1[v_ax0, v_ax1] + fc1_bias[v_ax1]

    @T.prim_func(private=True)
    def add1(lv5: T.Buffer((T.int64(1), T.int64(10)), "float32"), fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv5[v_ax0, v_ax1], fc2_bias[v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = lv5[v_ax0, v_ax1] + fc2_bias[v_ax1]

    @T.prim_func(private=True)
    def matmul(inp_0: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(inp_0[v_i0, v_k], lv[v_k, v_i1])
                T.writes(matmul[v_i0, v_i1])
                with T.init():
                    matmul[v_i0, v_i1] = T.float32(0.0)
                matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + inp_0[v_i0, v_k] * lv[v_k, v_i1]

    @T.prim_func(private=True)
    def matmul1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
                T.writes(matmul[v_i0, v_i1])
                with T.init():
                    matmul[v_i0, v_i1] = T.float32(0.0)
                matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]

    @T.prim_func(private=True)
    def relu(lv2: T.Buffer((T.int64(1), T.int64(256)), "float32"), compute: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1 in T.grid(T.int64(1), T.int64(256)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(lv2[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.max(lv2[v_i0, v_i1], T.float32(0.0))

    @T.prim_func(private=True)
    def transpose(fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc1_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]

    @T.prim_func(private=True)
    def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc2_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]

    @R.function
    def main(inp_0: R.Tensor((1, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            lv = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((784, 256), dtype="float32"))
            lv1 = R.call_tir(cls.matmul, (inp_0, lv), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            lv2 = R.call_tir(cls.add, (lv1, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            lv3 = R.call_tir(cls.relu, (lv2,), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            lv4 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
            lv5 = R.call_tir(cls.matmul1, (lv3, lv4), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            lv6 = R.call_tir(cls.add1, (lv5, fc2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            gv: R.Tensor((1, 10), dtype="float32") = lv6
            R.output(gv)
        return gv

转变换后,模块内会有更多的函数。让我们再次打印全局变量。

print(mod.get_global_vars())
[I.GlobalVar("add"), I.GlobalVar("add1"), I.GlobalVar("main"), I.GlobalVar("matmul"), I.GlobalVar("matmul1"), I.GlobalVar("relu"), I.GlobalVar("transpose"), I.GlobalVar("transpose1")]

Apache TVM Unity 为用户提供了一组默认的变换管道,以简化变换过程。然后我们可以将默认管道应用于模块。默认的零管道包含一些非常基础的变换,包括:

  • LegalizeOps:此转换将 Relax 算子转换为具有相应 TensorIR 函数的 call_tir 函数。在此转换之后,IRModule 将包含 Relax 函数和 TensorIR 函数。

  • AnnotateTIROpPattern:此转换注解 TensorIR 函数的模式,为后续的算子融合做准备。

  • FoldConstant:此传递执行常量折叠,优化涉及常量的运算。

  • FuseOpsFuseTIR:这两个传递基于上一步(AnnotateTIROpPattern)中注解的模式共同工作以融合算子。这些传递转换 Relax 函数和 TensorIR 函数。

备注

在这里,我们在流程中应用了两次 LegalizeOps。第二次是多余的,但无害。

每个传递都可以在流程中重复,因为我们确保传递可以处理所有合法的 IRModule 输入。这种设计可以帮助用户构建他们自己的管道。

mod = relax.get_pipeline("zero")(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def fused_matmul1_add1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(10)))
        for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
                T.writes(matmul_intermediate[v_i0, v_i1])
                with T.init():
                    matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul_intermediate[v_ax0, v_ax1], fc2_bias[v_ax1])
                T.writes(T_add_intermediate[v_ax0, v_ax1])
                T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc2_bias[v_ax1]

    @T.prim_func(private=True)
    def fused_matmul_add_relu(inp_0: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), fc1_bias: T.Buffer((T.int64(256),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
        T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
        for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(inp_0[v_i0, v_k], lv[v_k, v_i1])
                T.writes(matmul_intermediate[v_i0, v_i1])
                with T.init():
                    matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + inp_0[v_i0, v_k] * lv[v_k, v_i1]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul_intermediate[v_ax0, v_ax1], fc1_bias[v_ax1])
                T.writes(T_add_intermediate[v_ax0, v_ax1])
                T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc1_bias[v_ax1]
        for i0, i1 in T.grid(T.int64(1), T.int64(256)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_add_intermediate[v_i0, v_i1])
                T.writes(compute_intermediate[v_i0, v_i1])
                compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0.0))

    @T.prim_func(private=True)
    def transpose(fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc1_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]

    @T.prim_func(private=True)
    def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc2_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]

    @R.function
    def main(inp_0: R.Tensor((1, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            lv = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((784, 256), dtype="float32"))
            lv_1 = R.call_tir(cls.fused_matmul_add_relu, (inp_0, lv, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            lv4 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
            gv = R.call_tir(cls.fused_matmul1_add1, (lv_1, lv4, fc2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            R.output(gv)
        return gv

通用部署IRModule#

优化完成后,我们可以将模型编译为 TVM 运行时模块。值得注意的是,Apache TVM Unity 提供了通用部署的能力,这意味着我们可以在不同的后端(包括 CPU、GPU 和其他新兴后端)上部署相同的 IRModule。

在 CPU 上部署#

我们可以通过将目标指定为 llvm 来在 CPU 上部署 IRModule。

exec = relax.build(mod, target="llvm")
dev = tvm.cpu()
vm = relax.VirtualMachine(exec, dev)

raw_data = np.random.rand(1, 784).astype("float32")
data = tvm.nd.array(raw_data, dev)
cpu_out = vm["main"](data, *params_from_torch["main"]).numpy()
print(cpu_out)
[[-0.16265766  0.26098928  0.12619261 -0.170795    0.01475293 -0.02576868
  -0.03328853  0.21777523  0.16681662 -0.08581798]]

在 GPU 上部署#

除了 CPU 后端,我们还可以在其他后端上部署 IRModule。例如,我们可以将 IRModule 部署在 GPU 上。GPU 需要包含额外信息的程序,如线程绑定和共享内存分配。我们需要进一步的转换来生成 GPU 程序。

我们使用 DLight 来生成 GPU 程序。

from tvm import dlight as dl

with tvm.target.Target("cuda"):
    gpu_mod = dl.ApplyDefaultSchedule(
        dl.gpu.Matmul(),
        dl.gpu.Fallback(),
    )(mod)

现在我们可以像在 CPU 上那样,在 GPU 上编译 IRModule。

exec = relax.build(gpu_mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(exec, dev)
# Need to allocate data and params on GPU device
data = tvm.nd.array(raw_data, dev)
gpu_params = [tvm.nd.array(p, dev) for p in params_from_torch["main"]]
gpu_out = vm["main"](data, *gpu_params).numpy()
print(gpu_out)

# Check the correctness of the results
assert np.allclose(cpu_out, gpu_out, atol=1e-3)
[[-0.16265765  0.26098925  0.12619264 -0.17079495  0.01475285 -0.02576872
  -0.03328849  0.21777529  0.16681664 -0.08581802]]

在其他后端上部署#

Apache TVM Unity 还支持其他后端,如各种类型的 GPU(Metal、ROCm、Vulkan 和 OpenCL)、各种类型的 CPU(x86、ARM)以及其他新兴后端(例如 WebAssembly)。部署过程与 GPU 后端类似。