te tensor 计算

te tensor 计算#

import sys
from pathlib import Path
ROOT = Path(".").resolve().parents[2]
sys.path.extend([f"{ROOT}/tests", f"{ROOT}/src"])
# # from tools.tag_span import _create_span, _set_span, _verify_structural_equal_with_span
from tools.torch_utils import verify_model
import numpy as np
import tvm
from tvm import te
m = 1024
factor = 16
dtype = "float32"

def intrin_vadd(n):
    x = te.placeholder((n,))
    y = te.placeholder((n,))
    z = te.compute(x.shape, lambda i: x[i] + y[i])

    def intrin_func(ins, outs):
        ib = tvm.tir.ir_builder.create()
        ib.emit(
            tvm.tir.call_extern(
                outs[0].dtype,
                "vadd",
                ins[0].access_ptr("r"),
                ins[1].access_ptr("r"),
                outs[0].access_ptr("wr"),
            )
        )
        return ib.get()

    return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params={"offset_factor": n})

vadd = intrin_vadd(factor)

A = te.placeholder((m // factor, factor), name="A", dtype=dtype)
B = te.placeholder((m // factor, factor), name="B", dtype=dtype)
C = te.compute((m // factor, factor), lambda i: vadd(A[i, 0:factor], B[i, 0:factor]))

s = te.create_schedule(C.op)
# check lowering with the CSE pass disabled as otherwise it would do some commoning
with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]):
    stmt = tvm.lower(s, [A, B, C])["main"].body
assert isinstance(stmt.body, tvm.tir.Evaluate)
M = 2048
N = 1024
L = 1024
factor = 16
factor1 = 32
factor2 = 32
dtype = "float32"

def intrin_gemm(m, n, l):
    k = te.reduce_axis((0, l))
    x = te.placeholder((m, l))
    y = te.placeholder((n, l))
    # in theory, no relation
    z = te.compute((m, n), lambda i, j: te.sum(x[i][k] * y[j][k], axis=k))

    def intrin_func(ins, outs):
        x_ptr = ins[0].access_ptr("r")
        y_ptr = ins[1].access_ptr("r")
        z_ptr = outs[0].access_ptr("w")
        body = tvm.tir.call_packed("gemv", x_ptr, y_ptr, z_ptr, m, n, l)
        reset = tvm.tir.call_packed("fill_zero", z_ptr, m, n)
        update = tvm.tir.call_packed("gemv_add", x_ptr, y_ptr, z_ptr, m, n, l)
        return body, reset, update

    return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params={"offset_factor": n})

vgemm = intrin_gemm(factor1, factor2, factor)

A = te.placeholder((M // factor1, L // factor, factor1, factor), name="A", dtype=dtype)
B = te.placeholder((N // factor2, L // factor, factor2, factor), name="B", dtype=dtype)
k = te.reduce_axis((0, L // factor), name="k")
C = te.compute(
    (M // factor1, N // factor2, factor1, factor2),
    lambda i, j: vgemm(
        A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k
    ),
)

s = te.create_schedule(C.op)
# check lowering with the CSE pass disabled as otherwise it would do some commoning
with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]):
    stmt = tvm.lower(s, [A, B, C])["main"].body
assert isinstance(stmt.body.body[0], tvm.tir.Evaluate)
assert isinstance(stmt.body.body[1].body, tvm.tir.Evaluate)
def test_extern():
    m = te.size_var("m")
    A = te.placeholder((m,), name="A")

    def extern_func(ins, outs):
        assert isinstance(ins[0], tvm.te.schedule.Buffer)
        return tvm.tir.call_packed("myadd", ins[0].data, outs[0].data, m)

    B = te.extern((m,), [A], extern_func)
    assert tuple(B.shape) == (m,)


def test_extern_multi_out():
    m = te.size_var("m")
    A = te.placeholder((m,), name="A")
    B = te.compute((m,), lambda i: A[i] * 10)

    def extern_func(ins, outs):
        assert isinstance(ins[0], tvm.te.schedule.Buffer)
        return tvm.tir.call_packed("myadd", ins[0].data, outs[0].data, outs[1].data, m)

    res = te.extern([A.shape, A.shape], [A, B], extern_func)
    assert len(res) == 2
    assert res[1].value_index == 1
def test_tuple_inputs():
    m = te.size_var("m")
    n = te.size_var("n")
    A0 = te.placeholder((m, n), name="A0")
    A1 = te.placeholder((m, n), name="A1")
    T0, T1 = te.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name="T")
    s = te.create_schedule(T0.op)

    for i in range(len(T0.shape)):
        assert T0.shape[i] == T1.shape[i]
    assert T0.op == T1.op
    assert T0.value_index == 0
    assert T1.value_index == 1


def test_tuple_with_different_deps():
    m = te.size_var("m")
    n = te.size_var("n")
    A0 = te.placeholder((m, n), name="A1")
    A1 = te.placeholder((m, n), name="A2")
    B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name="B")
    C = te.compute((m, n), lambda i, j: B0[i, j] + 4, name="C")

    s = te.create_schedule(C.op)
    xo, xi = s[C].split(C.op.axis[0], factor=10)
    s[B0.op].compute_at(s[C], xo)
    sch = s.normalize()
    bounds = tvm.te.schedule.InferBound(sch)
    stmt = tvm.te.schedule.ScheduleOps(sch, bounds)

    def get_B1_realize(x):
        if (
            isinstance(x, tvm.tir.ProducerRealize)
            and x.producer.op == B1.op
            and x.producer.value_index == 1
        ):
            ret.append(x)

    ret = []
    tvm.tir.stmt_functor.post_order_visit(stmt, get_B1_realize)

    assert stmt.producer == C and len(ret) == 1


def test_tensor_pool():
    def intrin_pool():
        A = te.placeholder((64, 16, 16), name="A")
        kh = te.reduce_axis((0, 3), name="kh")
        kw = te.reduce_axis((0, 3), name="kw")
        P = te.compute(
            (64, 14, 14),
            lambda c, oh, ow: tvm.te.max(A[c, oh + kh, ow + kw], axis=[kh, kw]),
            name="p",
        )

        def intrin_func(ins, outs):
            dinp = ins[0]
            dout = outs[0]
            return tvm.tir.call_packed("op", dinp, dout)

        return te.decl_tensor_intrin(P.op, intrin_func, default_buffer_params={"offset_factor": 1})

    A = te.placeholder((1, 64, 16, 16), name="A")
    P = pool2d(
        data=A, kernel=(3, 3), stride=(1, 1), dilation=(1, 1), padding=(0, 0, 0, 0), pool_type="max"
    )
    s = te.create_schedule(P.op)
    _, oh, _, _ = P.op.axis
    intrin = intrin_pool()
    s[P].tensorize(oh, intrin)
    tvm.lower(s, [A, P])


def test_tensor_scalar_mixed():
    # test te with tensor and scalar
    a = np.array(np.random.uniform(size=(10,)), "float32")
    b = np.array(np.random.uniform(size=(1))[0], "float32")
    c = np.array(np.random.uniform(size=(10,)), "float32")

    @tvm.register_func("tvm.test_tensor_scalar_scale")
    def my_scale(tensor, scalar, out):
        out_np = tensor.numpy() * scalar.numpy()
        tvm.nd.array(out_np).copyto(out)

    A = te.placeholder(a.shape, name="A")
    B = te.placeholder(b.shape, name="B")
    C = te.extern(
        a.shape,
        [A, B],
        lambda ins, outs: tvm.tir.call_packed(
            "tvm.test_tensor_scalar_scale", ins[0], ins[1], outs[0]
        ),
        name="C",
    )
    s = te.create_schedule(C.op)
    f = tvm.build(s, [A, B, C], "llvm")

    ta = tvm.nd.array(a)
    tb = tvm.nd.array(b)
    tc = tvm.nd.array(c)
    f(ta, tb, tc)
    tvm.testing.assert_allclose(a * b, tc.numpy())


def test_tensor_scalar():
    # test te with scalar shape
    a = np.array(np.random.uniform(size=(1))[0], "float32")
    b = np.array(0.0, "float32")

    @tvm.register_func("tvm.test_tensor_scalar_copy")
    def mycopy(x, y):
        x.copyto(y)

    A = te.placeholder(a.shape, name="A")
    B = te.extern(
        a.shape,
        [A],
        lambda ins, outs: tvm.tir.call_packed("tvm.test_tensor_scalar_copy", ins[0], outs[0]),
        name="B",
    )
    s = te.create_schedule(B.op)
    f = tvm.build(s, [A, B], "llvm")

    ta = tvm.nd.array(a)
    tb = tvm.nd.array(b)
    f(ta, tb)
    tvm.testing.assert_allclose(ta.numpy(), tb.numpy())