TIR VM 代码生成#

# 导入必要的TVM库和模块
import tvm
import tvm.testing
from tvm import relax
from tvm.ir import assert_structural_equal  # 用于比较IR结构是否相等
from tvm.script import relax as R  # Relax脚本支持
from tvm.script import tir as T  # TIR脚本支持
# 获取TIR模块的辅助函数
def get_tir_mod(mod):
    """
    将Relax IR模块转换为TIR模块
    
    参数:
        mod: 输入的Relax IR模块
    
    返回值:
        通过VM代码生成得到的TIR模块
    """
    builder = relax.ExecBuilder()
    # 在编译模式下,通过VM代码生成器将Relax模块转换为TIR模块
    return relax.vm_build._vmcodegen(builder, mod, exec_mode="compiled")

测试VM编译模式下的基本加法算子代码生成#

# 定义原始的Relax IR模块
@tvm.script.ir_module
class Before:
    @R.function(pure=False)
    def foo(x: R.Tensor):
        R.func_attr({"global_symbol": "foo"})  # 设置全局符号名称
        # 调用外部加法函数
        z = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor))
        return z
after = get_tir_mod(Before)
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle):
        T.anylist_setitem_call_packed(r, 2, "test.vm.add", T.anylist_getitem(r, 0), T.anylist_getitem(r, 0))
        T.anylist_setitem_call_packed(r, 1, "vm.builtin.copy", T.anylist_getitem(r, 2))

测试VM编译模式下调用TIR函数的代码生成#

@tvm.script.ir_module
class Before:
    # 定义一个TIR原始函数
    @T.prim_func
    def shape_func(H: T.Buffer(T.int64(4), "int64")):
        T.func_attr({"global_symbol": "shape_func"})
        # 生成的计算函数:将H的第一个元素加1
        H[T.int64(0)] = H[T.int64(0)] + T.int64(1)

    @R.function(pure=False)
    def foo(x: R.Tensor([4], "int64")):
        R.func_attr({"global_symbol": "foo"})
        # 调用上面定义的TIR函数
        _ = Before.shape_func(x)
        return x
after = get_tir_mod(Before)
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle):
        T.call_cpacked("shape_func", T.anylist_getitem(r, 0))
        T.anylist_setitem_call_packed(r, 1, "vm.builtin.copy", T.anylist_getitem(r, 0))

    @T.prim_func
    def shape_func(H: T.Buffer((T.int64(4),), "int64")):
        H[T.int64(0)] = H[T.int64(0)] + T.int64(1)

测试VM编译模式下条件分支(if-else)的代码生成#

@tvm.script.ir_module
class Before:
    @R.function(pure=False)
    def ife(cond: R.Tensor((), "bool"), x: R.Tensor) -> R.Tensor:
        R.func_attr({"global_symbol": "ife"})
        # 根据条件执行不同的操作
        if cond:
            w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor))
        else:
            w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor))
        return w
after = get_tir_mod(Before)
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def __vmtir__ife(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle):
        if T.call_packed("vm.builtin.read_if_cond", T.anylist_getitem(r, 0)):
            T.anylist_setitem_call_packed(r, 4, "test.vm.add", T.anylist_getitem(r, 1), T.anylist_getitem(r, 1))
            T.anylist_setitem_call_packed(r, 3, "vm.builtin.copy", T.anylist_getitem(r, 4))
        else:
            T.anylist_setitem_call_packed(r, 5, "test.vm.mul", T.anylist_getitem(r, 1), T.anylist_getitem(r, 1))
            T.anylist_setitem_call_packed(r, 3, "vm.builtin.copy", T.anylist_getitem(r, 5))
        T.anylist_setitem_call_packed(r, 2, "vm.builtin.copy", T.anylist_getitem(r, 3))

测试VM编译模式下处理常量的代码生成#

@tvm.script.ir_module
class Before:
    @R.function
    def main(x: R.Tensor):
        R.func_attr({"global_symbol": "main"})
        y = R.const([1, 2])  # 创建常量数组
        z = (y, R.const([3, 4]), x)  # 创建包含常量和变量的元组
        return z
after = get_tir_mod(Before)
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def __vmtir__main(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle):
        T.anylist_setitem_call_packed(r, 2, "vm.builtin.make_tuple", T.anylist_getitem(c, 0), T.anylist_getitem(c, 1), T.anylist_getitem(r, 0))
        T.anylist_setitem_call_packed(r, 1, "vm.builtin.copy", T.anylist_getitem(r, 2))

测试VM编译模式下使用常量作为函数调用参数的代码生成#

@tvm.script.ir_module
class Before:
    @R.function(pure=False)
    def main(x: R.Tensor):
        R.func_attr({"global_symbol": "main"})
        y = R.const([1, 2])  # 创建常量数组
        # 将常量作为参数传递给函数
        z = R.call_packed("test.vm.add", x, y, sinfo_args=(R.Tensor))
        return z
after = get_tir_mod(Before)
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def __vmtir__main(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle):
        T.anylist_setitem_call_packed(r, 2, "test.vm.add", T.anylist_getitem(r, 0), T.anylist_getitem(c, 0))
        T.anylist_setitem_call_packed(r, 1, "vm.builtin.copy", T.anylist_getitem(r, 2))