extern#
import os
os.environ['PATH'] += ':/usr/local/cuda/bin' # 保证 nvcc 可以被找到
import tvm
from tvm import te
import numpy as np
import tvm.testing
验证 TVM 在不同目标设备上的向量化代码生成能力。测试分为 CPU 和 GPU 两个版本:CPU 版本使用循环展开策略处理向量化计算,GPU 版本通过线程块和线程索引实现并行。核心逻辑通过手动构建 TIR 中间表示,验证生成代码在 LLVM/OpenCL/CUDA 后端的正确性。测试使用 te.extern
创建外部计算节点,并检查输出结果是否符合预期。
CPU 版本:使用 SIMD 向量化策略(float32x2
),每次迭代处理 2 个元素,实现 2 倍循环展开
def extern_generator(ins, outs):
"""Manually write the IR for the extern function, add pipeline"""
ib = tvm.tir.ir_builder.create()
with ib.for_range(0, (n + 1) // 2) as i:
ib.emit(
outs[0].vstore(
i * 2, ins[0].vload(i * 2, "float32x2") + tvm.tir.const(1, "float32x2")
)
)
return ib.get()
GPU 版本:通过 blockIdx.x
和 threadIdx.x
实现两级并行,适配 GPU 的 SIMT 架构
def extern_generator_gpu(ins, outs):
"""Manually write the IR for the extern function, add pipeline"""
ib = tvm.tir.ir_builder.create()
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", (nn + max_threads - 1) // max_threads)
ib.scope_attr(tx, "thread_extent", max_threads)
idx = bx.var * max_threads + tx.var
with ib.if_scope(ib.likely(idx < n)):
ib.emit(
outs[0].vstore(
idx * 2, ins[0].vload(idx * 2, "float32x2") + tvm.tir.const(1, "float32x2")
)
)
return ib.get()
te.extern
创建外部计算节点,分离计算定义与实现vload/vstore
实现显式向量化内存访问内存对齐:向量化访问要求 64 位对齐(
float32x2
对应2*4B=8B
)
nn = 64
max_threads = 4
n = tvm.runtime.convert(nn)
A = te.placeholder((n,), name="A")
C_cpu = te.extern(A.shape, [A], extern_generator, name="C")
C_gpu = te.extern(A.shape, [A], extern_generator_gpu, name="C")
# Create IRModules directly
mod_cpu = tvm.IRModule.from_expr(te.create_prim_func([A, C_cpu]))
mod_gpu = tvm.IRModule.from_expr(te.create_prim_func([A, C_gpu]))
Show code cell content
mod_cpu.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
A = T.match_buffer(var_A, (64,), offset_factor=1)
C = T.match_buffer(var_C, (64,), offset_factor=1)
with T.block("C"):
T.reads()
T.writes()
for i in range(32):
C[i * 2:i * 2 + 2] = A[i * 2:i * 2 + 2] + T.Broadcast(T.float32(1.0), 2)
Show code cell content
mod_gpu.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
A = T.match_buffer(var_A, (64,), offset_factor=1)
C = T.match_buffer(var_C, (64,), offset_factor=1)
with T.block("C"):
T.reads()
T.writes()
blockIdx_x = T.launch_thread("blockIdx.x", 16)
threadIdx_x = T.launch_thread("threadIdx.x", 4)
if T.likely(blockIdx_x * 4 + threadIdx_x < 64):
C[(blockIdx_x * 4 + threadIdx_x) * 2:(blockIdx_x * 4 + threadIdx_x) * 2 + 2] = A[(blockIdx_x * 4 + threadIdx_x) * 2:(blockIdx_x * 4 + threadIdx_x) * 2 + 2] + T.Broadcast(T.float32(1.0), 2)
跨设备统一验证:
def check_target(target):
if not tvm.testing.device_enabled(target):
return
mod = mod_gpu if target in ["opencl", "cuda"] else mod_cpu
C = C_gpu if target in ["opencl", "cuda"] else C_cpu
# build and invoke the kernel.
f = tvm.compile(mod, target=target)
dev = tvm.device(target, 0)
# launch the kernel.
n = nn
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)
f(a, c)
tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1)
check_target("llvm")
check_target("opencl")
check_target("cuda")
打包 buffer#
def extern_generator(ins, outs):
"""Manually write the IR for the extern function, add pipeline."""
return tvm.tir.call_packed("my_extern_array_func1", ins[0], outs[0])
nn = 1024
n = tvm.runtime.convert(nn)
A = te.placeholder((n,), name="A")
C = te.extern(A.shape, [A], extern_generator, name="C")
# Create IRModule directly
mod = tvm.IRModule.from_expr(te.create_prim_func([A, C]))
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
A = T.match_buffer(var_A, (1024,), offset_factor=1)
C = T.match_buffer(var_C, (1024,), offset_factor=1)
with T.block("C"):
T.reads()
T.writes()
T.call_packed("my_extern_array_func1", T.tvm_stack_make_array(A.data, T.tvm_stack_make_shape(1024), 0, 1, T.float32(0.0), A.elem_offset), T.tvm_stack_make_array(C.data, T.tvm_stack_make_shape(1024), 0, 1, T.float32(0.0), C.elem_offset))
@tvm.register_func
def my_extern_array_func1(aa, bb):
aa.copyto(bb)
def check_target(target):
if not tvm.testing.device_enabled(target):
return
# build and invoke the kernel.
f = tvm.compile(mod, target=target)
dev = tvm.cpu(0)
# launch the kernel.
n = nn
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)
f(a, c)
tvm.testing.assert_allclose(c.numpy(), a.numpy())
check_target("llvm")
打包缓冲区中间表示#
def extern_generator(ins, outs):
"""Manually write the IR for the extern function, add pipeline."""
return tvm.tir.call_packed("my_extern_array_func2", *ins, outs[0])
nn = 1024
n = tvm.runtime.convert(nn)
A = te.placeholder((n,), name="A")
B = te.compute((n,), lambda i: A[i] + 1, name="B")
C = te.extern(B.shape, [B], extern_generator, name="C")
# D = te.compute((n,), lambda i: C[i] + 1, name="D")
mod = tvm.IRModule.from_expr(te.create_prim_func([A, C]))
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1024,), "float32"), var_C: T.handle):
T.func_attr({"tir.noalias": True})
C = T.match_buffer(var_C, (1024,), offset_factor=1)
# with T.block("root"):
B = T.alloc_buffer((1024,))
for i in range(1024):
with T.block("B"):
v_i = T.axis.spatial(1024, i)
T.reads(A[v_i])
T.writes(B[v_i])
B[v_i] = A[v_i] + T.float32(1.0)
with T.block("C"):
T.reads()
T.writes()
elem_offset = T.int32()
T.call_packed("my_extern_array_func2", T.tvm_stack_make_array(B.data, T.tvm_stack_make_shape(1024), 0, 1, T.float32(0.0), elem_offset), T.tvm_stack_make_array(C.data, T.tvm_stack_make_shape(1024), 0, 1, T.float32(0.0), C.elem_offset))
def check_target(target):
if not tvm.testing.device_enabled(target):
return
# build and invoke the kernel.
f = tvm.compile(mod, target=target)
dev = tvm.cpu(0)
# launch the kernel.
n = nn
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)
@tvm.register_func
def my_extern_array_func2(aa, cc):
assert aa.shape == a.shape
tvm.testing.assert_allclose(aa.numpy(), a.numpy()+1)
aa.copyto(cc)
f(a, c)
tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1)
check_target("llvm")
---------------------------------------------------------------------------
InternalError Traceback (most recent call last)
Cell In[17], line 1
----> 1 check_target("llvm")
Cell In[16], line 5, in check_target(target)
3 return
4 # build and invoke the kernel.
----> 5 f = tvm.compile(mod, target=target)
6 dev = tvm.cpu(0)
7 # launch the kernel.
File /media/pc/data/lxw/ai/tvm/python/tvm/driver/build_module.py:110, in compile(mod, target, relax_pipeline, tir_pipeline)
103 if _contains_relax(mod):
104 return tvm.relax.build(
105 mod,
106 target,
107 relax_pipeline=relax_pipeline,
108 tir_pipeline=tir_pipeline,
109 )
--> 110 lib = tvm.tir.build(mod, target, pipeline=tir_pipeline)
111 return Executable(lib)
File /media/pc/data/lxw/ai/tvm/python/tvm/tir/build.py:173, in build(mod, target, pipeline)
170 else:
171 # default pipeline depends on the target
172 pipeline = tvm.tir.get_default_tir_pipeline(target)
--> 173 mod = pipeline(mod)
175 # Step 5: Get host and device modules
176 host_mod, device_mod_dict = split_host_device_mods(mod)
File /media/pc/data/lxw/ai/tvm/python/tvm/ir/transform.py:238, in Pass.__call__(self, mod)
224 def __call__(self, mod):
225 """Execute the pass. Note that for sequential pass, the dependency among
226 different passes will be resolved in the backend.
227
(...)
236 The updated module after applying this pass.
237 """
--> 238 return _ffi_transform_api.RunPass(self, mod)
File /media/pc/data/lxw/ai/tvm/python/tvm/ffi/cython/function.pxi:228, in tvm.ffi.core.Function.__call__()
File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:576, in operator()()
574
575 TVM_REGISTER_GLOBAL("transform.RunPass")
--> 576 .set_body_typed([](Pass pass, ffi::RValueRef<IRModule> mod) { return pass(*std::move(mod)); });
577
578 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:297, in tvm::transform::Pass::operator()(tvm::IRModule) const()
295
296 IRModule Pass::operator()(IRModule mod) const {
--> 297 return this->operator()(std::move(mod), PassContext::Current());
298 }
299
File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:313, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const()
311 ret = Pass::AssertImmutableModule(mod, node, pass_ctx);
312 } else {
--> 313 ret = node->operator()(std::move(mod), pass_ctx);
314 }
315 pass_ctx.InstrumentAfterPass(ret, pass_info);
File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:419, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const()
417 VLOG(0) << "Executing module pass with opt level: " << pass_info->opt_level;
418
--> 419 mod = pass_func(std::move(mod), pass_ctx);
420
421 ICHECK(mod.defined()) << "The return value of a module pass must be set.";
File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:570, in operator()()
568 PassInfo pass_info) {
569 auto wrapped_pass_func = [pass_func](IRModule mod, PassContext ctx) {
--> 570 return pass_func(ffi::RValueRef<IRModule>(std::move(mod)), ctx);
571 };
572 return ModulePass(wrapped_pass_func, pass_info);
File /media/pc/data/lxw/ai/tvm/python/tvm/ffi/cython/function.pxi:281, in tvm.ffi.core.tvm_ffi_callback()
File /media/pc/data/lxw/ai/tvm/python/tvm/tir/pipeline.py:122, in _pipeline()
109 passes.append(tir.transform.InjectPTXLDG32())
110 passes.extend(
111 [
112 tir.transform.AnnotateDeviceRegions(),
(...)
120 ]
121 )
--> 122 mod = tvm.ir.transform.Sequential(passes)(mod)
123 return mod
File /media/pc/data/lxw/ai/tvm/python/tvm/ir/transform.py:238, in __call__()
224 def __call__(self, mod):
225 """Execute the pass. Note that for sequential pass, the dependency among
226 different passes will be resolved in the backend.
227
(...)
236 The updated module after applying this pass.
237 """
--> 238 return _ffi_transform_api.RunPass(self, mod)
File /media/pc/data/lxw/ai/tvm/python/tvm/ffi/cython/function.pxi:228, in tvm.ffi.core.Function.__call__()
File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:576, in operator()()
574
575 TVM_REGISTER_GLOBAL("transform.RunPass")
--> 576 .set_body_typed([](Pass pass, ffi::RValueRef<IRModule> mod) { return pass(*std::move(mod)); });
577
578 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:297, in tvm::transform::Pass::operator()(tvm::IRModule) const()
295
296 IRModule Pass::operator()(IRModule mod) const {
--> 297 return this->operator()(std::move(mod), PassContext::Current());
298 }
299
File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:313, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const()
311 ret = Pass::AssertImmutableModule(mod, node, pass_ctx);
312 } else {
--> 313 ret = node->operator()(std::move(mod), pass_ctx);
314 }
315 pass_ctx.InstrumentAfterPass(ret, pass_info);
File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:521, in tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const()
519
520 } else {
--> 521 mod = pass(std::move(mod), pass_ctx);
522 }
523 }
File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:313, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const()
311 ret = Pass::AssertImmutableModule(mod, node, pass_ctx);
312 } else {
--> 313 ret = node->operator()(std::move(mod), pass_ctx);
314 }
315 pass_ctx.InstrumentAfterPass(ret, pass_info);
File /media/pc/data/lxw/ai/tvm/src/ir/transform.cc:419, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const()
417 VLOG(0) << "Executing module pass with opt level: " << pass_info->opt_level;
418
--> 419 mod = pass_func(std::move(mod), pass_ctx);
420
421 ICHECK(mod.defined()) << "The return value of a module pass must be set.";
File /media/pc/data/lxw/ai/tvm/src/tir/transforms/make_packed_api.cc:424, in operator()()
422 }
423
--> 424 func = MakePackedAPI(std::move(func));
425
426 if (!func.same_as(orig_func)) {
File /media/pc/data/lxw/ai/tvm/src/tir/transforms/make_packed_api.cc:387, in tvm::tir::MakePackedAPI(tvm::tir::PrimFunc)()
385
386 Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
--> 387 ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined
388 << " are used, but are not passed in as API arguments";
389
InternalError: Check failed: undefined.size() == 0 (1 vs. 0) : In PrimFunc main variables [elem_offset] are used, but are not passed in as API arguments