extern#

import os
os.environ['PATH'] += ':/usr/local/cuda/bin' # 保证 nvcc 可以被找到
import tvm
from tvm import te
import numpy as np
import tvm.testing

验证 TVM 在不同目标设备上的向量化代码生成能力。测试分为 CPU 和 GPU 两个版本:CPU 版本使用循环展开策略处理向量化计算,GPU 版本通过线程块和线程索引实现并行。核心逻辑通过手动构建 TIR 中间表示,验证生成代码在 LLVM/OpenCL/CUDA 后端的正确性。测试使用 te.extern 创建外部计算节点,并检查输出结果是否符合预期。

CPU 版本:使用 SIMD 向量化策略(float32x2),每次迭代处理 2 个元素,实现 2 倍循环展开

def extern_generator(ins, outs):
    """Manually write the IR for the extern function, add pipeline"""
    ib = tvm.tir.ir_builder.create()
    with ib.for_range(0, (n + 1) // 2) as i:
        ib.emit(
            outs[0].vstore(
                i * 2, ins[0].vload(i * 2, "float32x2") + tvm.tir.const(1, "float32x2")
            )
        )
    return ib.get()

GPU 版本:通过 blockIdx.xthreadIdx.x 实现两级并行,适配 GPU 的 SIMT 架构

def extern_generator_gpu(ins, outs):
    """Manually write the IR for the extern function, add pipeline"""
    ib = tvm.tir.ir_builder.create()
    bx = te.thread_axis("blockIdx.x")
    tx = te.thread_axis("threadIdx.x")
    ib.scope_attr(bx, "thread_extent", (nn + max_threads - 1) // max_threads)
    ib.scope_attr(tx, "thread_extent", max_threads)
    idx = bx.var * max_threads + tx.var
    with ib.if_scope(ib.likely(idx < n)):
        ib.emit(
            outs[0].vstore(
                idx * 2, ins[0].vload(idx * 2, "float32x2") + tvm.tir.const(1, "float32x2")
            )
        )
    return ib.get()
  • te.extern 创建外部计算节点,分离计算定义与实现

  • vload/vstore 实现显式向量化内存访问

  • 内存对齐:向量化访问要求 64 位对齐(float32x2对应 2*4B=8B

nn = 64
max_threads = 4
n = tvm.runtime.convert(nn)
A = te.placeholder((n,), name="A")

C_cpu = te.extern(A.shape, [A], extern_generator, name="C")
C_gpu = te.extern(A.shape, [A], extern_generator_gpu, name="C")

# Create IRModules directly
mod_cpu = tvm.IRModule.from_expr(te.create_prim_func([A, C_cpu]))
mod_gpu = tvm.IRModule.from_expr(te.create_prim_func([A, C_gpu]))
Hide code cell content
mod_cpu.show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(var_A: T.handle, var_C: T.handle):
        T.func_attr({"tir.noalias": True})
        A = T.match_buffer(var_A, (64,), offset_factor=1)
        C = T.match_buffer(var_C, (64,), offset_factor=1)
        with T.block("C"):
            T.reads()
            T.writes()
            for i in range(32):
                C[i * 2:i * 2 + 2] = A[i * 2:i * 2 + 2] + T.Broadcast(T.float32(1.0), 2)
Hide code cell content
mod_gpu.show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(var_A: T.handle, var_C: T.handle):
        T.func_attr({"tir.noalias": True})
        A = T.match_buffer(var_A, (64,), offset_factor=1)
        C = T.match_buffer(var_C, (64,), offset_factor=1)
        with T.block("C"):
            T.reads()
            T.writes()
            blockIdx_x = T.launch_thread("blockIdx.x", 16)
            threadIdx_x = T.launch_thread("threadIdx.x", 4)
            if T.likely(blockIdx_x * 4 + threadIdx_x < 64):
                C[(blockIdx_x * 4 + threadIdx_x) * 2:(blockIdx_x * 4 + threadIdx_x) * 2 + 2] = A[(blockIdx_x * 4 + threadIdx_x) * 2:(blockIdx_x * 4 + threadIdx_x) * 2 + 2] + T.Broadcast(T.float32(1.0), 2)

跨设备统一验证:

def check_target(target):
    if not tvm.testing.device_enabled(target):
        return
    mod = mod_gpu if target in ["opencl", "cuda"] else mod_cpu
    C = C_gpu if target in ["opencl", "cuda"] else C_cpu
    # build and invoke the kernel.
    f = tvm.compile(mod, target=target)
    dev = tvm.device(target, 0)
    # launch the kernel.
    n = nn
    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
    c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)
    f(a, c)
    tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1)
check_target("llvm")
check_target("opencl")
check_target("cuda")

打包 buffer#

def extern_generator(ins, outs):
    """Manually write the IR for the extern function, add pipeline."""
    return tvm.tir.call_packed("my_extern_array_func1", ins[0], outs[0])
nn = 1024
n = tvm.runtime.convert(nn)
A = te.placeholder((n,), name="A")
C = te.extern(A.shape, [A], extern_generator, name="C")

# Create IRModule directly
mod = tvm.IRModule.from_expr(te.create_prim_func([A, C]))
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(var_A: T.handle, var_C: T.handle):
        T.func_attr({"tir.noalias": True})
        A = T.match_buffer(var_A, (1024,), offset_factor=1)
        C = T.match_buffer(var_C, (1024,), offset_factor=1)
        with T.block("C"):
            T.reads()
            T.writes()
            T.call_packed("my_extern_array_func1", T.tvm_stack_make_array(A.data, T.tvm_stack_make_shape(1024), 0, 1, T.float32(0.0), A.elem_offset), T.tvm_stack_make_array(C.data, T.tvm_stack_make_shape(1024), 0, 1, T.float32(0.0), C.elem_offset))
@tvm.register_func
def my_extern_array_func1(aa, bb):
    aa.copyto(bb)
def check_target(target):
    if not tvm.testing.device_enabled(target):
        return
    # build and invoke the kernel.
    f = tvm.compile(mod, target=target)
    dev = tvm.cpu(0)
    # launch the kernel.
    n = nn
    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
    c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)

    f(a, c)
    tvm.testing.assert_allclose(c.numpy(), a.numpy())

