解读 tvm.tir.trace()
#
import set_env
import tvm
from tvm import te
import numpy as np
tvm.tir.trace??
Signature: tvm.tir.trace(args, trace_action='tvm.default_trace_action')
Source:
def trace(args, trace_action="tvm.default_trace_action"):
"""Trace tensor data at the runtime.
The trace function allows to trace specific tensor at the
runtime. The tracing value should come as last argument.
The trace action should be specified, by default
tvm.default_trace_action is used.
Parameters
----------
args : list of Expr or Buffers.
Positional arguments.
trace_action : str.
The name of the trace action.
Returns
-------
call : PrimExpr
The call expression.
See Also
--------
tvm.tir.call_packed : Creates packed function.
"""
if not isinstance(args, list):
raise Exception("tvm.tir.trace consumes the args as list type")
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
call_args.insert(0, trace_action)
return tvm.tir.Call(args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args)
File: /media/pc/data/lxw/ai/tvm/python/tvm/tir/op.py
Type: function
tvm.tir.trace
-> trace_action
默认值为:'tvm.default_trace_action'
:
n = 2
x = te.placeholder((n, n, n), name="X", dtype="float32")
y = te.compute(x.shape, lambda i, j, k: tvm.tir.trace([i, j, k, x[i][j][k]]))
s = te.create_schedule(y.op)
f = tvm.build(s, [x, y], target="llvm")
xnd = tvm.nd.array(np.ones((n, n, n), dtype=x.dtype))
ynd = tvm.nd.array(np.zeros((n, n, n), dtype=y.dtype))
f(xnd, ynd)
[0, 0, 0, 1.0]
[0, 0, 1, 1.0]
[0, 1, 0, 1.0]
[0, 1, 1, 1.0]
[1, 0, 0, 1.0]
[1, 0, 1, 1.0]
[1, 1, 0, 1.0]
[1, 1, 1, 1.0]
trace expr 赋值#
@tvm.register_func("tvm.tir.trace_callback2")
def trace_buffer(x):
return
def check_assign(dtype):
n = 4
x = te.placeholder((n, n, n), name="X", dtype=dtype)
y = te.compute(
x.shape, lambda i, j, k: tvm.tir.trace([x[i][j][k]], "tvm.tir.trace_callback2")
)
z = te.compute(
x.shape, lambda i, j, k: tvm.tir.trace([y[i][j][k]], "tvm.tir.trace_callback2")
)
s = te.create_schedule(z.op)
f = tvm.build(s, [x, y, z], "llvm")
xnd = tvm.nd.array(np.ones((n, n, n), dtype=x.dtype))
ynd = tvm.nd.array(np.zeros((n, n, n), dtype=y.dtype))
znd = tvm.nd.array(np.zeros((n, n, n), dtype=z.dtype))
f(xnd, ynd, znd)
assert np.array_equal(xnd.numpy(), np.ones((n, n, n)))
assert np.array_equal(ynd.numpy(), np.ones((n, n, n)))
assert np.array_equal(znd.numpy(), np.ones((n, n, n)))
for t in ["float64", "float32", "int64", "int32"]:
check_assign(t)
trace expr 加法#
@tvm.register_func("tvm.tir.trace_callback3")
def trace_buffer(x):
return
def check_expr_sum(dtype):
n = 4
a = te.placeholder((n, n, n), name="a", dtype=dtype)
b = te.placeholder((n, n, n), name="b", dtype=dtype)
c = te.compute(
a.shape,
lambda i, j, k: tvm.tir.trace([a[i][j][k]], "tvm.tir.trace_callback3")
+ tvm.tir.trace([b[i][j][k]], "tvm.tir.trace_callback3"),
)
s = te.create_schedule(c.op)
f = tvm.build(s, [a, b, c])
xnd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=a.dtype)))
ynd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=b.dtype)))
znd = tvm.nd.array(np.zeros((n, n, n), dtype=c.dtype))
f(xnd, ynd, znd)
assert np.array_equal(znd.numpy(), xnd.numpy() + ynd.numpy())
for t in ["float64", "float32", "int64", "int32"]:
check_expr_sum(t)
def test_trace_expr_sum_args():
@tvm.register_func("tvm.tir.trace_silent")
def silent(*args):
return
def check_expr_sum(dtype):
n = 4
a = te.placeholder((n, n, n), name="a", dtype=dtype)
b = te.placeholder((n, n, n), name="b", dtype=dtype)
e = te.placeholder((n, n, n), name="e", dtype=dtype)
d = te.placeholder((n, n, n), name="d", dtype=dtype)
c = te.compute(
a.shape,
lambda i, j, k: tvm.tir.trace([i, j, k, a[i][j][k]], "tvm.tir.trace_silent")
+ tvm.tir.trace([i, j, k, b[i][j][k]], "tvm.tir.trace_silent")
+ tvm.tir.trace([i, j, k, d[i][j][k]], "tvm.tir.trace_silent")
+ tvm.tir.trace([i, j, k, e[i][j][k]], "tvm.tir.trace_silent"),
)
s = te.create_schedule(c.op)
f = tvm.build(s, [a, b, d, e, c])
a_nd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=a.dtype)))
b_nd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=b.dtype)))
d_nd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=d.dtype)))
e_nd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=e.dtype)))
c_nd = tvm.nd.array(np.zeros((n, n, n), dtype=c.dtype))
f(a_nd, b_nd, d_nd, e_nd, c_nd)
assert np.array_equal(
c_nd.numpy(), a_nd.numpy() + b_nd.numpy() + d_nd.numpy() + e_nd.numpy()
)
for t in ["float64", "float32", "int64", "int32"]:
check_expr_sum(t)
def test_trace_expr_sum_custom():
@tvm.register_func("tvm.tir.trace_callback4")
def trace_buffer(x):
return
def check_expr_sum_custom(dtype):
n = 4
a = te.placeholder((n, n), name="a", dtype=dtype)
b = te.placeholder((n, n), name="b", dtype=dtype)
c = te.compute(
a.shape,
lambda i, j: tvm.tir.trace([a[i][j]], "tvm.tir.trace_callback4")
+ tvm.tir.trace([b[i][j]], "tvm.tir.trace_callback4"),
)
s = te.create_schedule(c.op)
f = tvm.build(s, [a, b, c])
npa = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=a.dtype)
npb = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=a.dtype)
xnd = tvm.nd.array(npa)
ynd = tvm.nd.array(npb)
znd = tvm.nd.array(np.zeros((n, n), dtype=c.dtype))
f(xnd, ynd, znd)
assert np.array_equal(znd.numpy(), npa + npb)
for t in ["float64", "float32", "int64", "int32"]:
check_expr_sum_custom(t)
trace 常量表达式#
def test_trace_can_change_traced_value_int():
@tvm.register_func("tvm.tir.trace_change_int_first")
def trace_buffer(x):
return 13
@tvm.register_func("tvm.tir.trace_change_int_second")
def trace_buffer(x):
return 14
def check_assign(dtype):
n = 4
x = te.placeholder((n,), name="X", dtype=dtype)
y = te.compute(x.shape, lambda i: tvm.tir.trace([x[i]], "tvm.tir.trace_change_int_first"))
z = te.compute(x.shape, lambda i: tvm.tir.trace([y[i]], "tvm.tir.trace_change_int_second"))
s = te.create_schedule(z.op)
f = tvm.build(s, [x, y, z], "llvm")
xnd = tvm.nd.array(np.ones((n,), dtype=x.dtype))
ynd = tvm.nd.array(np.zeros((n,), dtype=y.dtype))
znd = tvm.nd.array(np.zeros((n,), dtype=z.dtype))
f(xnd, ynd, znd)
check_array_first = np.array([13, 13, 13, 13])
check_array_second = np.array([14, 14, 14, 14])
assert np.array_equal(ynd.numpy(), check_array_first)
assert np.array_equal(znd.numpy(), check_array_second)
for t in ["int64", "int32"]:
check_assign(t)
def test_trace_can_change_traced_value_float():
@tvm.register_func("tvm.tir.trace_change_float_first")
def trace_buffer(x):
return 13.0
@tvm.register_func("tvm.tir.trace_change_float_second")
def trace_buffer(x):
return 14.0
def check_assign(dtype):
n = 4
x = te.placeholder((n,), name="X", dtype=dtype)
y = te.compute(x.shape, lambda i: tvm.tir.trace([x[i]], "tvm.tir.trace_change_float_first"))
z = te.compute(
x.shape, lambda i: tvm.tir.trace([y[i]], "tvm.tir.trace_change_float_second")
)
s = te.create_schedule(z.op)
f = tvm.build(s, [x, y, z], "llvm")
xnd = tvm.nd.array(np.ones((n,), dtype=x.dtype))
ynd = tvm.nd.array(np.zeros((n,), dtype=y.dtype))
znd = tvm.nd.array(np.zeros((n,), dtype=z.dtype))
f(xnd, ynd, znd)
check_array_first = np.array([13.0, 13.0, 13.0, 13.0])
check_array_second = np.array([14.0, 14.0, 14.0, 14.0])
assert np.array_equal(ynd.numpy(), check_array_first)
assert np.array_equal(znd.numpy(), check_array_second)
for t in ["float64", "float32"]:
check_assign(t)