代码生成:静态初始化

代码生成:静态初始化#

import set_env
import tvm
from tvm import te
import ctypes
import numpy as np


def test_static_callback():
    dtype = "int64"
    n = te.size_var("n")
    Ab = tvm.tir.decl_buffer((n,), dtype)
    i = te.size_var("i")
    ib = tvm.tir.ir_builder.create()
    A = ib.buffer_ptr(Ab)
    cp = te.thread_axis((0, 1), "cop")
    finit = tvm.tir.StringImm("TVMBackendRunOnce")
    ib.scope_attr(cp, "coproc_uop_scope", finit)
    with ib.for_range(0, n, "i", kind="parallel") as i:
        A[i] = A[i] + 1
    stmt = ib.get()

    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp"))
    f = tvm.driver.build(mod, target="llvm")
    a = tvm.nd.array(np.zeros(10, dtype=dtype))
    f(a)
    f(a)
    np.testing.assert_equal(a.numpy(), np.ones(a.shape[0]))


def test_static_init():
    dtype = "int64"
    n = te.size_var("n")
    Ab = tvm.tir.decl_buffer((n,), dtype)
    i = te.size_var("i")
    ib = tvm.tir.ir_builder.create()
    handle = tvm.tir.call_intrin("handle", "tir.tvm_static_handle")
    ib.emit(tvm.tir.call_packed("test_static_callback", handle, Ab))

    @tvm.register_func("test_static_callback")
    def test_cb(sh, A):
        assert isinstance(sh, ctypes.c_void_p)
        return sh

    stmt = ib.get()
    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp"))
    f = tvm.driver.build(mod, target="llvm")
    a = tvm.nd.array(np.zeros(10, dtype=dtype))
    f(a)