check_target("llvm")

打包缓冲区中间表示#

def extern_generator(ins, outs):
    """Manually write the IR for the extern function, add pipeline."""
    return tvm.tir.call_packed("my_extern_array_func2", *ins, outs[0])
nn = 1024
n = tvm.runtime.convert(nn)
A = te.placeholder((n,), name="A")
B = te.compute((n,), lambda i: A[i] + 1, name="B")
C = te.extern(B.shape, [B], extern_generator, name="C")
# D = te.compute((n,), lambda i: C[i] + 1, name="D")
mod = tvm.IRModule.from_expr(te.create_prim_func([A, C]))
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024,), "float32"), var_C: T.handle):
        T.func_attr({"tir.noalias": True})
        C = T.match_buffer(var_C, (1024,), offset_factor=1)
        # with T.block("root"):
        B = T.alloc_buffer((1024,))
        for i in range(1024):
            with T.block("B"):
                v_i = T.axis.spatial(1024, i)
                T.reads(A[v_i])
                T.writes(B[v_i])
                B[v_i] = A[v_i] + T.float32(1.0)
        with T.block("C"):
            T.reads()
            T.writes()
            elem_offset = T.int32()
            T.call_packed("my_extern_array_func2", T.tvm_stack_make_array(B.data, T.tvm_stack_make_shape(1024), 0, 1, T.float32(0.0), elem_offset), T.tvm_stack_make_array(C.data, T.tvm_stack_make_shape(1024), 0, 1, T.float32(0.0), C.elem_offset))
def check_target(target):
    if not tvm.testing.device_enabled(target):
        return
    # build and invoke the kernel.
    f = tvm.compile(mod, target=target)
    dev = tvm.cpu(0)
    # launch the kernel.
    n = nn
    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
    b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)
    c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)

    @tvm.register_func
    def my_extern_array_func2(aa, cc):
        assert aa.shape == a.shape
        tvm.testing.assert_allclose(aa.numpy(), a.numpy()+1)
        aa.copyto(cc)

    f(a, c)
    tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1)
check_target("llvm")
---------------------------------------------------------------------------
InternalError                             Traceback (most recent call last)
Cell In[17], line 1
----> 1 check_target("llvm")

Cell In[16], line 5, in check_target(target)
      3     return
      4 # build and invoke the kernel.
----> 5 f = tvm.compile(mod, target=target)
      6 dev = tvm.cpu(0)
      7 # launch the kernel.

File /media/pc/data/lxw/ai/tvm/python/tvm/driver/build_module.py:110, in compile(mod, target, relax_pipeline, tir_pipeline)
    103 if _contains_relax(mod):
    104     return tvm.relax.build(
    105         mod,
    106         target,
    107         relax_pipeline=relax_pipeline,
    108         tir_pipeline=tir_pipeline,
    109     )
--> 110 lib = tvm.tir.build(mod, target, pipeline=tir_pipeline)
    111 return Executable(lib)

File /media/pc/data/lxw/ai/tvm/python/tvm/tir/build.py:173, in build(mod, target, pipeline)
    170 else:
    171     # default pipeline depends on the target
    172     pipeline = tvm.tir.get_default_tir_pipeline(target)
--> 173 mod = pipeline(mod)
    175 # Step 5: Get host and device modules
    176 host_mod, device_mod_dict = split_host_device_mods(mod)

File /media/pc/data/lxw/ai/tvm/python/tvm/ir/transform.py:238, in Pass.__call__(self, mod)
    224 def __call__(self, mod):
    225     """Execute the pass. Note that for sequential pass, the dependency among
    226     different passes will be resolved in the backend.
    227 
   (...)
    236         The updated module after applying this pass.
    237     """
