StorageRewrite
#
参考:tvm/tests/python/tir-transform/test_tir_transform_storage_rewrite.py
import sys
from pathlib import Path
ROOT = Path(".").resolve().parents[2]
sys.path.extend([f"{ROOT}/tests", f"{ROOT}/src"])
# # from tools.tag_span import _create_span, _set_span, _verify_structural_equal_with_span
from tools.torch_utils import verify_model
import tvm
import tvm.testing
from tvm import te
from tvm.driver.build_module import schedule_to_module
from tvm.script import tir as T
def test_storage_share():
m = te.var("m")
l = te.var("l")
A = te.placeholder((m, l), name="A")
num_stage = 5
B = A
for t in range(num_stage):
B = te.compute((m, l), lambda i, j: B[i, j] + (t + 1), name="A%d" % t)
s = te.create_schedule(B.op)
mod = schedule_to_module(s, [A, B])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
# verify only have one allocations.
# verify inplace folding works
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert num_alloc[0] == 1
def register_mem(scope_tb, max_bits):
# Register mem
@tvm.register_func("tvm.info.mem.%s" % scope_tb)
def mem_info_inp_buffer():
return tvm.ir.make_node(
"MemoryInfo", unit_bits=16, max_simd_bits=32, max_num_bits=max_bits, head_address=None
)
def test_alloc_seq():
scope_tb = "local.L0A"
max_bits = 1024 * 1024 * 1024
register_mem(scope_tb, max_bits)
ib = tvm.tir.ir_builder.create()
n = te.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="A", scope=scope_tb)
A[j] = 1.2
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="B", scope=scope_tb)
A[j] = 1.3
body = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
assert n.extents[0].value == 200
tvm.tir.stmt_functor.post_order_visit(body, verify)
assert num_alloc[0] == 1
def test_alloc_different_dtypes():
def stmt_generater(dtype_list, length):
ib = tvm.tir.ir_builder.create()
base_dtype = dtype_list[0]
global_a = te.placeholder((length,), name="global_a", dtype=base_dtype)
assert len(dtype_list) == 4
with ib.for_range(0, length, name="j") as j:
dtype = dtype_list[0]
A = ib.allocate(dtype, length, name="A", scope="local.L0A")
A[j] = tvm.tir.const(1, dtype=dtype)
with ib.for_range(0, length, name="j") as j:
dtype = dtype_list[1]
B = ib.allocate(dtype, length, name="B", scope="local.L0A")
B[j] = tvm.tir.const(1, dtype=dtype)
with ib.for_range(0, length, name="j") as j:
dtype = dtype_list[2]
C = ib.allocate(dtype, length, name="C", scope="local.L0A")
C[j] = tvm.tir.const(1, dtype=dtype)
with ib.for_range(0, length, name="j") as j:
dtype = dtype_list[3]
D = ib.allocate(dtype, length, name="D", scope="local.L0A")
D[j] = tvm.tir.const(1, dtype=dtype)
with ib.for_range(0, length, name="j") as j:
dtype = "int8"
E = ib.allocate(dtype, length, name="E", scope="local.L0A")
E[j] = A[j].astype(dtype) + B[j].astype(dtype) + C[j].astype(dtype) + D[j].astype(dtype)
return ib.get()
def dtype_bit_len(dtype):
index = 0
for i in dtype:
if i.isdigit():
break
index += 1
return int(dtype[index:])
def offset_generater(dtype_list, length):
dtype_len_list = [dtype_bit_len(i) for i in dtype_list]
base_len = dtype_len_list[0]
return sum([i * length / base_len for i in dtype_len_list])
def dtype_test(dtype_list, length):
def verify(n):
if isinstance(n, tvm.tir.Allocate):
assert n.extents[0].value == offset
body = stmt_generater(dtype_list, length)
offset = offset_generater(dtype_list, length)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
tvm.tir.stmt_functor.post_order_visit(body, verify)
length = 1024
dtype_list = ["float16", "int32", "uint16", "int8"]
dtype_test(dtype_list, length)
dtype_list = ["float32", "int32", "uint16", "int8"]
dtype_test(dtype_list, length)
dtype_list = ["float64", "int32", "uint16", "int8"]
dtype_test(dtype_list, length)
dtype_list = ["int8", "int32", "uint16", "uint8"]
dtype_test(dtype_list, length)
def test_inplace_rule():
m = 10
A = te.placeholder((m,), name="A")
A0 = te.compute((m,), lambda i: A[i], name="A0")
A1 = te.compute((m,), lambda i: A[i] + 1, name="A1")
AA = te.compute((m,), lambda i: A0[i] + A1[i] + A1[0], name="AA")
B = te.compute((m,), lambda i: AA[i] + 1, name="B")
s = te.create_schedule(B.op)
mod = schedule_to_module(s, [A, B])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
# verify only have one allocations.
# verify inplace folding works
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert num_alloc[0] == 2
def test_storage_combine():
n = 8
A = te.placeholder((4,), name="A")
num_stage = 5
B = A
stages = []
for t in range(num_stage):
B = te.compute((n,), lambda i: B[i] + B[0] + (t + 1), name="A%d" % t)
stages.append(B)
s = te.create_schedule(B.op)
for S in stages[:-1]:
s[S].set_scope("global:tag")
mod = schedule_to_module(s, [A, B])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
assert n.extents[0].value == 16
tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert num_alloc[0] == 1
def test_storage_combine_with_vectorization():
n = 1024
A = te.placeholder((n,), name="A")
B = te.placeholder((n,), name="B")
C = te.compute((n,), lambda i: A[i] + B[i], name="C")
s = te.create_schedule(C.op)
AA = s.cache_read(A, "global:tag", readers=[C])
BB = s.cache_read(B, "global:tag", readers=[C])
CC = s.cache_write(C, "global:tag")
s[CC].vectorize(s[CC].op.axis[0])
mod = schedule_to_module(s, [A, B, C])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.VectorizeLoop()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
mod = tvm.tir.transform.Simplify()(mod)
stmt = mod["main"].body
num_alloc = [0]
def verify(v):
# find add op
if (
isinstance(v, tvm.tir.Add)
and isinstance(v.a, tvm.tir.BufferLoad)
and isinstance(v.b, tvm.tir.BufferLoad)
):
lhs_ramp = v.a.indices[0]
rhs_ramp = v.b.indices[0]
# these two ramp load should not overlap
assert lhs_ramp.lanes == n
assert rhs_ramp.lanes == n
assert lhs_ramp.base >= rhs_ramp.base + n or rhs_ramp.base >= lhs_ramp.base + n
elif isinstance(v, tvm.tir.Allocate):
num_alloc[0] += 1
tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert num_alloc[0] == 1
def test_address_of():
# In this test, the storage rewrite pass is allowed to
# combine buffers B and D, but not C
@T.prim_func
def before(A: T.Buffer(8, "float32"), E: T.Buffer(8, "float32")):
B_data = T.allocate([8], "float32")
B = T.Buffer(8, data=B_data, align=32)
for i in range(8):
B[i] = (
T.call_extern("deref", T.address_of(A[i]), dtype="float32")
+ T.call_extern("deref", T.address_of(A[0]), dtype="float32")
+ T.float32(1)
)
C_data = T.allocate([8], "float32")
C = T.Buffer(8, data=C_data, align=32)
for i in range(8):
C[i] = (
T.call_extern("deref", T.address_of(B[i]), dtype="float32")
+ T.call_extern("deref", T.address_of(B[0]), dtype="float32")
+ T.float32(2)
)
D_data = T.allocate([8], "float32")
D = T.Buffer(8, data=D_data, align=32)
for i in range(8):
D[i] = (
T.call_extern("deref", T.address_of(C[i]), dtype="float32")
+ T.call_extern("deref", T.address_of(C[0]), dtype="float32")
+ T.float32(2)
)
for i in range(8):
E[i] = (
T.call_extern("deref", T.address_of(D[i]), dtype="float32")
+ T.call_extern("deref", T.address_of(D[0]), dtype="float32")
+ T.float32(3)
)
def verify(n):
if isinstance(n, tvm.tir.Allocate):
total_alloc[0] += n.extents[0].value
total_alloc = [0]
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod.show()
tvm.tir.stmt_functor.post_order_visit(mod["main"].body, verify)
assert total_alloc[0] == 24
total_alloc[0] = 0
mod = tvm.tir.transform.StorageRewrite()(mod)
mod.show()
tvm.tir.stmt_functor.post_order_visit(mod["main"].body, verify)
assert total_alloc[0] == 16
def test_storage_share_gpu():
m = te.var("m")
A = [te.placeholder((m), name="A")]
num_stage = 5
for t in range(num_stage):
A.append(te.compute((m,), lambda i: A[-1][i] + (t + 1), name="A%d_s" % t))
A.append(te.compute((m,), lambda i: A[-1][i], name="A%d" % t))
s = te.create_schedule(A[-1].op)
for t in range(num_stage):
x = A[2 * t + 2].op.axis[0]
bx, tx = s[A[2 * t + 2]].split(x, factor=32)
s[A[2 * t + 2]].bind(bx, te.thread_axis("blockIdx.x"))
s[A[2 * t + 2]].bind(tx, te.thread_axis("threadIdx.x"))
s[A[2 * t + 1]].compute_at(s[A[2 * t + 2]], tx)
s[A[2 * t + 1]].set_scope("shared")
mod = schedule_to_module(s, [A[0], A[-1]])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
alloc_stats = {"global": 0, "shared": 0}
def verify(n):
if isinstance(n, tvm.tir.Allocate):
scope = n.buffer_var.type_annotation.storage_scope
alloc_stats[scope] += 1
tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert alloc_stats["global"] == 2
assert alloc_stats["shared"] == num_stage
def test_parallel_alloc():
ib = tvm.tir.ir_builder.create()
n = te.var("n")
with ib.for_range(0, n, name="i", kind="parallel") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", n, name="A", scope="global")
A[j] = A[j] + 2
body = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"]
assert isinstance(body.body.body, tvm.tir.Allocate)
ib = tvm.tir.ir_builder.create()
n = te.var("n")
with ib.for_range(0, n, name="t") as i:
ib.scope_attr(
tvm.tir.const(1, "int32"), "pragma_scope", tvm.tir.StringImm("parallel_launch_point")
)
with ib.for_range(0, n, name="i", kind="parallel") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", n, name="A", scope="global")
A[j] = A[j] + 2
body = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"]
assert isinstance(body.body.body.body.body, tvm.tir.Allocate)
def test_while_alloc():
def get_mod(kind="serial"):
ib = tvm.tir.ir_builder.create()
n = te.var("n")
with ib.for_range(0, n, name="i", kind=kind) as i:
j = ib.allocate("int32", 1, name="j", scope="global")
j[0] = 0
with ib.while_loop(j[0] < 10):
A = ib.allocate("float32", n, name="A", scope="global")
A[j[0]] = A[j[0]] + 2
j[0] += j[0] + 1
body = ib.get()
return tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
mod = get_mod(kind="parallel")
# parallel (i, 0, n) {
# allocate j[int32 * 1]
# j[0] = 0
# while((j[0] < 10)){
# // attr [A] storage_scope = "global"
# allocate A[float32 * n]
# A[j[0]] = (A[j[0]] + 2f)
# j[0] = (j[0] + (j[0] + 1))
# }
# }
body = tvm.tir.transform.StorageRewrite()(mod)["main"]
# parallel (i, 0, n) {
# allocate j[int32 * 1]
# allocate A[float32 * n]
# j[0] = 0
# while((j[0] < 10)){
# A[j[0]] = (A[j[0]] + 2f)
# j[0] = (j[0] + (j[0] + 1))
# }
# }
assert isinstance(body.body.body, tvm.tir.Allocate) # j
assert isinstance(body.body.body.body, tvm.tir.Allocate) # A
mod = get_mod(kind="serial")
# for (i, 0, n) {
# allocate j[int32 * 1]
# j[0] = 0
# while((j[0] < 10)){
# // attr [A] storage_scope = "global"
# allocate A[float32 * n]
# A[j[0]] = (A[j[0]] + 2f)
# j[0] = (j[0] + (j[0] + 1))
# }
# }
body = tvm.tir.transform.StorageRewrite()(mod)["main"]
# allocate j[int32 * 1]
# allocate A[float32 * n]
# for (i, 0, n) {
# j[0] = 0
# while((j[0] < 10)){
# A[j[0]] = (A[j[0]] + 2f)
# j[0] = (j[0] + (j[0] + 1))
# }
# }
assert isinstance(body.body, tvm.tir.Allocate) # j
assert isinstance(body.body.body, tvm.tir.Allocate) # A
def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024):
# Test Buffer
register_mem(scope_tb, max_bits)
m = 10
A = te.placeholder((m,), name="A")
C = te.placeholder((m,), name="C")
D = te.placeholder((m,), name="D")
A0 = te.compute((m,), lambda i: A[i] + C[i], name="A0")
A1 = te.compute((m,), lambda i: D[i] * D[i], name="A1")
A2 = te.compute((m,), lambda i: A0[i] + A1[i], name="A2")
B = te.compute((m,), lambda i: A2[i], name="B")
s = te.create_schedule(B.op)
A0L = s.cache_read(A0, scope_tb, [A2])
A1L = s.cache_read(A1, scope_tb, [A2])
A2L = s.cache_read(A2, scope_tb, [B])
mod = schedule_to_module(s, [A, B, C, D])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
# verify only have one allocations.
# verify inplace folding works
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert num_alloc[0] == 2
def test_exceed_mem():
max_bits = 639
# The critical max_num_bits is between 639 and 640
loc = -1
try:
test_inplace_rule2("local_TEM", max_bits)
except Exception as e:
estr = str(e)
loc = estr.find("Allocation exceed bound of memory")
assert loc != -1
def test_inplace_rule3():
# Test Buffer
scope_tb = "local_TB3"
max_bits = 1024 * 1024 * 1024
register_mem(scope_tb, max_bits)
m = 10
B0 = te.placeholder((m,), name="B0")
B1 = te.placeholder((m,), name="B1")
B2 = te.placeholder((m,), name="B2")
B3 = te.placeholder((m,), name="B3")
B4 = te.placeholder((m,), name="B4")
B5 = te.placeholder((m,), name="B5")
B6 = te.compute((m,), lambda i: B1[i] * B5[i], name="B6")
B7 = te.compute((m,), lambda i: B2[i] * B4[i], name="B7")
B8 = te.compute((m,), lambda i: B6[i] - B7[i], name="B8")
B9 = te.compute((m,), lambda i: B2[i] * B3[i], name="B9")
B10 = te.compute((m,), lambda i: B0[i] * B5[i], name="B10")
B11 = te.compute((m,), lambda i: B9[i] - B10[i], name="B11")
B12 = te.compute((m,), lambda i: B0[i] * B4[i], name="B12")
B13 = te.compute((m,), lambda i: B1[i] * B3[i], name="B13")
B14 = te.compute((m,), lambda i: B12[i] - B13[i], name="B14")
B = te.compute((m,), lambda i: B8[i] * B11[i] + B14[i], name="B")
s = te.create_schedule(B.op)
B1L = s.cache_read(B1, scope_tb, [B6, B13])
B5L = s.cache_read(B5, scope_tb, [B6, B10])
B2L = s.cache_read(B2, scope_tb, [B7, B9])
B4L = s.cache_read(B4, scope_tb, [B7, B12])
B3L = s.cache_read(B3, scope_tb, [B9, B13])
B0L = s.cache_read(B0, scope_tb, [B10, B12])
B8L = s.cache_write(B8, scope_tb)
B11L = s.cache_write(B11, scope_tb)
B14L = s.cache_write(B14, scope_tb)
B6L = s.cache_write(B6, scope_tb)
B7L = s.cache_write(B7, scope_tb)
B9L = s.cache_write(B9, scope_tb)
B10L = s.cache_write(B10, scope_tb)
B12L = s.cache_write(B12, scope_tb)
B13L = s.cache_write(B13, scope_tb)
s[B12].compute_inline()
s[B13].compute_inline()
s[B8].compute_inline()
s[B11].compute_inline()
s[B14].compute_inline()
s[B6].compute_inline()
s[B7].compute_inline()
s[B9].compute_inline()
s[B10].compute_inline()
s = s.normalize()
mod = schedule_to_module(s, [B0, B1, B2, B3, B4, B5, B])
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
# verify only have one allocations.
# verify inplace folding works
def verify(n):
if isinstance(n, tvm.tir.Allocate):
assert n.extents[0].value == 70
tvm.tir.stmt_functor.post_order_visit(stmt, verify)
def test_alloc_seq_type():
ib = tvm.tir.ir_builder.create()
n = te.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="A", scope="local.L0A")
A1 = ib.allocate("float32", 200, name="A1", scope="local.L0A")
A[j] = 1.2
A1[j] = 1.3
B = ib.allocate("int16", 200, name="B", scope="local.L0A")
B[j] = tvm.tir.const(1, "int16")
C = ib.allocate("int16", 200, name="C", scope="local.L0A")
C[j] = tvm.tir.const(1, "int16")
D = ib.allocate("int16", 200, name="D", scope="local.L0A")
D[j] = B[j] + C[j]
A2 = ib.allocate("float32", 200, name="A2", scope="local.L0A")
A2[j] = A[j]
body = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
assert n.extents[0].value == 500
tvm.tir.stmt_functor.post_order_visit(body, verify)
assert num_alloc[0] == 1
def test_alloc_seq_type2():
scope_tb = "local.L0A2"
max_bits = 1024 * 1024 * 1024
register_mem(scope_tb, max_bits)
ib = tvm.tir.ir_builder.create()
n = te.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="A", scope=scope_tb)
A[j] = 1.2
with ib.for_range(0, 20, name="j") as j:
B = ib.allocate("int16", 400, name="B", scope=scope_tb)
B[j] = tvm.tir.const(1, "int16")
with ib.for_range(0, 10, name="j") as j:
C = ib.allocate("float32", 200, name="C", scope=scope_tb)
C[j] = 1.2
body = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
assert n.extents[0].value == 200
tvm.tir.stmt_functor.post_order_visit(body, verify)
assert num_alloc[0] == 1
def test_reuse_small_buffer():
ib = tvm.tir.ir_builder.create()
n = te.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("int16", 200, name="A", scope="local.L0A")
A[j] = tvm.tir.const(1, "int16")
B = ib.allocate("int16", 200, name="B", scope="local.L0A")
B[j] = tvm.tir.const(1, "int16")
B1 = ib.allocate("int16", 200, name="B1", scope="local.L0A")
B1[j] = A[j] + B[j]
C = ib.allocate("int16", 400, name="C", scope="local.L0A")
C[j] = tvm.tir.const(1, "int16")
D = ib.allocate("int16", 400, name="D", scope="local.L0A")
D[j] = tvm.tir.const(1, "int16")
E = ib.allocate("int16", 400, name="E", scope="local.L0A")
E[j] = C[j]
body = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
assert n.extents[0].value == 800
tvm.tir.stmt_functor.post_order_visit(body, verify)
assert num_alloc[0] == 1
def test_replace_dataflow():
shape = (255,)
A = te.placeholder(shape, name="A")
B = te.compute(shape, lambda i: A[i] + A[i], name="B")
C = te.compute(shape, lambda i: A[i] + B[i], name="C")
D = te.compute(shape, lambda i: A[i] + C[i], name="D")
E = te.compute(shape, lambda i: A[i] + D[i], name="E")
s = te.create_schedule(E.op)
s.cache_read(A, "local", [B, C, D, E])
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
def test_large_input():
@te.hybrid.script
def compute(a, b):
n = 16384
c = output_tensor((n, n), "int32")
for i in range(n):
for j in range(n):
c[i, j] = a[i, j] - b[i, j]
return c
n = 16384
shape = (n, n)
a = te.placeholder(shape, name="a", dtype="int32")
b = te.placeholder(shape, name="b", dtype="int32")
c = te.compute(shape, lambda i, j: compute(a, b)[i, j])
c = te.compute(shape, lambda i, j: 1 + c[i, j])
s = te.create_schedule(c.op)
stmt = tvm.lower(s, [a, b, c])["main"].body
def verify(n):
if isinstance(n, tvm.tir.Allocate):
assert n.extents[0].value == 268435456
tvm.tir.stmt_functor.post_order_visit(stmt, verify)
def test_access_in_let_value():
@T.prim_func
def func(A: T.Buffer((8,), "float32")):
for i in range(8):
B_data = T.allocate((1,), "float32", "global")
B = T.Buffer(shape=[1], dtype="float32", data=B_data)
B[0] = 3.14
x: T.float32 = T.exp(B[0], dtype="float32")
A[i] = (x + 1.0) / (x - 1.0)
@T.prim_func
def func_rewritten(A: T.Buffer((8,), "float32")) -> None:
B_data = T.allocate((1,), "float32", "global")
B = T.Buffer(shape=[1], dtype="float32", data=B_data)
for i in range(8):
B[0] = 3.14
x: T.float32 = T.exp(B[0], dtype="float32")
A[i] = (x + 1.0) / (x - 1.0)
mod = tvm.tir.transform.StorageRewrite()(
tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
)
tvm.ir.assert_structural_equal(mod["main"], func_rewritten.with_attr("global_symbol", "main"))
class BaseCompare(tvm.testing.CompareBeforeAfter):
transform = tvm.tir.transform.StorageRewrite()
class TestLetBufferRewrite(BaseCompare):
"""StorageRewrite replaces the bound var of backing allocations
If StorageRewrite replaces the backing variable of an array, such
as when vectorizing the storage type, the variable must be
replaced in the LetStmt that defines it. Currently, StmtMutator
only visits usage of variables, and does not visit definitions of
variables, so the definition in a LetStmt must be explicitly
handled.
"""
def before() -> None:
A_data: T.handle("int32") = T.call_extern("dummy_func", dtype="handle")
A = T.Buffer([8], "int32", data=A_data)
A[0:8] = T.broadcast(42, 8)
def expected() -> None:
A_data: T.handle("int32x8") = T.call_extern("dummy_func", dtype="handle")
A = T.Buffer([1], "int32x8", data=A_data)
A[0] = T.broadcast(42, 8)
class TestRewriteInPlaceUseOfNonFlatBuffer(BaseCompare):
"""A non-flat buffer may be re-used for in-place operations"""
def before(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")):
B_data = T.allocate(
[16, 16],
dtype="float32",
scope="global",
)
B = T.Buffer(
[16, 16],
dtype="float32",
axis_separators=[1],
data=B_data,
)
C_data = T.allocate(
[16, 16],
dtype="float32",
scope="global",
)
C = T.Buffer(
[16, 16],
dtype="float32",
axis_separators=[1],
data=C_data,
)
for i, j in T.grid(16, 16):
B[i, j] = A[i, j]
for i, j in T.grid(16, 16):
C[i, j] = 2.0 * B[i, j]
for i, j in T.grid(16, 16):
D[i, j] = C[i, j]
def expected(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")):
B_data = T.allocate(
[16, 16],
dtype="float32",
scope="global",
)
B = T.Buffer([16, 16], dtype="float32", axis_separators=[1], data=B_data)
C = T.Buffer(
[16, 16],
dtype="float32",
axis_separators=[1],
data=B.data,
)
for i, j in T.grid(16, 16):
B[i, j] = A[i, j]
for i, j in T.grid(16, 16):
C[i, j] = 2.0 * B[i, j]
for i, j in T.grid(16, 16):
D[i, j] = C[i, j]
class TestNoRewriteOfSharedNonFlatBuffer(BaseCompare):
"""In general, sharing of non-flat buffer isn't supported
The current packing algorithms in StorageRewrite assume a flat
memory space, and do not support packing of N-d buffers. For
buffers with axis separators, normal buffer sharing should be
disabled.
Like TestRewriteInPlaceUseOfNonFlatBuffer, except that B and C do
not have matching shapes.
"""
def before(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")):
B_data = T.allocate(
[16, 16],
dtype="float32",
scope="global",
)
B = T.Buffer(
[16, 16],
dtype="float32",
axis_separators=[1],
data=B_data,
)
C_data = T.allocate(
[20, 20],
dtype="float32",
scope="global",
)
C = T.Buffer(
[20, 20],
dtype="float32",
axis_separators=[1],
data=C_data,
)
for i, j in T.grid(16, 16):
B[i, j] = A[i, j]
for i, j in T.grid(16, 16):
C[i, j] = 2.0 * B[i, j]
for i, j in T.grid(16, 16):
D[i, j] = C[i, j]
expected = before
class TestRewriteDeclBuffer(BaseCompare):
"""A DeclBuffer node may appear in StorageRewrite's input"""
def before(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
B = T.decl_buffer(16, dtype="float32")
C = T.decl_buffer(16, dtype="float32")
for i in range(16):
B[i] = A[i]
for i in range(16):
C[i] = 2.0 * B[i]
for i in range(16):
D[i] = C[i]
def expected(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
B = T.decl_buffer(16, dtype="float32")
C = T.decl_buffer(16, dtype="float32", data=B.data)
for i in range(16):
B[i] = A[i]
for i in range(16):
C[i] = 2.0 * B[i]
for i in range(16):
D[i] = C[i]
class TestNoOrphanedDeclBuffer(BaseCompare):
"""A DeclBuffer of an unused Allocate should be removed
StorageRewrite removes any allocations that are unused. When it
does so, any DeclBuffer that refers to that allocation should also
be removed.
"""
def before(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
B = T.decl_buffer(16, dtype="float32")
C = T.decl_buffer(16, dtype="float32")
Unused = T.decl_buffer(16, dtype="float32")
for i in range(16):
B[i] = A[i]
for i in range(16):
C[i] = 2.0 * B[i]
for i in range(16):
D[i] = C[i]
def expected(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
B = T.decl_buffer(16, dtype="float32")
C = T.decl_buffer(16, dtype="float32", data=B.data)
for i in range(16):
B[i] = A[i]
for i in range(16):
C[i] = 2.0 * B[i]
for i in range(16):
D[i] = C[i]
def test_vulkan_smem_reuse():
target = tvm.target.Target(
{
"keys": ["vulkan", "gpu"],
"kind": "vulkan",
"max_num_threads": 256,
"max_threads_per_block": 256,
"supports_float32": T.bool(True),
"supports_int32": T.bool(True),
"tag": "",
"thread_warp_size": 1,
}
)
@T.prim_func(private=True)
def func(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
A_shared = T.allocate([4], "float32", "shared")
A_local = T.allocate([4], "float32", "local")
B_shared = T.allocate([4], "float16", "shared")
A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_1 = T.Buffer((4,), data=A.data)
A_shared_1[threadIdx_x] = A_1[threadIdx_x]
A_local_1 = T.Buffer((4,), data=A_local, scope="local")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_local_1[threadIdx_x] = A_shared_1[threadIdx_x]
B_shared_1 = T.Buffer((4,), "float16", data=B_shared, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
B_shared_1[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x])
threadIdx_x = T.launch_thread("threadIdx.x", 4)
B_1 = T.Buffer((4,), "float16", data=B.data)
B_1[threadIdx_x] = B_shared_1[threadIdx_x]
@T.prim_func(private=True)
def normal_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
A_shared = T.allocate([4], "float32", "shared")
A_local = T.allocate([4], "float32", "local")
A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_1 = T.Buffer((4,), data=A.data)
A_shared_1[threadIdx_x] = A_1[threadIdx_x]
A_local_1 = T.Buffer((4,), data=A_local, scope="local")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_local_1[threadIdx_x] = A_shared_1[threadIdx_x]
A_shared_2 = T.Buffer((4,), "float16", data=A_shared, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_shared_2[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x])
threadIdx_x = T.launch_thread("threadIdx.x", 4)
B_1 = T.Buffer((4,), "float16", data=B.data)
B_1[threadIdx_x] = A_shared_2[threadIdx_x]
@T.prim_func(private=True)
def no_reuse_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")):
T.func_attr({"target": target, "tir.noalias": T.bool(True)})
A_shared_1 = T.allocate([4], "float32", "shared")
A_local_1 = T.allocate([4], "float32", "local")
B_shared_1 = T.allocate([4], "float16", "shared")
A_shared_1_1 = T.Buffer((4,), data=A_shared_1, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_1 = T.Buffer((4,), data=A.data)
A_shared_1_1[threadIdx_x] = A_1[threadIdx_x]
A_local_1_1 = T.Buffer((4,), data=A_local_1, scope="local")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_local_1_1[threadIdx_x] = A_shared_1_1[threadIdx_x]
B_shared_1_1 = T.Buffer((4,), "float16", data=B_shared_1, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
B_shared_1_1[threadIdx_x] = T.Cast("float16", A_local_1_1[threadIdx_x])
threadIdx_x = T.launch_thread("threadIdx.x", 4)
B_1 = T.Buffer((4,), "float16", data=B.data)
B_1[threadIdx_x] = B_shared_1_1[threadIdx_x]
# Reuse shared memory when lowering without target.
mod = tvm.IRModule({"main": func})
tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], normal_lowering)
# No shared memory reuse when lowering with target Vulkan.
mod = tvm.tir.transform.BindTarget(target)(mod)
tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], no_reuse_lowering)