条件表达式#
%cd ../..
import set_env
import numpy as np
import tvm
from tvm import te
在 numpy 中使用 where()
处理数组的条件表达式:
a = np.arange(10)
np.where(a < 5, a, 10*a)
array([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90])
a = np.array([[0, 1, 2],
[0, 2, 4],
[0, 3, 6]])
np.where(a < 4, a, -1) # -1 被广播
array([[ 0, 1, 2],
[ 0, 2, -1],
[ 0, 3, -1]])
在 TVM 中使用 if_then_else
实现它。与 where()
类似,它接受三个参数,第一个是条件,如果为真返回第二个参数,否则返回第三个参数。
下面以实现上三角矩阵为例:
n, m = te.var('n'), te.var('m')
A = te.placeholder((m, n))
B = te.compute(A.shape,
lambda i, j: te.if_then_else(i >= j, A[i, j], 0.0))
te_func = te.create_prim_func([A, B])
te_func.show()
mod = tvm.build(te_func, target="llvm")
# from tvm.script import tir as T
@T.prim_func
def func(var_placeholder: T.handle, var_compute: T.handle):
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
n = T.var("int32")
placeholder = T.match_buffer(var_placeholder, [m, n], dtype="float32")
compute = T.match_buffer(var_compute, [m, n], dtype="float32")
# body
# with T.block("root")
for i0, i1 in T.grid(m, n):
with T.block("compute"):
i, j = T.axis.remap("SS", [i0, i1])
T.reads(placeholder[i, j])
T.writes(compute[i, j])
compute[i, j] = T.if_then_else(j <= i, placeholder[i, j], T.float32(0), dtype="float32")
a_np = np.arange(1, 13, dtype='float32').reshape((3, 4))
b_np = np.tril(a_np)
b_np
array([[ 1., 0., 0., 0.],
[ 5., 6., 0., 0.],
[ 9., 10., 11., 0.]], dtype=float32)
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(np.empty_like(a_np))
mod(a_nd, b_nd)
b_nd
<tvm.nd.NDArray shape=(3, 4), cpu(0)>
array([[ 1., 0., 0., 0.],
[ 5., 6., 0., 0.],
[ 9., 10., 11., 0.]], dtype=float32)