--> 238     return _ffi_transform_api.RunPass(self, mod)

File /media/pc/data/lxw/ai/tvm/python/tvm/ffi/cython/function.pxi:228, in tvm.ffi.core.Function.__call__()

File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:576, in operator()()
    574 
    575 TVM_REGISTER_GLOBAL("transform.RunPass")
--> 576     .set_body_typed([](Pass pass, ffi::RValueRef<IRModule> mod) { return pass(*std::move(mod)); });
    577 
    578 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:297, in tvm::transform::Pass::operator()(tvm::IRModule) const()
    295 
    296 IRModule Pass::operator()(IRModule mod) const {
--> 297   return this->operator()(std::move(mod), PassContext::Current());
    298 }
    299 

File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:313, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const()
    311   ret = Pass::AssertImmutableModule(mod, node, pass_ctx);
    312 } else {
--> 313   ret = node->operator()(std::move(mod), pass_ctx);
    314 }
    315 pass_ctx.InstrumentAfterPass(ret, pass_info);

File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:419, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const()
    417 VLOG(0) << "Executing module pass with opt level: " << pass_info->opt_level;
    418 
--> 419 mod = pass_func(std::move(mod), pass_ctx);
    420 
    421 ICHECK(mod.defined()) << "The return value of a module pass must be set.";

File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:570, in operator()()
    568  PassInfo pass_info) {
    569 auto wrapped_pass_func = [pass_func](IRModule mod, PassContext ctx) {
--> 570   return pass_func(ffi::RValueRef<IRModule>(std::move(mod)), ctx);
    571 };
    572 return ModulePass(wrapped_pass_func, pass_info);

File /media/pc/data/lxw/ai/tvm/python/tvm/ffi/cython/function.pxi:281, in tvm.ffi.core.tvm_ffi_callback()

File /media/pc/data/lxw/ai/tvm/python/tvm/tir/pipeline.py:122, in _pipeline()
    109     passes.append(tir.transform.InjectPTXLDG32())
    110 passes.extend(
    111     [
    112         tir.transform.AnnotateDeviceRegions(),
   (...)
    120     ]
    121 )
--> 122 mod = tvm.ir.transform.Sequential(passes)(mod)
    123 return mod

File /media/pc/data/lxw/ai/tvm/python/tvm/ir/transform.py:238, in __call__()
    224 def __call__(self, mod):
    225     """Execute the pass. Note that for sequential pass, the dependency among
    226     different passes will be resolved in the backend.
    227 
   (...)
    236         The updated module after applying this pass.
    237     """
--> 238     return _ffi_transform_api.RunPass(self, mod)

File /media/pc/data/lxw/ai/tvm/python/tvm/ffi/cython/function.pxi:228, in tvm.ffi.core.Function.__call__()

File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:576, in operator()()
    574 
    575 TVM_REGISTER_GLOBAL("transform.RunPass")
--> 576     .set_body_typed([](Pass pass, ffi::RValueRef<IRModule> mod) { return pass(*std::move(mod)); });
    577 
    578 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:297, in tvm::transform::Pass::operator()(tvm::IRModule) const()
    295 
    296 IRModule Pass::operator()(IRModule mod) const {
--> 297   return this->operator()(std::move(mod), PassContext::Current());
    298 }
    299 

File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:313, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const()
    311   ret = Pass::AssertImmutableModule(mod, node, pass_ctx);
    312 } else {
--> 313   ret = node->operator()(std::move(mod), pass_ctx);
    314 }
    315 pass_ctx.InstrumentAfterPass(ret, pass_info);

File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:521, in tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const()
    519 
    520     } else {
--> 521       mod = pass(std::move(mod), pass_ctx);
    522     }
    523   }

File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:313, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const()
    311   ret = Pass::AssertImmutableModule(mod, node, pass_ctx);
    312 } else {
--> 313   ret = node->operator()(std::move(mod), pass_ctx);
    314 }
    315 pass_ctx.InstrumentAfterPass(ret, pass_info);

File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:419, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const()
    417 VLOG(0) << "Executing module pass with opt level: " << pass_info->opt_level;
    418 
--> 419 mod = pass_func(std::move(mod), pass_ctx);
    420 
    421 ICHECK(mod.defined()) << "The return value of a module pass must be set.";

File /media/pc/data/lxw/ai/tvm/src/tir/transforms/make_packed_api.cc:424, in operator()()
    422 }
    423 
--> 424 func = MakePackedAPI(std::move(func));
    425 
    426 if (!func.same_as(orig_func)) {

File /media/pc/data/lxw/ai/tvm/src/tir/transforms/make_packed_api.cc:387, in tvm::tir::MakePackedAPI(tvm::tir::PrimFunc)()
    385 
    386   Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
--> 387   ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined
    388                                  << " are used, but are not passed in as API arguments";
    389 

InternalError: Check failed: undefined.size() == 0 (1 vs. 0) : In PrimFunc main variables [elem_offset] are used, but are not passed in as API arguments