Relax 构建块核心#
import tvm
import tvm.contrib.cblas
from tvm import te, tir, topi
from tvm import relax as rx
from tvm.ir.base import assert_structural_equal
from tvm.script import ir as I, relax as R, tir as T
from tvm.tir.function import PrimFunc
import pytest
测试 BlockBuilder 在处理函数定义时的错误情况#
验证了当函数定义中没有显式声明参数,但可能在函数体中隐式引用了外部变量(如 m, n, x, y)时,系统是否会抛出适当的错误。这是一种编译时错误检查机制,确保函数定义的完整性和正确性。
# 注册 nop
@tvm.register_func("test.blockbuilder.nop")
def nop():
...
from tvm.relax import ExternFunc
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
bb = rx.BlockBuilder()
with pytest.raises(RuntimeError):
with bb.function("func"):
gv0 = bb.emit(rx.Call(ExternFunc("test.blockbuilder.nop"), []))
bb.emit_func_output(gv0)
简单测试#
def test_block_builder():
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
bb = rx.BlockBuilder()
bb._begin_binding_block()
gv0 = bb.emit(rx.op.add(x, y))
bb._begin_dataflow_block()
lv0 = bb.emit(rx.op.multiply(gv0, y))
gv1 = bb.emit_output(rx.op.multiply(lv0, lv0))
b0 = bb._end_block()
bb._begin_dataflow_block()
lv1 = bb.emit(rx.op.multiply(gv0, y))
gv2 = bb.emit_output(rx.op.multiply(lv1, lv1))
b1 = bb._end_block()
gv3 = bb.emit(rx.op.add(x, y))
b2 = bb._end_block()
assert isinstance(b0, rx.DataflowBlock)
assert isinstance(b1, rx.DataflowBlock)
assert not isinstance(b2, rx.DataflowBlock)
test_block_builder()
变量定义:
m和n是 TIR(Tensor Intermediate Representation)变量,表示 int64 类型的标量尺寸x和y是 Relax 变量,分别表示形状为[m, n]和[n]的float16张量bb是rx.BlockBuilder()实例,用于构建 Relax 构建块。
块构建过程:
bb._begin_binding_block():开始绑定块(Binding Block)gv0 = bb.emit(rx.op.add(x, y)):创建全局变量gv0,表示x和y的加法运输bb._begin_dataflow_block():开始数据流块(Dataflow Block)lv0 = bb.emit(rx.op.multiply(gv0, y)):创建局部变量lv0,表示gv0和y的乘法gv1 = bb.emit_output(...):创建输出变量gv1,表示lv0的平方b0 = bb._end_block():结束当前块并返回块对象代码重复创建了另一个数据流块
b1和普通绑定块b2
测试目的:
验证BlockBuilder能够正确创建和区分不同类型的块
测试变量作用域管理(全局变量gv0、gv1、gv2、gv3和局部变量lv0、lv1)
验证数据流块和普通绑定块的创建机制
测试基本算子(add、multiply)的使用
关键概念:
Binding Block:普通绑定块,用于顺序执行的计算
Dataflow Block:数据流块,通常用于表示可以并行执行的计算,内部变量为局部变量
BlockBuilder:Relax中用于构建计算图的核心工具
这段代码展示了Relax模块中构建计算图的基本模式,包括块的创建、变量的定义和操作的执行,同时验证了BlockBuilder的正确性。
为构建块中变量指定名称#
def test_emit_with_name():
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
bb = rx.BlockBuilder()
bb._begin_dataflow_block()
lv0 = bb.emit(rx.op.add(x, y), "add")
gv0 = bb.emit_output(rx.op.multiply(lv0, y), "multi")
b0 = bb._end_block()
assert b0.bindings[0].var.name_hint == "add"
assert b0.bindings[1].var.name_hint == "multi"
test_emit_with_name()
测试 BlockBuilder 创建函数的功能#
def test_function_single_block():
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
bb = rx.BlockBuilder()
with bb.function("func", [x, y]):
with bb.dataflow():
lv0 = bb.emit(rx.op.add(x, y))
assert lv0.name_hint == "lv"
lv1 = bb.emit(rx.op.multiply(lv0, y))
assert lv1.name_hint == "lv1"
gv0 = bb.emit_output(lv1)
assert gv0.name_hint == "gv"
bb.emit_func_output(gv0)
func = bb.finalize()["func"]
assert func.params[0] == x
assert func.params[1] == y
assert func.body.body == gv0
assert_structural_equal(gv0.struct_info, rx.TensorStructInfo([m, n], "float16"))
assert len(func.body.blocks) == 1
assert len(func.body.blocks[0].bindings) == 3
函数创建与块构建:
with bb.function("func", [x, y]):创建名为"func"的函数,参数为x和ywith bb.dataflow():在函数内创建数据流块lv0 = bb.emit(rx.op.add(x, y)):在数据流块中创建局部变量lv0,表示x和y的加法lv1 = bb.emit(rx.op.multiply(lv0, y)):创建局部变量lv1,表示lv0和y的乘法gv0 = bb.emit_output(lv1):创建全局变量gv0,表示数据流块的输出
验证逻辑:
断言lv0的名称提示(name_hint)为"lv"
断言lv1的名称提示为"lv1"
断言gv0的名称提示为"gv"
func = bb.finalize()["func"]:获取最终的函数断言函数参数正确
断言函数体结构正确
断言gv0的结构信息正确
断言函数体包含1个块,且该块包含3个绑定
测试目的:
验证使用BlockBuilder创建函数的基本流程
测试函数参数的设置和传递
验证数据流块的创建和操作
测试变量名称提示的自动生成
验证函数输出的设置和获取
确认函数结构和绑定数量的正确性
关键概念:
function:Relax中的函数定义,包含参数和函数体
dataflow:数据流块,用于表示可以并行执行的计算
emit:向当前块中添加操作并返回生成的变量
emit_output:标记数据流块的输出变量
emit_func_output:设置函数的返回值
finalize:完成构建并返回创建的函数
这段代码全面测试了使用BlockBuilder创建包含单个数据流块的函数的各个环节,确保了函数创建、参数传递、块操作和输出设置等功能的正确性。
def test_function_multi_blocks():
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
bb = rx.BlockBuilder()
with bb.function("func", [x, y]):
with bb.dataflow():
lv0 = bb.emit(rx.op.add(x, y))
assert lv0.name_hint == "lv"
gv0 = bb.emit_output(lv0)
assert gv0.name_hint == "gv"
gv1 = bb.emit(rx.op.add(gv0, gv0))
assert gv1.name_hint == "gv1"
with bb.dataflow():
lv1 = bb.emit(rx.op.add(gv1, gv1))
assert lv1.name_hint == "lv1"
gv2 = bb.emit_output(gv1)
bb.emit_func_output(gv2)
func = bb.finalize()["func"]
assert_structural_equal(gv2.struct_info, rx.TensorStructInfo([m, n], "float16"))
assert func.params[0] == x
assert func.params[1] == y
assert func.body.body == gv2
assert len(func.body.blocks) == 3
assert len(func.body.blocks[0].bindings) == 2
assert len(func.body.blocks[1].bindings) == 1
assert len(func.body.blocks[2].bindings) == 2
def test_multi_functions():
bb = rx.BlockBuilder()
m_1 = tir.Var("m", "int64")
n_1 = tir.Var("n", "int64")
x_1 = rx.Var("x", rx.TensorStructInfo([m_1, n_1], "float16"))
y_1 = rx.Var("y", rx.TensorStructInfo([n_1], "float16"))
with bb.function("func1", [x_1, y_1]):
with bb.dataflow():
lv0 = bb.emit(rx.op.add(x_1, y_1))
assert lv0.name_hint == "lv"
gv0 = bb.emit_output(lv0)
bb.emit_func_output(gv0)
m_2 = tir.Var("m", "int64")
n_2 = tir.Var("n", "int64")
x_2 = rx.Var("x", rx.TensorStructInfo([m_2, n_2], "float16"))
y_2 = rx.Var("y", rx.TensorStructInfo([n_2], "float16"))
with bb.function("func2", [x_2, y_2]):
with bb.dataflow():
lv0 = bb.emit(rx.op.add(y_2, x_2))
# TODO(@yuchen): enable block builder to reset local var unique name map
assert lv0.name_hint == "lv1"
gv0 = bb.emit_output(lv0)
bb.emit_func_output(gv0)
mod = bb.finalize()
func1 = mod["func1"]
assert func1.params[0] == x_1
assert func1.params[1] == y_1
assert len(func1.body.blocks) == 1
func2 = mod["func2"]
assert func2.params[0] == x_2
assert func2.params[1] == y_2
assert len(func2.body.blocks) == 1
验证 Relax BlockBuilder 对二元算子(如加法、乘法)的 形状和类型推导 功能#
def test_binary_shape_type_deduction():
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
k = tir.Var("k", "int64")
x = rx.Var("x", rx.TensorStructInfo([m, 1], "float16"))
y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
z = rx.Var("z", rx.TensorStructInfo([5], "float16"))
w = rx.Var("w", rx.TensorStructInfo([k], "float16"))
bb = rx.BlockBuilder()
with bb.function("func", [x, y, z, w]):
with bb.dataflow():
lv0 = bb.emit(rx.op.add(x, y))
assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float16"))
lv1 = bb.emit(rx.op.multiply(x, z))
assert_structural_equal(lv1.struct_info, rx.TensorStructInfo([m, 5], "float16"))
lv2 = bb.emit(rx.op.multiply(z, w))
assert isinstance(lv2.struct_info, rx.TensorStructInfo)
assert lv2.struct_info.ndim == 1
assert lv2.struct_info.dtype == "float16"
lv3 = bb.emit(rx.op.multiply(y, w))
assert isinstance(lv3.struct_info, rx.TensorStructInfo)
assert lv3.struct_info.ndim == 1
assert lv3.struct_info.dtype == "float16"
gv0 = bb.emit_output(lv3)
bb.emit_func_output(gv0)
assert isinstance(gv0.struct_info, rx.TensorStructInfo)
assert gv0.struct_info.ndim == 1
assert gv0.struct_info.dtype == "float16"
match_cast#
match_cast 是一种类型转换机制,用于将变量与特定的结构信息进行匹配
def test_emit_match_cast():
m = tir.Var("m", dtype="int64")
n = tir.Var("n", dtype="int64")
x = rx.Var("tensor_value", rx.TensorStructInfo(dtype="float32", ndim=-1))
y = rx.Var("shape_value", rx.ShapeStructInfo([16, 8]))
bb = rx.BlockBuilder()
with bb.function("func", [x, y]):
with bb.dataflow():
# lv0: Tensor((m, n), "float32") =
# match_cast(x: Tensor(_, "float32"], [m, n))
lv0 = bb.match_cast(x, rx.TensorStructInfo([m, n], "float32"))
assert isinstance(lv0, rx.DataflowVar)
assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float32"))
# lv1: Shape = match_cast(shape, rx.ShapeStructInfo([m, n]))
lv1 = bb.match_cast(y, rx.ShapeStructInfo([m, n]), "var_name")
assert lv1.struct_info == rx.ShapeStructInfo([m, n])
gv0 = bb.emit_output(lv1)
bb.emit_func_output(gv0)
func = bb.finalize()["func"]
block = func.body.blocks[0]
b0, b1 = block.bindings[:2]
assert isinstance(b0, rx.MatchCast)
assert isinstance(b1, rx.MatchCast)
assert b0.value == x
assert b0.struct_info == rx.TensorStructInfo([m, n], "float32")
assert b0.var == lv0
assert b1.value == y
assert b1.struct_info == rx.ShapeStructInfo([m, n])
assert b1.var == lv1
assert b1.var.name_hint == "var_name"
该测试函数全面验证了 match_cast 操作的功能,包括:
将张量从任意维度转换为指定的符号维度形状
将固定形状转换为符号维度形状
为转换后的变量指定名称
验证转换操作的底层表示和属性
match_cast 操作在 Relax 中扮演着重要角色,它允许开发者在编写程序时明确指定变量的结构信息,同时为框架提供了进行静态类型检查和形状推断的能力。这些测试确保了 match_cast 操作能够正确处理不同类型的变量转换,并验证了转换结果的正确性。
emit_normalized#
测试函数专注于验证 emit_normalized 方法在数据流块中处理 MatchCast 操作的能力。它确保:
MatchCast操作能够被正确地添加到数据流块中MatchCast操作的属性(源值、结构信息、目标变量)被正确设置数据流块能够正确管理和绑定
MatchCast操作
emit_normalized 方法是 BlockBuilder 中的一个重要方法,它允许直接添加规范化的操作,而不需要通过 Builder 的其他方法(如 match_cast)来创建。这个测试确保了该方法在数据流块中的正确行为,为开发者提供了更灵活的编程方式。
def test_emit_match_cast_binding_in_dataflow_block():
bb = rx.BlockBuilder()
x = rx.Var("x", rx.TensorStructInfo(dtype="float32", ndim=-1))
m = tir.Var("m", dtype="int64")
gv = rx.Var("gv", rx.TensorStructInfo(dtype="float32", ndim=-1))
match_cast = rx.MatchCast(gv, x, rx.TensorStructInfo((m,), "float32"))
with bb.function("main", [x]):
with bb.dataflow():
bb.emit_normalized(match_cast)
bb.emit_output(gv)
bb.emit_func_output(x)
func = bb.finalize()["main"]
block = func.body.blocks[0]
b0 = block.bindings[0]
assert isinstance(b0, rx.MatchCast)
assert b0.value == x
assert isinstance(b0.struct_info, rx.TensorStructInfo)
assert b0.struct_info.shape[0] == m
assert b0.var == gv
normalize#
测试函数全面验证了 normalize 方法对不同类型节点的规范化能力,包括:
Call 节点:验证操作节点在规范化后具有正确的形状信息
Tuple 节点:验证元组节点在规范化后具有正确的结构信息
嵌套 Tuple 节点:验证嵌套元组在规范化后保持正确的层次结构信息
normalize 方法是 BlockBuilder 中的一个核心方法,它确保各种类型的节点在创建后具有正确的结构信息,这对于后续的静态分析、优化和代码生成至关重要。通过规范化,框架可以明确每个节点的类型和形状,从而进行更精确的处理。
def test_normalize():
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
bb = rx.BlockBuilder()
# Call node
add_call = rx.op.multiply(x, y)
bb.normalize(add_call)
shape = rx.get_shape_of(add_call)
assert isinstance(shape, rx.ShapeExpr)
assert shape[0] == m
assert shape[1] == n
# Tuple node
tuple_1 = rx.Tuple([x, y])
bb.normalize(tuple_1)
assert isinstance(tuple_1.struct_info, rx.TupleStructInfo)
assert isinstance(tuple_1.struct_info.fields[0], rx.TensorStructInfo)
assert isinstance(tuple_1.struct_info.fields[1], rx.TensorStructInfo)
# Nested Tuple
tuple_2 = rx.Tuple([x, rx.Tuple([x, y])])
bb.normalize(tuple_2)
assert isinstance(tuple_2.struct_info, rx.TupleStructInfo)
assert isinstance(tuple_2.struct_info.fields[0], rx.TensorStructInfo)
assert isinstance(tuple_2.struct_info.fields[1], rx.TupleStructInfo)
assert isinstance(tuple_2.struct_info.fields[1].fields[0], rx.TensorStructInfo)
assert isinstance(tuple_2.struct_info.fields[1].fields[1], rx.TensorStructInfo)
处理元组类型的结构信息#
测试确保Relax框架能够正确处理元组类型的结构信息,包括:
元组索引操作能正确继承元素的结构信息
元组解包功能按预期工作
对错误的解包操作提供明确的错误提示 这些功能对于静态类型分析、形状推导和代码正确性验证至关重要。
def test_tuple_indexing():
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
shape_x = rx.TensorStructInfo([m, n], "float16")
shape_y = rx.TensorStructInfo([n], "float16")
relax_tuple = rx.Var("relax_tuple", rx.TupleStructInfo([shape_x, shape_y]))
assert isinstance(relax_tuple.struct_info, rx.TupleStructInfo)
assert isinstance(relax_tuple.struct_info.fields[0], rx.TensorStructInfo)
assert isinstance(relax_tuple.struct_info.fields[1], rx.TensorStructInfo)
# TupleGetItem will initialize struct info from the
# TupleStructInfo, if present.
x = relax_tuple[0]
tvm.ir.assert_structural_equal(x.struct_info, shape_x)
y = relax_tuple[1]
tvm.ir.assert_structural_equal(y.struct_info, shape_y)
# Tuple unpacking produces TupleGetItem structs
x_unpack, y_unpack = relax_tuple
tvm.ir.assert_structural_equal(x, x_unpack)
tvm.ir.assert_structural_equal(y, y_unpack)
# When TupleStructInfo is available, tuple unpacking fails immediately
# for incorrect number of arguments.
with pytest.raises(ValueError):
x_unpack, y_unpack, z_unpack = relax_tuple
调用TE函数#
测试确保Relax框架能够正确集成和调用TE函数,这是TVM中连接Relax前端与底层张量计算的重要机制。测试验证了:
位置参数和关键字参数的正确传递
TE计算与Relax函数的正确嵌套
函数结构和绑定关系的正确性 这些功能对于将高级Relax程序转换为可执行的底层张量计算至关重要。
def test_call_te():
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.Var("x", rx.TensorStructInfo([n, m], "float32"))
y = rx.Var("y", rx.TensorStructInfo([n, m], "float32"))
z = rx.Var("z", rx.TensorStructInfo([n, m], "float32"))
def te_func(args, args_dict, msg):
A, B = args
C = args_dict["C"]
D = te.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
E = te.compute((128, 128), lambda i, j: D[i, j] - C[i, j])
return E
with bb.function("rx_func", [x, y, z]):
with bb.dataflow():
out = bb.emit_output(bb.call_te(te_func, [x, y], {"C": z}, msg="hello"))
bb.emit_func_output(out)
mod = bb.finalize()
rx_func = mod["rx_func"]
assert rx_func.params[0] == x
assert rx_func.params[1] == y
assert rx_func.params[2] == z
assert rx_func.body.body == out
assert len(rx_func.body.blocks) == 1
assert len(rx_func.body.blocks[0].bindings) == 1
TE函数的名称唯一性#
测试确保Relax框架在调用TE函数时能够为不同的张量参数生成唯一的名称。名称唯一性对于以下方面至关重要:
避免代码生成和优化过程中的命名冲突
确保张量和缓冲区的正确识别和引用
提高代码的可读性和可维护性
保证底层计算图的正确性
这个测试特别关注矩阵乘法操作,验证了从参数到缓冲区的完整命名链都能保持唯一性。
def test_call_te_unique_tensor_name():
bb = rx.BlockBuilder()
x = rx.Var("x", R.Tensor((2, 3), "float32"))
y = rx.Var("y", R.Tensor((3, 4), "float32"))
with bb.function("main", [x, y]):
gv = bb.emit_te(topi.nn.matmul, x, y)
bb.emit_func_output(gv)
f_matmul = bb.finalize()["matmul"]
param_A = f_matmul.params[0]
param_B = f_matmul.params[1]
buffer_A = f_matmul.buffer_map[param_A]
buffer_B = f_matmul.buffer_map[param_B]
assert param_A.name != param_B.name
assert buffer_A.name != buffer_B.name
assert buffer_A.data.name != buffer_B.data.name
调用TE函数时能够正确验证参数类型#
测试确保Relax框架在调用TE函数时能够正确验证参数类型,特别是形状参数。具体来说:
topi.reshape通常期望形状参数是具体的元组(如(200,)),而非ShapeStructInfo类型的变量测试验证了
call_te方法能够检测到这种类型不匹配并抛出适当的异常确保框架在早期阶段就能捕获类型错误,避免后续执行过程中出现更严重的问题
维护了类型安全和API使用的正确性
这个测试是框架健壮性的重要保障,确保用户在使用API时能够得到明确的错误提示,而不是遇到隐藏的类型不兼容问题。
def test_call_te_with_unsupported_shape_arg():
bb = rx.BlockBuilder()
x = rx.Var("x", rx.TensorStructInfo((200,), "float32"))
s = rx.Var("s", rx.ShapeStructInfo((200,)))
with pytest.raises(AssertionError):
with bb.function("rx_func", [x]):
out = bb.emit(bb.call_te(topi.reshape, x, s))
bb.emit_func_output(out)
emit_te 测试#
def test_emit_te():
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.Var("x", rx.TensorStructInfo([n, m], "float32"))
y = rx.Var("y", rx.TensorStructInfo([n, m], "float32"))
z = rx.Var("z", rx.TensorStructInfo([n, m], "float32"))
def te_func(args, args_dict, msg):
A, B = args
C = args_dict["C"]
D = te.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
E = te.compute((128, 128), lambda i, j: D[i, j] - C[i, j])
return E
with bb.function("rx_func", [x, y, z]):
out = bb.emit_te(te_func, [x, y], {"C": z}, msg="hello")
bb.emit_func_output(out)
mod = bb.finalize()
rx_func = mod["rx_func"]
def get_tir_func():
A = te.placeholder((n, m), dtype="float32", name="A")
B = te.placeholder((n, m), dtype="float32", name="B")
C = te.placeholder((n, m), dtype="float32", name="C")
out = te_func((A, B), {"C": C}, "")
return tvm.te.create_prim_func([A, B, C, out], index_dtype_override="int64")
# check TIR structure matches expected
assert_structural_equal(mod["te_func"].body, get_tir_func().body)
# check Relax function calls TIR function with call_tir call
assert rx_func.params[0] == x
assert rx_func.params[1] == y
assert rx_func.params[2] == z
assert rx_func.body.body == out
assert len(rx_func.body.blocks) == 1
assert len(rx_func.body.blocks[0].bindings) == 1
call_node = rx_func.body.blocks[0].bindings[0].value
assert isinstance(call_node, rx.Call)
assert len(call_node.args) == 2
assert call_node.args[0].name_hint == "te_func"
assert call_node.args[1][0] == x
assert call_node.args[1][1] == y
assert call_node.args[1][2] == z
def test_emit_te_multiple():
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.Var("x", rx.TensorStructInfo([n, m], "float32"))
y = rx.Var("y", rx.TensorStructInfo([n, m], "float32"))
z = rx.Var("z", rx.TensorStructInfo([128, m], "float32"))
def te_func(A):
B = te.compute((128, 128), lambda i, j: A[i, j] + 1)
return B
with bb.function("rx_func", [x, y, z]):
x1 = bb.emit_te(te_func, x)
y1 = bb.emit_te(te_func, y)
z1 = bb.emit_te(te_func, z)
bb.emit_func_output(z1)
mod = bb.finalize()
rx_func = mod["rx_func"]
prim_func = []
for gv in mod.get_global_vars():
if isinstance(mod[gv], PrimFunc):
prim_func.append(mod[gv])
# only two PrimFuncs were generated since two of them are equal so got deduped
assert len(prim_func) == 2
assert rx_func.body.blocks[0].bindings[0].value.args[0].name_hint == "te_func"
assert rx_func.body.blocks[0].bindings[1].value.args[0].name_hint == "te_func"
assert rx_func.body.blocks[0].bindings[2].value.args[0].name_hint == "te_func1"
def test_emit_te_multiple_output():
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.Var("x", rx.TensorStructInfo([n, m], "float32"))
def te_func(A):
B0, B1 = te.compute((n, m), lambda i, j: (A[i, j] + 1, A[i, j] * 2), name="B")
return (B0, B1)
with bb.function("rx_func", [x]):
y = bb.emit_te(te_func, x)
z = rx.TupleGetItem(y, 0)
bb.emit_func_output([y, z])
rx_func = bb.finalize()["rx_func"]
# check call tir output shape is a Tuple of ShapeExpr
assert rx_func.params[0] == x
call_node = rx_func.body.blocks[0].bindings[0].value
assert call_node.args[0].name_hint == "te_func"
assert isinstance(call_node.sinfo_args[0], rx.TupleStructInfo)
assert len(call_node.sinfo_args[0].fields) == 2
assert isinstance(call_node.sinfo_args[0].fields[0].shape, rx.ShapeExpr)
assert isinstance(call_node.sinfo_args[0].fields[1].shape, rx.ShapeExpr)
def test_emit_te_extern():
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.Var("x", rx.TensorStructInfo([n, m], "float32"))
y = rx.Var("y", rx.TensorStructInfo([m, n], "float32"))
with bb.function("rx_cblas_matmul", [x, y]):
out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False)
bb.emit_func_output(out)
mod = bb.finalize()
rx_func = mod["rx_cblas_matmul"]
# check Relax function calls TIR function with call_tir call
assert rx_func.params[0] == x
assert rx_func.params[1] == y
assert len(rx_func.body.blocks) == 1
call_node = rx_func.body.blocks[0].bindings[0].value
assert isinstance(call_node, rx.Call)
assert len(call_node.args) == 2
assert call_node.args[0].name_hint == "matmul"
assert call_node.args[1][0] == x
assert call_node.args[1][1] == y
assert call_node.sinfo_args[0].shape[0] == n
assert call_node.sinfo_args[0].shape[1] == n
def test_emit_te_prim_value():
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.Var("x", R.Tensor([n, m], "float32"))
a_min = rx.PrimValue(0)
a_max = rx.PrimValue(6)
with bb.function("rx_clip", [x]):
out = bb.emit_te(topi.clip, x, a_min, a_max)
bb.emit_func_output(out)
rx_func = bb.finalize()["rx_clip"]
# check Relax function calls TIR function with call_tir call
assert rx_func.params[0] == x
assert len(rx_func.body.blocks) == 1
call_node = rx_func.body.blocks[0].bindings[0].value
assert isinstance(call_node, rx.Call)
assert len(call_node.args) == 2
assert call_node.args[1][0] == x
def test_nested_function_fail():
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
bb = rx.BlockBuilder()
with pytest.raises(RuntimeError):
with bb.function("func", [x, y]):
gv0 = bb.emit(rx.op.add(x, x))
with bb.function("func1", [x, y]):
gv1 = bb.emit(rx.op.add(x, x))
bb.emit_func_output(gv0)
def test_emit_func_output_twice_fail():
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
bb = rx.BlockBuilder()
with pytest.raises(RuntimeError):
with bb.function("func", [x, y]):
gv0 = bb.emit(rx.op.add(x, y))
bb.emit_func_output(gv0)
bb.emit_func_output(gv0)
def test_func_params_twice_fail():
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
bb = rx.BlockBuilder()
with pytest.raises(RuntimeError):
with bb.function("func", [x, y]):
gv0 = bb.emit(rx.op.add(x, y))
bb.emit_func_output(gv0, [x])
def test_no_func_params_fail():
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
bb = rx.BlockBuilder()
with pytest.raises(RuntimeError):
with bb.function("func"):
gv0 = bb.emit(rx.Call(ExternFunc("test.blockbuilder.nop"), []))
bb.emit_func_output(gv0)
def test_block_builder_scope_recovery():
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.Var("x", rx.TensorStructInfo([n, m], "float32"))
y = rx.Var("y", rx.TensorStructInfo([m, n], "float32"))
with pytest.raises(RuntimeError):
# this line fails
with bb.function("func", [x, y]):
gv0 = bb.emit(rx.op.add(x, y))
# current should be recovered
assert rx.BlockBuilder.current() is None
# second attempt to do it correctly.
with bb.function("func", [x, y]):
gv0 = bb.emit(rx.op.add(x, y))
bb.emit_func_output(gv0)
@pytest.mark.parametrize("emit_nested_tuple", [True, False])
def test_emit_nested_tuple(emit_nested_tuple):
"""Convert nested tuples when emitting relax"""
def make_function(emit_nested_tuple: bool):
bb = rx.BlockBuilder()
n_sym = tir.Var("n", "int64")
m_sym = tir.Var("m", "int64")
n = rx.Var("n", rx.PrimStructInfo(value=n_sym))
m = rx.Var("m", rx.PrimStructInfo(value=m_sym))
x = rx.Var("x", rx.TensorStructInfo([n_sym, m_sym], "float32"))
y = rx.Var("y", rx.TensorStructInfo([m_sym, n_sym], "float32"))
with bb.function("func", [n, m, x, y]):
scalars = (n, m)
if not emit_nested_tuple:
scalars = bb.emit(scalars)
output = (scalars, x, y)
bb.emit_func_output(output)
return bb.finalize()["func"]
def make_expected(emit_nested_tuple: bool):
if emit_nested_tuple:
@R.function
def func(
n_1: R.Prim(value="n"),
m_1: R.Prim(value="m"),
x: R.Tensor(("n", "m"), dtype="float32"),
y: R.Tensor(("m", "n"), dtype="float32"),
):
return ((n_1, m_1), x, y)
else:
@R.function
def func(
n_1: R.Prim(value="n"),
m_1: R.Prim(value="m"),
x: R.Tensor(("n", "m"), dtype="float32"),
y: R.Tensor(("m", "n"), dtype="float32"),
):
gv = n_1, m_1
return (gv, x, y)
return func
expected = make_expected(emit_nested_tuple)
actual = make_function(emit_nested_tuple)
tvm.ir.assert_structural_equal(expected, actual)
@pytest.mark.skip_well_formed_check_before_transform
def test_finalize_public_private_name_conflict():
# tir call
bb = rx.BlockBuilder()
def te_zero():
return topi.full((), "int64", tir.IntImm("int64", 0))
def te_one():
return topi.full((), "int64", tir.IntImm("int64", 1))
with bb.function("func", []):
gv0 = bb.emit_te(te_zero, primfunc_name_hint="func")
gv1 = bb.emit_te(te_one, primfunc_name_hint="func")
bb.emit_func_output((gv0, gv1))
mod = bb.get()
assert not rx.analysis.well_formed(mod)
mod_final = bb.finalize()
assert rx.analysis.well_formed(mod_final)
# relax function call
bb = rx.BlockBuilder()
with bb.function("func", [], private=True):
gvar = bb.emit_func_output(rx.const(0, "int64"))
with bb.function("func", [], private=True):
gv0 = bb.emit(rx.Call(gvar, []))
gvar1 = bb.emit_func_output(gv0)
with bb.function("func", []):
gv0 = bb.emit(rx.Call(gvar1, []))
bb.emit_func_output(gv0)
mod = bb.get()
assert not rx.analysis.well_formed(mod)
mod_final = bb.finalize()
assert rx.analysis.well_formed(mod_final)
def test_emit_nested_seqexpr_in_binding_block():
"""May emit a SeqExpr inside a BindingBlock"""
bb = rx.BlockBuilder()
with bb.function("func", []):
lhs = bb.emit(rx.const(1, "int64"), "a")
rhs = bb.emit(rx.const(2, "int64"), "b")
out = bb.emit(rx.op.add(lhs, rhs), "c")
bb.emit_func_output(out)
seq_expr = bb.finalize()["func"].body
bb = rx.BlockBuilder()
with bb.function("func", [], private=True):
lhs = bb.emit(rx.const(3, "int64"), "d")
rhs = bb.emit(seq_expr, "e")
out = bb.emit(rx.op.add(lhs, rhs), "f")
bb.emit_func_output(out)
output = bb.finalize()["func"]
@R.function(private=True)
def expected():
d = R.const(3, "int64")
a = R.const(1, "int64")
b = R.const(2, "int64")
c = R.add(a, b)
e = c
f = R.add(d, e)
return f
tvm.ir.assert_structural_equal(expected, output)
def test_emit_nested_dataflow_seqexpr_in_dataflow_block():
"""May emit a SeqExpr with dataflow inside a DataflowBlock"""
bb = rx.BlockBuilder()
with bb.function("func", []):
with bb.dataflow():
lhs = bb.emit(rx.const(1, "int64"), "a")
rhs = bb.emit(rx.const(2, "int64"), "b")
out = bb.emit_output(rx.op.add(lhs, rhs), "c")
bb.emit_func_output(out)
seq_expr = bb.finalize()["func"].body
bb = rx.BlockBuilder()
with bb.function("func", [], private=True):
with bb.dataflow():
lhs = bb.emit(rx.const(3, "int64"), "d")
rhs = bb.emit(seq_expr, "e")
out = bb.emit_output(rx.op.add(lhs, rhs), "f")
bb.emit_func_output(out)
output = bb.finalize()["func"]
@R.function(private=True)
def expected():
with R.dataflow():
d = R.const(3, "int64")
a = R.const(1, "int64")
b = R.const(2, "int64")
c = R.add(a, b)
e = c
f = R.add(d, e)
R.output(c, f)
return f
tvm.ir.assert_structural_equal(expected, output)
def test_emit_ill_formed_nested_seqexpr_in_dataflow_block():
"""May emit a SeqExpr inside a DataflowBlock
This produces ill-formed code, but cannot be caught at the
normalizer. See also
test_emit_well_formed_nested_seqexpr_in_dataflow_block.
"""
bb = rx.BlockBuilder()
with bb.function("func", []):
lhs = bb.emit(rx.const(1, "int64"), "a")
rhs = bb.emit(rx.const(2, "int64"), "b")
out = bb.emit(rx.op.add(lhs, rhs), "c")
bb.emit_func_output(out)
seq_expr = bb.finalize()["func"].body
bb = rx.BlockBuilder()
with bb.function("func", [], private=True):
with bb.dataflow():
lhs = bb.emit(rx.const(3, "int64"), "d")
# This would be ill-formed, as it requires breaking up the
# DataflowBlock with a BindingBlock.
rhs = bb.emit(seq_expr, "e")
# We cannot throw an error at that point, because it is
# only the later usage of "d" that results in use of a
# DataflowVar outside of its home DataflowBlock.
out = bb.emit_output(rx.op.add(lhs, rhs), "f")
bb.emit_func_output(out)
output = bb.finalize()["func"]
assert not rx.analysis.well_formed(tvm.ir.IRModule.from_expr(output))
def test_emit_well_formed_nested_seqexpr_in_dataflow_block():
"""May emit a SeqExpr inside a DataflowBlock
This produces well-formed code, and should not have any output
produced by the normalizer. See also
test_emit_ill_formed_nested_seqexpr_in_dataflow_block.
"""
bb = rx.BlockBuilder()
with bb.function("func", []):
lhs = bb.emit(rx.const(1, "int64"), "a")
rhs = bb.emit(rx.const(2, "int64"), "b")
out = bb.emit(rx.op.add(lhs, rhs), "c")
bb.emit_func_output(out)
seq_expr = bb.finalize()["func"].body
bb = rx.BlockBuilder()
with bb.function("func", [], private=True):
with bb.dataflow():
lhs = bb.emit(rx.const(3, "int64"), "d")
# This similarly breaks up the DataflowBlock, with
# identical steps as the previous test up until this
# point.
rhs = bb.emit(seq_expr, "e")
# But the "d" variable isn't used, and so there aren't any
# usages of DataflowVar outside of their home
# DataflowBlock.
out = bb.emit_output(rhs, "f")
bb.emit_func_output(out)
output = bb.finalize()["func"]
assert rx.analysis.well_formed(tvm.ir.IRModule.from_expr(output))
@R.function(private=True)
def expected() -> R.Tensor((), dtype="int64"):
with R.dataflow():
d = R.const(3, "int64")
R.output()
a = R.const(1, "int64")
b = R.const(2, "int64")
c = R.add(a, b)
with R.dataflow():
e = c
f = e
R.output(f)
return f
tvm.ir.assert_structural_equal(expected, output)
def test_error_when_unwrapping_dataflowvar():
"""Checks for ill-formed use of DataflowVar at normalization
We can check for some illegal unwrapping of SeqExpr, though. If
the inlined non-dataflow SeqExpr uses a DataflowVar, that should
trigger an error when the SeqExpr is being unwrapped.
"""
bb = rx.BlockBuilder()
lhs = rx.Var("a", rx.TensorStructInfo(shape=[], dtype="int64"))
with bb.function("func", [lhs]):
rhs = rx.const(2, "int64")
out = bb.emit(rx.op.add(lhs, rhs))
bb.emit_func_output(out)
func = bb.finalize()["func"]
bb = rx.BlockBuilder()
with bb.function("func", [], private=True):
with bb.dataflow():
local_lhs = bb.emit(rx.const(3, "int64"), "local_a")
rhs = bb.emit(func.bind_params({lhs: local_lhs}).body, "f")
out = bb.emit_output(rhs, "f")
with pytest.raises(tvm.TVMError, match="Malformed AST"):
bb.emit_func_output(out)
def test_deduplication_when_input_contains_duplicates():
"""De-duplication of IRModules
A well-formed IRModule may contain duplicate function definitions.
This is rare, as most functions can be disambiguated by the the
function attribute `tvm::attr::kGlobalSymbol`. However, private
functions do not have this attribute, and a well-formed IRModule
may contain multiple copies of the same function.
This is a regression test. Previous implementation de-duplicated
using a `Dict[Function, GlobalVar]`, which has the failure mode
shown below. This was resolved by de-duplicating using a
`Dict[Function, Set[GlobalVar]]` instead.
"""
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor):
B = Module.subroutine_a(A)
C = Module.subroutine_b(B)
return C
@R.function(private=True)
def subroutine_a(arg: R.Tensor) -> R.Tensor:
return R.add(arg, arg)
@R.function(private=True)
def subroutine_b(arg: R.Tensor) -> R.Tensor:
return R.add(arg, arg)
@R.function(private=True)
def subroutine_c(arg: R.Tensor) -> R.Tensor:
return R.multiply(arg, arg)
# This test case is only valid when the two subroutines are
# structurally equal, and therefore allowed to be de-duplicated by
# the BlockBuilder.
tvm.ir.assert_structural_equal(Module["subroutine_a"], Module["subroutine_b"])
gvar_a = Module.get_global_var("subroutine_a")
gvar_b = Module.get_global_var("subroutine_b")
subroutine_c = Module["subroutine_c"]
bb = rx.BlockBuilder(Module)
# Add a function to the module. What we add doesn't matter, as
# this is only to initialize the de-duplication map.
bb.add_func(subroutine_c, "_unused")
# The deduplication table now maps `subroutine_ab` to either
# `gvar_a` or `gvar_b`.
# Update gvar_a.
bb.update_func(gvar_a, subroutine_c)
# The deduplication map no longer has an entry for
# `subroutine_ab`.
# Update gvar_b. The deduplication map is present (because we
# called `add_func`), but doesn't contain an entry for
# `subroutine_ab` (because it was just removed). This throws an
# error.
bb.update_func(gvar_b, subroutine_c)
# The tests here depend on tvmscript
import tvm
from tvm import te, tir
from tvm import relax as rx
from tvm.ir.base import assert_structural_equal
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
from tvm.script.parser import tir as T
def test_emit_te_with_symbolic_arg():
bb = rx.BlockBuilder()
m = tir.Var("m", "int64")
x = rx.Var("x", R.Tensor([10], "float32"))
y = rx.Var("y", R.Shape([m]))
def te_func(A, offset):
return te.compute(A.shape, lambda i: A[i + offset], name="B")
with bb.function("main", [x, y]):
out = bb.emit_te(te_func, x, m)
bb.emit_func_output(out)
after = bb.get()
@I.ir_module
class Expected:
@T.prim_func(private=True)
def te_func(
A: T.Buffer((T.int64(10),), "float32"),
B: T.Buffer((T.int64(10),), "float32"),
m: T.int64,
):
T.func_attr({"tir.noalias": True})
for i in range(T.int64(10)):
with T.block("B"):
v_i = T.axis.spatial(T.int64(10), i)
T.writes(B[v_i])
B[v_i] = A[v_i + m]
@R.function
def main(
x: R.Tensor((10,), dtype="float32"), y: R.Shape(["m"])
) -> R.Tensor((10,), dtype="float32"):
m = T.int64()
cls = Expected
gv = R.call_tir(
cls.te_func,
(x,),
out_sinfo=R.Tensor((10,), dtype="float32"),
tir_vars=R.shape([m]),
)
return gv
assert_structural_equal(after, Expected)
def test_symbolic_shape_in_prim_value():
"""Symbolic vars may be provided to TE in R.Prim"""
def te_slice(tensor, i):
return tvm.te.compute([tensor.shape[1]], lambda j: tensor[i, j], name="slice")
def from_builder():
bb = rx.BlockBuilder()
A = rx.Var("A", R.Tensor([16, 16], "float32"))
tir_i = tvm.tir.Var("tir_i", "int64")
relax_i = rx.Var("relax_i", R.Prim(value=tir_i))
with bb.function("main", params=[A, relax_i]):
A_sliced = bb.emit_te(te_slice, A, relax_i)
bb.emit_func_output(A_sliced)
return bb.get()
@I.ir_module
class Expected:
@T.prim_func(private=True)
def te_slice(
A: T.Buffer([T.int64(16), T.int64(16)], "float32"),
Output: T.Buffer(T.int64(16), "float32"),
row_index: T.int64,
):
T.func_attr({"tir.noalias": True})
for i in range(A.shape[1]):
with T.block("slice"):
vi = T.axis.remap("S", [i])
Output[vi] = A[row_index, vi]
@R.function
def main(
A: R.Tensor([16, 16], "float32"),
arg_row_index: R.Prim(value="row_index"),
):
cls = Expected
row_index = T.int64()
gv = R.call_tir(
cls.te_slice,
A,
tir_vars=[row_index],
out_sinfo=R.Tensor([16], "float32"),
)
return gv
tvm.ir.assert_structural_equal(from_builder(), Expected)