TIR VM 代码生成#
# 导入必要的TVM库和模块
import tvm
import tvm.testing
from tvm import relax
from tvm.ir import assert_structural_equal # 用于比较IR结构是否相等
from tvm.script import relax as R # Relax脚本支持
from tvm.script import tir as T # TIR脚本支持
# 获取TIR模块的辅助函数
def get_tir_mod(mod):
"""
将Relax IR模块转换为TIR模块
参数:
mod: 输入的Relax IR模块
返回值:
通过VM代码生成得到的TIR模块
"""
builder = relax.ExecBuilder()
# 在编译模式下,通过VM代码生成器将Relax模块转换为TIR模块
return relax.vm_build._vmcodegen(builder, mod, exec_mode="compiled")
测试VM编译模式下的基本加法算子代码生成#
# 定义原始的Relax IR模块
@tvm.script.ir_module
class Before:
@R.function(pure=False)
def foo(x: R.Tensor):
R.func_attr({"global_symbol": "foo"}) # 设置全局符号名称
# 调用外部加法函数
z = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor))
return z
after = get_tir_mod(Before)
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle):
T.anylist_setitem_call_packed(r, 2, "test.vm.add", T.anylist_getitem(r, 0), T.anylist_getitem(r, 0))
T.anylist_setitem_call_packed(r, 1, "vm.builtin.copy", T.anylist_getitem(r, 2))
测试VM编译模式下调用TIR函数的代码生成#
@tvm.script.ir_module
class Before:
# 定义一个TIR原始函数
@T.prim_func
def shape_func(H: T.Buffer(T.int64(4), "int64")):
T.func_attr({"global_symbol": "shape_func"})
# 生成的计算函数:将H的第一个元素加1
H[T.int64(0)] = H[T.int64(0)] + T.int64(1)
@R.function(pure=False)
def foo(x: R.Tensor([4], "int64")):
R.func_attr({"global_symbol": "foo"})
# 调用上面定义的TIR函数
_ = Before.shape_func(x)
return x
after = get_tir_mod(Before)
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle):
T.call_cpacked("shape_func", T.anylist_getitem(r, 0))
T.anylist_setitem_call_packed(r, 1, "vm.builtin.copy", T.anylist_getitem(r, 0))
@T.prim_func
def shape_func(H: T.Buffer((T.int64(4),), "int64")):
H[T.int64(0)] = H[T.int64(0)] + T.int64(1)
测试VM编译模式下条件分支(if-else)的代码生成#
@tvm.script.ir_module
class Before:
@R.function(pure=False)
def ife(cond: R.Tensor((), "bool"), x: R.Tensor) -> R.Tensor:
R.func_attr({"global_symbol": "ife"})
# 根据条件执行不同的操作
if cond:
w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor))
else:
w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor))
return w
after = get_tir_mod(Before)
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def __vmtir__ife(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle):
if T.call_packed("vm.builtin.read_if_cond", T.anylist_getitem(r, 0)):
T.anylist_setitem_call_packed(r, 4, "test.vm.add", T.anylist_getitem(r, 1), T.anylist_getitem(r, 1))
T.anylist_setitem_call_packed(r, 3, "vm.builtin.copy", T.anylist_getitem(r, 4))
else:
T.anylist_setitem_call_packed(r, 5, "test.vm.mul", T.anylist_getitem(r, 1), T.anylist_getitem(r, 1))
T.anylist_setitem_call_packed(r, 3, "vm.builtin.copy", T.anylist_getitem(r, 5))
T.anylist_setitem_call_packed(r, 2, "vm.builtin.copy", T.anylist_getitem(r, 3))
测试VM编译模式下处理常量的代码生成#
@tvm.script.ir_module
class Before:
@R.function
def main(x: R.Tensor):
R.func_attr({"global_symbol": "main"})
y = R.const([1, 2]) # 创建常量数组
z = (y, R.const([3, 4]), x) # 创建包含常量和变量的元组
return z
after = get_tir_mod(Before)
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def __vmtir__main(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle):
T.anylist_setitem_call_packed(r, 2, "vm.builtin.make_tuple", T.anylist_getitem(c, 0), T.anylist_getitem(c, 1), T.anylist_getitem(r, 0))
T.anylist_setitem_call_packed(r, 1, "vm.builtin.copy", T.anylist_getitem(r, 2))
测试VM编译模式下使用常量作为函数调用参数的代码生成#
@tvm.script.ir_module
class Before:
@R.function(pure=False)
def main(x: R.Tensor):
R.func_attr({"global_symbol": "main"})
y = R.const([1, 2]) # 创建常量数组
# 将常量作为参数传递给函数
z = R.call_packed("test.vm.add", x, y, sinfo_args=(R.Tensor))
return z
after = get_tir_mod(Before)
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def __vmtir__main(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle):
T.anylist_setitem_call_packed(r, 2, "test.vm.add", T.anylist_getitem(r, 0), T.anylist_getitem(c, 0))
T.anylist_setitem_call_packed(r, 1, "vm.builtin.copy", T.anylist_getitem(r, 2))