Tensor Program Abstraction#

Before we dive into the details of TensorIR, let's first introduce what is a primitive tensor function. Primitive tensor functions are functions that correspond to a single "unit" of computational operation. For example, a convolution operation can be a primitive tensor function, and a fused convolution + relu operation can also be a primitive tensor function. Usually, a typical abstraction for primitive tensor function implementation contains the following elements: multi-dimensional buffers, loop nests that drive the tensor computations, and finally, the compute statements themselves.

from tvm.script import tir as T

@T.prim_func
def main(
    A: T.Buffer((128,), "float32"),
    B: T.Buffer((128,), "float32"),
    C: T.Buffer((128,), "float32"),
) -> None:
    for i in range(128):
        with T.block("C"):
            vi = T.axis.spatial(128, i)
            C[vi] = A[vi] + B[vi]

Key Elements of Tensor Programs#

The demonstrated primitive tensor function calculates the element-wise sum of two vectors. The function:

  • Accepts three multi-dimensional buffers as parameters, and generates one multi-dimensional buffer as output.

  • Incorporates a solitary loop nest i that facilitates the computation.

  • Features a singular compute statement that calculates the element-wise sum of the two vectors.

Extra Structure in TensorIR#

Crucially, we are unable to execute arbitrary transformations on the program, as certain computations rely on the loop's sequence. Fortunately, the majority of primitive tensor functions we focus on possess favorable properties, such as independence among loop iterations. For instance, the aforementioned program includes block and iteration annotations:

  • The block annotation with T.block("C") signifies that the block is the fundamental computation unit designated for scheduling. A block may encompass a single computation statement, multiple computation statements with loops, or opaque intrinsics such as Tensor Core instructions.

  • The iteration annotation T.axis.spatial, indicating that variable vi is mapped to i, and all iterations are independent.

While this information isn't crucial for executing the specific program, it proves useful when transforming the program. Consequently, we can confidently parallelize or reorder loops associated with vi, provided we traverse all the index elements from 0 to 128.