IRModule#
Apache TVM Unity 的核心抽象,即 IRModule。IRModule 包含整个 ML 模型,包括 计算图、张量程序和对外部库的潜在调用。
创建 IRModule#
IRModules 可以通过多种方式进行初始化。
import set_env
import numpy as np
import tvm
from tvm import relax
从现有模型导入#
初始化 IRModule 的最常见方法是从现有模型导入。Apache TVM Unity 支持从一系列框架导入,如 PyTorch 和 ONNX。
import torch
from torch import fx, nn
from tvm.relax.frontend.torch import from_fx
# Create a dummy model
class TorchModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
# Give the input shape and data type
input_info = [((1, 784), "float32")]
# Convert the model to IRModule
with torch.no_grad():
torch_fx_model = fx.symbolic_trace(TorchModel())
mod_from_torch = from_fx(torch_fx_model, input_info, keep_params_as_input=True)
mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch)
# Print the IRModule
mod_from_torch.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(inp_0: R.Tensor((1, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(inp_0, lv, out_dtype="float32")
lv2: R.Tensor((1, 256), dtype="float32") = R.add(lv1, fc1_bias)
lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
使用 Relax NN 模块编写#
Apache TVM Unity还提供了一系列类似 PyTorch 的 API,帮助用户直接编写 IRModule。
from tvm.relax.frontend import nn
class RelaxModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
mod_from_relax, params_from_relax = RelaxModel().export_tvm(
{"forward": {"x": nn.spec.Tensor((1, 784), "float32")}}
)
mod_from_relax.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = add1
R.output(gv)
return gv
通过 TVMScript 创建#
TVMScript 是一种基于 Python 的 DSL,用于 IRModule。我们可以直接以 TVMScript 语法输出 IRModule,或者解析 TVMScript 以获取 IRModule。
from tvm.script import ir as I
from tvm.script import relax as R
@I.ir_module
class TVMScriptModule:
@R.function
def main(
x: R.Tensor((1, 784), dtype="float32"),
fc1_weight: R.Tensor((256, 784), dtype="float32"),
fc1_bias: R.Tensor((256,), dtype="float32"),
fc2_weight: R.Tensor((10, 256), dtype="float32"),
fc2_bias: R.Tensor((10,), dtype="float32"),
) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims = R.permute_dims(fc1_weight, axes=None)
matmul = R.matmul(x, permute_dims, out_dtype="void")
add = R.add(matmul, fc1_bias)
relu = R.nn.relu(add)
permute_dims1 = R.permute_dims(fc2_weight, axes=None)
matmul1 = R.matmul(relu, permute_dims1, out_dtype="void")
add1 = R.add(matmul1, fc2_bias)
gv = add1
R.output(gv)
return gv
mod_from_script = TVMScriptModule
mod_from_script.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = add1
R.output(gv)
return gv
IRModule的属性#
IRModule 是一组函数的集合,通过 GlobalVars 索引。
mod = mod_from_torch
print(mod.get_global_vars())
[I.GlobalVar("main")]
我们可以通过使用 GlobalVars 或它们的名称来索引 IRModule 中的函数。
# index by global var name
print(mod["main"])
# index by global var, and checking they are the same function
(gv,) = mod.get_global_vars()
assert mod[gv] == mod["main"]
Show code cell output
# from tvm.script import relax as R
@R.function
def main(inp_0: R.Tensor((1, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(inp_0, lv, out_dtype="float32")
lv2: R.Tensor((1, 256), dtype="float32") = R.add(lv1, fc1_bias)
lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
IRModule 上的变换#
变换是 Apache TVM Unity 的重要组成部分。一个变换接受一个 IRModule 并输出另一个 IRModule。我们可以将一系列变换应用于一个 IRModule 以获得一个新的 IRModule。这是优化模型的常见方法。
有关每个变换的详细信息,请参阅变换 API 参考。
我们首先对 IRModule 应用 LegalizeOps
变换。此变换将 Relax 模块转换为混合阶段,同一个模块内包 Relax 和 TensorIR 函数。同时,Relax 算子将被转换为 call_tir
。
mod = mod_from_torch
mod = relax.transform.LegalizeOps()(mod)
mod.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(lv1: T.Buffer((T.int64(1), T.int64(256)), "float32"), fc1_bias: T.Buffer((T.int64(256),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(lv1[v_ax0, v_ax1], fc1_bias[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = lv1[v_ax0, v_ax1] + fc1_bias[v_ax1]
@T.prim_func(private=True)
def add1(lv5: T.Buffer((T.int64(1), T.int64(10)), "float32"), fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(lv5[v_ax0, v_ax1], fc2_bias[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = lv5[v_ax0, v_ax1] + fc2_bias[v_ax1]
@T.prim_func(private=True)
def matmul(inp_0: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(inp_0[v_i0, v_k], lv[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0.0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + inp_0[v_i0, v_k] * lv[v_k, v_i1]
@T.prim_func(private=True)
def matmul1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0.0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]
@T.prim_func(private=True)
def relu(lv2: T.Buffer((T.int64(1), T.int64(256)), "float32"), compute: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(T.int64(1), T.int64(256)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(lv2[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.max(lv2[v_i0, v_i1], T.float32(0.0))
@T.prim_func(private=True)
def transpose(fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(fc1_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]
@T.prim_func(private=True)
def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(fc2_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]
@R.function
def main(inp_0: R.Tensor((1, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
lv = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((784, 256), dtype="float32"))
lv1 = R.call_tir(cls.matmul, (inp_0, lv), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv2 = R.call_tir(cls.add, (lv1, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv3 = R.call_tir(cls.relu, (lv2,), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv4 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
lv5 = R.call_tir(cls.matmul1, (lv3, lv4), out_sinfo=R.Tensor((1, 10), dtype="float32"))
lv6 = R.call_tir(cls.add1, (lv5, fc2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32"))
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
转变换后,模块内会有更多的函数。让我们再次打印全局变量。
print(mod.get_global_vars())
[I.GlobalVar("add"), I.GlobalVar("add1"), I.GlobalVar("main"), I.GlobalVar("matmul"), I.GlobalVar("matmul1"), I.GlobalVar("relu"), I.GlobalVar("transpose"), I.GlobalVar("transpose1")]
Apache TVM Unity 为用户提供了一组默认的变换管道,以简化变换过程。然后我们可以将默认管道应用于模块。默认的零管道包含一些非常基础的变换,包括:
LegalizeOps
:此转换将 Relax 算子转换为具有相应 TensorIR 函数的call_tir
函数。在此转换之后,IRModule 将包含 Relax 函数和 TensorIR 函数。AnnotateTIROpPattern
:此转换注解 TensorIR 函数的模式,为后续的算子融合做准备。FoldConstant
:此传递执行常量折叠,优化涉及常量的运算。FuseOps
和FuseTIR
:这两个传递基于上一步(AnnotateTIROpPattern
)中注解的模式共同工作以融合算子。这些传递转换 Relax 函数和 TensorIR 函数。
备注
在这里,我们在流程中应用了两次 LegalizeOps
。第二次是多余的,但无害。
每个传递都可以在流程中重复,因为我们确保传递可以处理所有合法的 IRModule 输入。这种设计可以帮助用户构建他们自己的管道。
mod = relax.get_pipeline("zero")(mod)
mod.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def fused_matmul1_add1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(10)))
for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], fc2_bias[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc2_bias[v_ax1]
@T.prim_func(private=True)
def fused_matmul_add_relu(inp_0: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), fc1_bias: T.Buffer((T.int64(256),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(inp_0[v_i0, v_k], lv[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + inp_0[v_i0, v_k] * lv[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], fc1_bias[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc1_bias[v_ax1]
for i0, i1 in T.grid(T.int64(1), T.int64(256)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_add_intermediate[v_i0, v_i1])
T.writes(compute_intermediate[v_i0, v_i1])
compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0.0))
@T.prim_func(private=True)
def transpose(fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(fc1_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]
@T.prim_func(private=True)
def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(fc2_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]
@R.function
def main(inp_0: R.Tensor((1, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
lv = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((784, 256), dtype="float32"))
lv_1 = R.call_tir(cls.fused_matmul_add_relu, (inp_0, lv, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv4 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
gv = R.call_tir(cls.fused_matmul1_add1, (lv_1, lv4, fc2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32"))
R.output(gv)
return gv
通用部署IRModule#
优化完成后,我们可以将模型编译为 TVM 运行时模块。值得注意的是,Apache TVM Unity 提供了通用部署的能力,这意味着我们可以在不同的后端(包括 CPU、GPU 和其他新兴后端)上部署相同的 IRModule。
在 CPU 上部署#
我们可以通过将目标指定为 llvm
来在 CPU 上部署 IRModule。
exec = relax.build(mod, target="llvm")
dev = tvm.cpu()
vm = relax.VirtualMachine(exec, dev)
raw_data = np.random.rand(1, 784).astype("float32")
data = tvm.nd.array(raw_data, dev)
cpu_out = vm["main"](data, *params_from_torch["main"]).numpy()
print(cpu_out)
[[-0.00346027 -0.15362605 0.00199257 0.10505666 -0.15537857 0.08472918
-0.08493918 0.00613638 0.03105983 0.02112349]]
在 GPU 上部署#
除了 CPU 后端,我们还可以在其他后端上部署 IRModule。例如,我们可以将 IRModule 部署在 GPU 上。GPU 需要包含额外信息的程序,如线程绑定和共享内存分配。我们需要进一步的转换来生成 GPU 程序。
我们使用 DLight 来生成 GPU 程序。
from tvm import dlight as dl
with tvm.target.Target("cuda"):
gpu_mod = dl.ApplyDefaultSchedule(
dl.gpu.Matmul(),
dl.gpu.Fallback(),
)(mod)
[17:16:00] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:171: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
现在我们可以像在 CPU 上那样,在 GPU 上编译 IRModule。
exec = relax.build(gpu_mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(exec, dev)
# Need to allocate data and params on GPU device
data = tvm.nd.array(raw_data, dev)
gpu_params = [tvm.nd.array(p, dev) for p in params_from_torch["main"]]
gpu_out = vm["main"](data, *gpu_params).numpy()
print(gpu_out)
# Check the correctness of the results
assert np.allclose(cpu_out, gpu_out, atol=1e-3)
Show code cell output
[17:16:05] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:171: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
---------------------------------------------------------------------------
InternalError Traceback (most recent call last)
Cell In[14], line 1
----> 1 exec = relax.build(gpu_mod, target="cuda")
2 dev = tvm.device("cuda", 0)
3 vm = relax.VirtualMachine(exec, dev)
File /media/pc/data/lxw/ai/tvm/python/tvm/relax/vm_build.py:353, in build(mod, target, params, pipeline, exec_mode, system_lib)
351 builder = relax.ExecBuilder()
352 mod = _vmcodegen(builder, mod, exec_mode)
--> 353 return _vmlink(
354 builder=builder,
355 target=target,
356 tir_mod=_filter_tir(mod),
357 ext_libs=ext_libs,
358 params=params,
359 system_lib=system_lib,
360 )
File /media/pc/data/lxw/ai/tvm/python/tvm/relax/vm_build.py:249, in _vmlink(builder, target, tir_mod, ext_libs, params, system_lib)
247 tir_ext_libs = []
248 if tir_mod is not None and len(tir_mod.get_global_vars()) > 0:
--> 249 lib = tvm.build(
250 tir_mod,
251 target=target,
252 runtime=_autodetect_system_lib_req(target, system_lib),
253 )
254 for ext_mod in ext_libs:
255 if ext_mod.is_device_module:
File /media/pc/data/lxw/ai/tvm/python/tvm/driver/build_module.py:297, in build(inputs, args, target, target_host, runtime, name, binds)
293 target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
295 annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host)
--> 297 rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
299 annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host)
301 if not isinstance(target_host, Target):
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_ctypes/packed_func.py:245, in PackedFuncBase.__call__(self, *args)
233 ret_tcode = ctypes.c_int()
234 if (
235 _LIB.TVMFuncCall(
236 self.handle,
(...)
243 != 0
244 ):
--> 245 raise_last_ffi_error()
246 _ = temp_args
247 _ = args
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/base.py:481, in raise_last_ffi_error()
475 # The exception PyObject may contain a large amount of state,
476 # including all stack frames that may be inspected in a later
477 # PDB post-mortem. Therefore, we must make sure to remove the
478 # underlying PyObject* from the C++ side after we retrieve it.
479 _LIB.TVMDropLastPythonError()
--> 481 raise py_err
InternalError: Traceback (most recent call last):
2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::__mk_TVM24::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#1}>(tvm::__mk_TVM24::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
1: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
0: tvm::codegen::Build(tvm::IRModule, tvm::Target)
File "/media/pc/data/lxw/ai/tvm/src/target/codegen.cc", line 72
InternalError: Check failed: (bf != nullptr) is false: target.build.cuda is not enabled
在其他后端上部署#
Apache TVM Unity 还支持其他后端,如各种类型的 GPU(Metal、ROCm、Vulkan 和 OpenCL)、各种类型的 CPU(x86、ARM)以及其他新兴后端(例如 WebAssembly)。部署过程与 GPU 后端类似。