Relax 构建块核心#

import tvm
import tvm.contrib.cblas

from tvm import te, tir, topi
from tvm import relax as rx
from tvm.ir.base import assert_structural_equal

from tvm.script import ir as I, relax as R, tir as T
from tvm.tir.function import PrimFunc
import pytest

测试 BlockBuilder 在处理函数定义时的错误情况#

验证了当函数定义中没有显式声明参数,但可能在函数体中隐式引用了外部变量(如 m, n, x, y)时,系统是否会抛出适当的错误。这是一种编译时错误检查机制,确保函数定义的完整性和正确性。

# 注册 nop 
@tvm.register_func("test.blockbuilder.nop")
def nop():
    ...
from tvm.relax import ExternFunc
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
bb = rx.BlockBuilder()

with pytest.raises(RuntimeError):
    with bb.function("func"):
        gv0 = bb.emit(rx.Call(ExternFunc("test.blockbuilder.nop"), []))
        bb.emit_func_output(gv0)

简单测试#

def test_block_builder():
    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")
    x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
    y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
    bb = rx.BlockBuilder()

    bb._begin_binding_block()
    gv0 = bb.emit(rx.op.add(x, y))
    bb._begin_dataflow_block()
    lv0 = bb.emit(rx.op.multiply(gv0, y))
    gv1 = bb.emit_output(rx.op.multiply(lv0, lv0))
    b0 = bb._end_block()
    bb._begin_dataflow_block()
    lv1 = bb.emit(rx.op.multiply(gv0, y))
    gv2 = bb.emit_output(rx.op.multiply(lv1, lv1))
    b1 = bb._end_block()
    gv3 = bb.emit(rx.op.add(x, y))
    b2 = bb._end_block()

    assert isinstance(b0, rx.DataflowBlock)
    assert isinstance(b1, rx.DataflowBlock)
    assert not isinstance(b2, rx.DataflowBlock)

test_block_builder()
  1. 变量定义:

    • mn 是 TIR(Tensor Intermediate Representation)变量,表示 int64 类型的标量尺寸

    • xy 是 Relax 变量,分别表示形状为 [m, n][n]float16 张量

    • bbrx.BlockBuilder() 实例,用于构建 Relax 构建块。

  2. 块构建过程:

    • bb._begin_binding_block():开始绑定块(Binding Block)

    • gv0 = bb.emit(rx.op.add(x, y)):创建全局变量 gv0,表示 xy 的加法运输

    • bb._begin_dataflow_block():开始数据流块(Dataflow Block)

    • lv0 = bb.emit(rx.op.multiply(gv0, y)):创建局部变量 lv0,表示 gv0y 的乘法

    • gv1 = bb.emit_output(...):创建输出变量 gv1,表示 lv0 的平方

    • b0 = bb._end_block():结束当前块并返回块对象

    • 代码重复创建了另一个数据流块 b1 和普通绑定块 b2

测试目的:

  • 验证BlockBuilder能够正确创建和区分不同类型的块

  • 测试变量作用域管理(全局变量gv0、gv1、gv2、gv3和局部变量lv0、lv1)

  • 验证数据流块和普通绑定块的创建机制

  • 测试基本算子(add、multiply)的使用

关键概念:

  • Binding Block:普通绑定块,用于顺序执行的计算

  • Dataflow Block:数据流块,通常用于表示可以并行执行的计算,内部变量为局部变量

  • BlockBuilder:Relax中用于构建计算图的核心工具

这段代码展示了Relax模块中构建计算图的基本模式,包括块的创建、变量的定义和操作的执行,同时验证了BlockBuilder的正确性。

为构建块中变量指定名称#

def test_emit_with_name():
    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")
    x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
    y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
    bb = rx.BlockBuilder()

    bb._begin_dataflow_block()
    lv0 = bb.emit(rx.op.add(x, y), "add")
    gv0 = bb.emit_output(rx.op.multiply(lv0, y), "multi")
    b0 = bb._end_block()

    assert b0.bindings[0].var.name_hint == "add"
    assert b0.bindings[1].var.name_hint == "multi"

test_emit_with_name()

测试 BlockBuilder 创建函数的功能#

def test_function_single_block():
    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")
    x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
    y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
    bb = rx.BlockBuilder()

    with bb.function("func", [x, y]):
        with bb.dataflow():
            lv0 = bb.emit(rx.op.add(x, y))
            assert lv0.name_hint == "lv"
            lv1 = bb.emit(rx.op.multiply(lv0, y))
            assert lv1.name_hint == "lv1"
            gv0 = bb.emit_output(lv1)
        assert gv0.name_hint == "gv"
        bb.emit_func_output(gv0)

    func = bb.finalize()["func"]
    assert func.params[0] == x
    assert func.params[1] == y
    assert func.body.body == gv0
    assert_structural_equal(gv0.struct_info, rx.TensorStructInfo([m, n], "float16"))
    assert len(func.body.blocks) == 1
    assert len(func.body.blocks[0].bindings) == 3
  1. 函数创建与块构建:

    • with bb.function("func", [x, y]):创建名为"func"的函数,参数为x和y

    • with bb.dataflow():在函数内创建数据流块

    • lv0 = bb.emit(rx.op.add(x, y)):在数据流块中创建局部变量lv0,表示x和y的加法

    • lv1 = bb.emit(rx.op.multiply(lv0, y)):创建局部变量lv1,表示lv0和y的乘法

    • gv0 = bb.emit_output(lv1):创建全局变量gv0,表示数据流块的输出

  2. 验证逻辑:

    • 断言lv0的名称提示(name_hint)为"lv"

    • 断言lv1的名称提示为"lv1"

    • 断言gv0的名称提示为"gv"

    • func = bb.finalize()["func"]:获取最终的函数

    • 断言函数参数正确

    • 断言函数体结构正确

    • 断言gv0的结构信息正确

    • 断言函数体包含1个块,且该块包含3个绑定

测试目的:

  • 验证使用BlockBuilder创建函数的基本流程

  • 测试函数参数的设置和传递

  • 验证数据流块的创建和操作

  • 测试变量名称提示的自动生成

  • 验证函数输出的设置和获取

  • 确认函数结构和绑定数量的正确性

关键概念:

  • function:Relax中的函数定义,包含参数和函数体

  • dataflow:数据流块,用于表示可以并行执行的计算

  • emit:向当前块中添加操作并返回生成的变量

  • emit_output:标记数据流块的输出变量

  • emit_func_output:设置函数的返回值

  • finalize:完成构建并返回创建的函数

这段代码全面测试了使用BlockBuilder创建包含单个数据流块的函数的各个环节,确保了函数创建、参数传递、块操作和输出设置等功能的正确性。

def test_function_multi_blocks():
    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")
    x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
    y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
    bb = rx.BlockBuilder()

    with bb.function("func", [x, y]):
        with bb.dataflow():
            lv0 = bb.emit(rx.op.add(x, y))
            assert lv0.name_hint == "lv"
            gv0 = bb.emit_output(lv0)
        assert gv0.name_hint == "gv"
        gv1 = bb.emit(rx.op.add(gv0, gv0))
        assert gv1.name_hint == "gv1"
        with bb.dataflow():
            lv1 = bb.emit(rx.op.add(gv1, gv1))
            assert lv1.name_hint == "lv1"
            gv2 = bb.emit_output(gv1)
        bb.emit_func_output(gv2)

    func = bb.finalize()["func"]

    assert_structural_equal(gv2.struct_info, rx.TensorStructInfo([m, n], "float16"))
    assert func.params[0] == x
    assert func.params[1] == y
    assert func.body.body == gv2
    assert len(func.body.blocks) == 3
    assert len(func.body.blocks[0].bindings) == 2
    assert len(func.body.blocks[1].bindings) == 1
    assert len(func.body.blocks[2].bindings) == 2
def test_multi_functions():
    bb = rx.BlockBuilder()

    m_1 = tir.Var("m", "int64")
    n_1 = tir.Var("n", "int64")
    x_1 = rx.Var("x", rx.TensorStructInfo([m_1, n_1], "float16"))
    y_1 = rx.Var("y", rx.TensorStructInfo([n_1], "float16"))

    with bb.function("func1", [x_1, y_1]):
        with bb.dataflow():
            lv0 = bb.emit(rx.op.add(x_1, y_1))
            assert lv0.name_hint == "lv"
            gv0 = bb.emit_output(lv0)
        bb.emit_func_output(gv0)

    m_2 = tir.Var("m", "int64")
    n_2 = tir.Var("n", "int64")
    x_2 = rx.Var("x", rx.TensorStructInfo([m_2, n_2], "float16"))
    y_2 = rx.Var("y", rx.TensorStructInfo([n_2], "float16"))

    with bb.function("func2", [x_2, y_2]):
        with bb.dataflow():
            lv0 = bb.emit(rx.op.add(y_2, x_2))
            # TODO(@yuchen): enable block builder to reset local var unique name map
            assert lv0.name_hint == "lv1"
            gv0 = bb.emit_output(lv0)
        bb.emit_func_output(gv0)

    mod = bb.finalize()
    func1 = mod["func1"]
    assert func1.params[0] == x_1
    assert func1.params[1] == y_1
    assert len(func1.body.blocks) == 1
    func2 = mod["func2"]
    assert func2.params[0] == x_2
    assert func2.params[1] == y_2
    assert len(func2.body.blocks) == 1

验证 Relax BlockBuilder 对二元算子(如加法、乘法)的 形状和类型推导 功能#

def test_binary_shape_type_deduction():
    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")
    k = tir.Var("k", "int64")
    x = rx.Var("x", rx.TensorStructInfo([m, 1], "float16"))
    y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
    z = rx.Var("z", rx.TensorStructInfo([5], "float16"))
    w = rx.Var("w", rx.TensorStructInfo([k], "float16"))
    bb = rx.BlockBuilder()

    with bb.function("func", [x, y, z, w]):
        with bb.dataflow():
            lv0 = bb.emit(rx.op.add(x, y))
            assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float16"))

            lv1 = bb.emit(rx.op.multiply(x, z))
            assert_structural_equal(lv1.struct_info, rx.TensorStructInfo([m, 5], "float16"))

            lv2 = bb.emit(rx.op.multiply(z, w))
            assert isinstance(lv2.struct_info, rx.TensorStructInfo)
            assert lv2.struct_info.ndim == 1
            assert lv2.struct_info.dtype == "float16"

            lv3 = bb.emit(rx.op.multiply(y, w))
            assert isinstance(lv3.struct_info, rx.TensorStructInfo)
            assert lv3.struct_info.ndim == 1
            assert lv3.struct_info.dtype == "float16"

            gv0 = bb.emit_output(lv3)
        bb.emit_func_output(gv0)

        assert isinstance(gv0.struct_info, rx.TensorStructInfo)
        assert gv0.struct_info.ndim == 1
        assert gv0.struct_info.dtype == "float16"

match_cast#

match_cast 是一种类型转换机制,用于将变量与特定的结构信息进行匹配

def test_emit_match_cast():
    m = tir.Var("m", dtype="int64")
    n = tir.Var("n", dtype="int64")
    x = rx.Var("tensor_value", rx.TensorStructInfo(dtype="float32", ndim=-1))
    y = rx.Var("shape_value", rx.ShapeStructInfo([16, 8]))
    bb = rx.BlockBuilder()

    with bb.function("func", [x, y]):
        with bb.dataflow():
            # lv0: Tensor((m, n), "float32") =
            #   match_cast(x: Tensor(_, "float32"], [m, n))
            lv0 = bb.match_cast(x, rx.TensorStructInfo([m, n], "float32"))
            assert isinstance(lv0, rx.DataflowVar)
            assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float32"))

            # lv1: Shape = match_cast(shape, rx.ShapeStructInfo([m, n]))
            lv1 = bb.match_cast(y, rx.ShapeStructInfo([m, n]), "var_name")
            assert lv1.struct_info == rx.ShapeStructInfo([m, n])
            gv0 = bb.emit_output(lv1)

        bb.emit_func_output(gv0)
    func = bb.finalize()["func"]
    block = func.body.blocks[0]
    b0, b1 = block.bindings[:2]
    assert isinstance(b0, rx.MatchCast)
    assert isinstance(b1, rx.MatchCast)

    assert b0.value == x
    assert b0.struct_info == rx.TensorStructInfo([m, n], "float32")
    assert b0.var == lv0

    assert b1.value == y
    assert b1.struct_info == rx.ShapeStructInfo([m, n])
    assert b1.var == lv1
    assert b1.var.name_hint == "var_name"

该测试函数全面验证了 match_cast 操作的功能,包括:

  • 将张量从任意维度转换为指定的符号维度形状

  • 将固定形状转换为符号维度形状

  • 为转换后的变量指定名称

  • 验证转换操作的底层表示和属性

match_cast 操作在 Relax 中扮演着重要角色,它允许开发者在编写程序时明确指定变量的结构信息,同时为框架提供了进行静态类型检查和形状推断的能力。这些测试确保了 match_cast 操作能够正确处理不同类型的变量转换,并验证了转换结果的正确性。

emit_normalized#

测试函数专注于验证 emit_normalized 方法在数据流块中处理 MatchCast 操作的能力。它确保:

  1. MatchCast 操作能够被正确地添加到数据流块中

  2. MatchCast 操作的属性(源值、结构信息、目标变量)被正确设置

  3. 数据流块能够正确管理和绑定 MatchCast 操作

emit_normalized 方法是 BlockBuilder 中的一个重要方法,它允许直接添加规范化的操作,而不需要通过 Builder 的其他方法(如 match_cast)来创建。这个测试确保了该方法在数据流块中的正确行为,为开发者提供了更灵活的编程方式。

def test_emit_match_cast_binding_in_dataflow_block():
    bb = rx.BlockBuilder()

    x = rx.Var("x", rx.TensorStructInfo(dtype="float32", ndim=-1))
    m = tir.Var("m", dtype="int64")
    gv = rx.Var("gv", rx.TensorStructInfo(dtype="float32", ndim=-1))
    match_cast = rx.MatchCast(gv, x, rx.TensorStructInfo((m,), "float32"))

    with bb.function("main", [x]):
        with bb.dataflow():
            bb.emit_normalized(match_cast)
            bb.emit_output(gv)
        bb.emit_func_output(x)

    func = bb.finalize()["main"]
    block = func.body.blocks[0]
    b0 = block.bindings[0]
    assert isinstance(b0, rx.MatchCast)

    assert b0.value == x
    assert isinstance(b0.struct_info, rx.TensorStructInfo)
    assert b0.struct_info.shape[0] == m
    assert b0.var == gv

normalize#

测试函数全面验证了 normalize 方法对不同类型节点的规范化能力,包括:

  1. Call 节点:验证操作节点在规范化后具有正确的形状信息

  2. Tuple 节点:验证元组节点在规范化后具有正确的结构信息

  3. 嵌套 Tuple 节点:验证嵌套元组在规范化后保持正确的层次结构信息

normalize 方法是 BlockBuilder 中的一个核心方法,它确保各种类型的节点在创建后具有正确的结构信息,这对于后续的静态分析、优化和代码生成至关重要。通过规范化,框架可以明确每个节点的类型和形状,从而进行更精确的处理。

def test_normalize():
    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")

    x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
    y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
    bb = rx.BlockBuilder()

    # Call node
    add_call = rx.op.multiply(x, y)

    bb.normalize(add_call)
    shape = rx.get_shape_of(add_call)

    assert isinstance(shape, rx.ShapeExpr)
    assert shape[0] == m
    assert shape[1] == n

    # Tuple node
    tuple_1 = rx.Tuple([x, y])
    bb.normalize(tuple_1)
    assert isinstance(tuple_1.struct_info, rx.TupleStructInfo)
    assert isinstance(tuple_1.struct_info.fields[0], rx.TensorStructInfo)
    assert isinstance(tuple_1.struct_info.fields[1], rx.TensorStructInfo)

    # Nested Tuple
    tuple_2 = rx.Tuple([x, rx.Tuple([x, y])])
    bb.normalize(tuple_2)

    assert isinstance(tuple_2.struct_info, rx.TupleStructInfo)
    assert isinstance(tuple_2.struct_info.fields[0], rx.TensorStructInfo)
    assert isinstance(tuple_2.struct_info.fields[1], rx.TupleStructInfo)
    assert isinstance(tuple_2.struct_info.fields[1].fields[0], rx.TensorStructInfo)
    assert isinstance(tuple_2.struct_info.fields[1].fields[1], rx.TensorStructInfo)

处理元组类型的结构信息#

测试确保Relax框架能够正确处理元组类型的结构信息,包括:

  • 元组索引操作能正确继承元素的结构信息

  • 元组解包功能按预期工作

  • 对错误的解包操作提供明确的错误提示 这些功能对于静态类型分析、形状推导和代码正确性验证至关重要。

def test_tuple_indexing():
    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")

    shape_x = rx.TensorStructInfo([m, n], "float16")
    shape_y = rx.TensorStructInfo([n], "float16")
    relax_tuple = rx.Var("relax_tuple", rx.TupleStructInfo([shape_x, shape_y]))

    assert isinstance(relax_tuple.struct_info, rx.TupleStructInfo)
    assert isinstance(relax_tuple.struct_info.fields[0], rx.TensorStructInfo)
    assert isinstance(relax_tuple.struct_info.fields[1], rx.TensorStructInfo)

    # TupleGetItem will initialize struct info from the
    # TupleStructInfo, if present.
    x = relax_tuple[0]
    tvm.ir.assert_structural_equal(x.struct_info, shape_x)

    y = relax_tuple[1]
    tvm.ir.assert_structural_equal(y.struct_info, shape_y)

    # Tuple unpacking produces TupleGetItem structs
    x_unpack, y_unpack = relax_tuple
    tvm.ir.assert_structural_equal(x, x_unpack)
    tvm.ir.assert_structural_equal(y, y_unpack)

    # When TupleStructInfo is available, tuple unpacking fails immediately
    # for incorrect number of arguments.
    with pytest.raises(ValueError):
        x_unpack, y_unpack, z_unpack = relax_tuple

调用TE函数#

测试确保Relax框架能够正确集成和调用TE函数,这是TVM中连接Relax前端与底层张量计算的重要机制。测试验证了:

  • 位置参数和关键字参数的正确传递

  • TE计算与Relax函数的正确嵌套

  • 函数结构和绑定关系的正确性 这些功能对于将高级Relax程序转换为可执行的底层张量计算至关重要。

def test_call_te():
    bb = rx.BlockBuilder()
    n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
    x = rx.Var("x", rx.TensorStructInfo([n, m], "float32"))
    y = rx.Var("y", rx.TensorStructInfo([n, m], "float32"))
    z = rx.Var("z", rx.TensorStructInfo([n, m], "float32"))

    def te_func(args, args_dict, msg):
        A, B = args
        C = args_dict["C"]
        D = te.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
        E = te.compute((128, 128), lambda i, j: D[i, j] - C[i, j])
        return E

    with bb.function("rx_func", [x, y, z]):
        with bb.dataflow():
            out = bb.emit_output(bb.call_te(te_func, [x, y], {"C": z}, msg="hello"))
        bb.emit_func_output(out)

    mod = bb.finalize()
    rx_func = mod["rx_func"]

    assert rx_func.params[0] == x
    assert rx_func.params[1] == y
    assert rx_func.params[2] == z
    assert rx_func.body.body == out
    assert len(rx_func.body.blocks) == 1
    assert len(rx_func.body.blocks[0].bindings) == 1

TE函数的名称唯一性#

测试确保Relax框架在调用TE函数时能够为不同的张量参数生成唯一的名称。名称唯一性对于以下方面至关重要:

  • 避免代码生成和优化过程中的命名冲突

  • 确保张量和缓冲区的正确识别和引用

  • 提高代码的可读性和可维护性

  • 保证底层计算图的正确性

这个测试特别关注矩阵乘法操作,验证了从参数到缓冲区的完整命名链都能保持唯一性。

def test_call_te_unique_tensor_name():
    bb = rx.BlockBuilder()
    x = rx.Var("x", R.Tensor((2, 3), "float32"))
    y = rx.Var("y", R.Tensor((3, 4), "float32"))
    with bb.function("main", [x, y]):
        gv = bb.emit_te(topi.nn.matmul, x, y)
        bb.emit_func_output(gv)

    f_matmul = bb.finalize()["matmul"]
    param_A = f_matmul.params[0]
    param_B = f_matmul.params[1]
    buffer_A = f_matmul.buffer_map[param_A]
    buffer_B = f_matmul.buffer_map[param_B]
    assert param_A.name != param_B.name
    assert buffer_A.name != buffer_B.name
    assert buffer_A.data.name != buffer_B.data.name

调用TE函数时能够正确验证参数类型#

测试确保Relax框架在调用TE函数时能够正确验证参数类型,特别是形状参数。具体来说:

  • topi.reshape通常期望形状参数是具体的元组(如(200,)),而非ShapeStructInfo类型的变量

  • 测试验证了call_te方法能够检测到这种类型不匹配并抛出适当的异常

  • 确保框架在早期阶段就能捕获类型错误,避免后续执行过程中出现更严重的问题

  • 维护了类型安全和API使用的正确性

这个测试是框架健壮性的重要保障,确保用户在使用API时能够得到明确的错误提示,而不是遇到隐藏的类型不兼容问题。

def test_call_te_with_unsupported_shape_arg():
    bb = rx.BlockBuilder()
    x = rx.Var("x", rx.TensorStructInfo((200,), "float32"))
    s = rx.Var("s", rx.ShapeStructInfo((200,)))

    with pytest.raises(AssertionError):
        with bb.function("rx_func", [x]):
            out = bb.emit(bb.call_te(topi.reshape, x, s))
            bb.emit_func_output(out)

emit_te 测试#

def test_emit_te():
    bb = rx.BlockBuilder()
    n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
    x = rx.Var("x", rx.TensorStructInfo([n, m], "float32"))
    y = rx.Var("y", rx.TensorStructInfo([n, m], "float32"))
    z = rx.Var("z", rx.TensorStructInfo([n, m], "float32"))

    def te_func(args, args_dict, msg):
        A, B = args
        C = args_dict["C"]
        D = te.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
        E = te.compute((128, 128), lambda i, j: D[i, j] - C[i, j])
        return E

    with bb.function("rx_func", [x, y, z]):
        out = bb.emit_te(te_func, [x, y], {"C": z}, msg="hello")
        bb.emit_func_output(out)

    mod = bb.finalize()
    rx_func = mod["rx_func"]

    def get_tir_func():
        A = te.placeholder((n, m), dtype="float32", name="A")
        B = te.placeholder((n, m), dtype="float32", name="B")
        C = te.placeholder((n, m), dtype="float32", name="C")
        out = te_func((A, B), {"C": C}, "")
        return tvm.te.create_prim_func([A, B, C, out], index_dtype_override="int64")

    # check TIR structure matches expected
    assert_structural_equal(mod["te_func"].body, get_tir_func().body)

    # check Relax function calls TIR function with call_tir call
    assert rx_func.params[0] == x
    assert rx_func.params[1] == y
    assert rx_func.params[2] == z
    assert rx_func.body.body == out
    assert len(rx_func.body.blocks) == 1
    assert len(rx_func.body.blocks[0].bindings) == 1

    call_node = rx_func.body.blocks[0].bindings[0].value
    assert isinstance(call_node, rx.Call)
    assert len(call_node.args) == 2
    assert call_node.args[0].name_hint == "te_func"
    assert call_node.args[1][0] == x
    assert call_node.args[1][1] == y
    assert call_node.args[1][2] == z
def test_emit_te_multiple():
    bb = rx.BlockBuilder()
    n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
    x = rx.Var("x", rx.TensorStructInfo([n, m], "float32"))
    y = rx.Var("y", rx.TensorStructInfo([n, m], "float32"))
    z = rx.Var("z", rx.TensorStructInfo([128, m], "float32"))

    def te_func(A):
        B = te.compute((128, 128), lambda i, j: A[i, j] + 1)
        return B

    with bb.function("rx_func", [x, y, z]):
        x1 = bb.emit_te(te_func, x)
        y1 = bb.emit_te(te_func, y)
        z1 = bb.emit_te(te_func, z)
        bb.emit_func_output(z1)

    mod = bb.finalize()
    rx_func = mod["rx_func"]

    prim_func = []
    for gv in mod.get_global_vars():
        if isinstance(mod[gv], PrimFunc):
            prim_func.append(mod[gv])

    # only two PrimFuncs were generated since two of them are equal so got deduped
    assert len(prim_func) == 2
    assert rx_func.body.blocks[0].bindings[0].value.args[0].name_hint == "te_func"
    assert rx_func.body.blocks[0].bindings[1].value.args[0].name_hint == "te_func"
    assert rx_func.body.blocks[0].bindings[2].value.args[0].name_hint == "te_func1"
def test_emit_te_multiple_output():
    bb = rx.BlockBuilder()
    n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
    x = rx.Var("x", rx.TensorStructInfo([n, m], "float32"))

    def te_func(A):
        B0, B1 = te.compute((n, m), lambda i, j: (A[i, j] + 1, A[i, j] * 2), name="B")
        return (B0, B1)

    with bb.function("rx_func", [x]):
        y = bb.emit_te(te_func, x)
        z = rx.TupleGetItem(y, 0)
        bb.emit_func_output([y, z])

    rx_func = bb.finalize()["rx_func"]

    # check call tir output shape is a Tuple of ShapeExpr
    assert rx_func.params[0] == x
    call_node = rx_func.body.blocks[0].bindings[0].value
    assert call_node.args[0].name_hint == "te_func"
    assert isinstance(call_node.sinfo_args[0], rx.TupleStructInfo)
    assert len(call_node.sinfo_args[0].fields) == 2
    assert isinstance(call_node.sinfo_args[0].fields[0].shape, rx.ShapeExpr)
    assert isinstance(call_node.sinfo_args[0].fields[1].shape, rx.ShapeExpr)
def test_emit_te_extern():
    bb = rx.BlockBuilder()
    n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
    x = rx.Var("x", rx.TensorStructInfo([n, m], "float32"))
    y = rx.Var("y", rx.TensorStructInfo([m, n], "float32"))

    with bb.function("rx_cblas_matmul", [x, y]):
        out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False)
        bb.emit_func_output(out)

    mod = bb.finalize()
    rx_func = mod["rx_cblas_matmul"]

    # check Relax function calls TIR function with call_tir call
    assert rx_func.params[0] == x
    assert rx_func.params[1] == y
    assert len(rx_func.body.blocks) == 1
    call_node = rx_func.body.blocks[0].bindings[0].value
    assert isinstance(call_node, rx.Call)
    assert len(call_node.args) == 2
    assert call_node.args[0].name_hint == "matmul"
    assert call_node.args[1][0] == x
    assert call_node.args[1][1] == y
    assert call_node.sinfo_args[0].shape[0] == n
    assert call_node.sinfo_args[0].shape[1] == n


def test_emit_te_prim_value():
    bb = rx.BlockBuilder()
    n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
    x = rx.Var("x", R.Tensor([n, m], "float32"))
    a_min = rx.PrimValue(0)
    a_max = rx.PrimValue(6)

    with bb.function("rx_clip", [x]):
        out = bb.emit_te(topi.clip, x, a_min, a_max)
        bb.emit_func_output(out)

    rx_func = bb.finalize()["rx_clip"]

    # check Relax function calls TIR function with call_tir call
    assert rx_func.params[0] == x
    assert len(rx_func.body.blocks) == 1
    call_node = rx_func.body.blocks[0].bindings[0].value
    assert isinstance(call_node, rx.Call)
    assert len(call_node.args) == 2
    assert call_node.args[1][0] == x


def test_nested_function_fail():
    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")
    x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
    y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
    bb = rx.BlockBuilder()

    with pytest.raises(RuntimeError):
        with bb.function("func", [x, y]):
            gv0 = bb.emit(rx.op.add(x, x))
            with bb.function("func1", [x, y]):
                gv1 = bb.emit(rx.op.add(x, x))
            bb.emit_func_output(gv0)


def test_emit_func_output_twice_fail():
    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")
    x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
    y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
    bb = rx.BlockBuilder()

    with pytest.raises(RuntimeError):
        with bb.function("func", [x, y]):
            gv0 = bb.emit(rx.op.add(x, y))
            bb.emit_func_output(gv0)
            bb.emit_func_output(gv0)


def test_func_params_twice_fail():
    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")
    x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
    y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
    bb = rx.BlockBuilder()

    with pytest.raises(RuntimeError):
        with bb.function("func", [x, y]):
            gv0 = bb.emit(rx.op.add(x, y))
            bb.emit_func_output(gv0, [x])


def test_no_func_params_fail():
    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")
    x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
    y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
    bb = rx.BlockBuilder()

    with pytest.raises(RuntimeError):
        with bb.function("func"):
            gv0 = bb.emit(rx.Call(ExternFunc("test.blockbuilder.nop"), []))
            bb.emit_func_output(gv0)


def test_block_builder_scope_recovery():
    bb = rx.BlockBuilder()

    n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
    x = rx.Var("x", rx.TensorStructInfo([n, m], "float32"))
    y = rx.Var("y", rx.TensorStructInfo([m, n], "float32"))

    with pytest.raises(RuntimeError):
        # this line fails
        with bb.function("func", [x, y]):
            gv0 = bb.emit(rx.op.add(x, y))

    # current should be recovered
    assert rx.BlockBuilder.current() is None

    # second attempt to do it correctly.
    with bb.function("func", [x, y]):
        gv0 = bb.emit(rx.op.add(x, y))
        bb.emit_func_output(gv0)


@pytest.mark.parametrize("emit_nested_tuple", [True, False])
def test_emit_nested_tuple(emit_nested_tuple):
    """Convert nested tuples when emitting relax"""

    def make_function(emit_nested_tuple: bool):
        bb = rx.BlockBuilder()

        n_sym = tir.Var("n", "int64")
        m_sym = tir.Var("m", "int64")
        n = rx.Var("n", rx.PrimStructInfo(value=n_sym))
        m = rx.Var("m", rx.PrimStructInfo(value=m_sym))
        x = rx.Var("x", rx.TensorStructInfo([n_sym, m_sym], "float32"))
        y = rx.Var("y", rx.TensorStructInfo([m_sym, n_sym], "float32"))

        with bb.function("func", [n, m, x, y]):
            scalars = (n, m)
            if not emit_nested_tuple:
                scalars = bb.emit(scalars)
            output = (scalars, x, y)
            bb.emit_func_output(output)

        return bb.finalize()["func"]

    def make_expected(emit_nested_tuple: bool):
        if emit_nested_tuple:

            @R.function
            def func(
                n_1: R.Prim(value="n"),
                m_1: R.Prim(value="m"),
                x: R.Tensor(("n", "m"), dtype="float32"),
                y: R.Tensor(("m", "n"), dtype="float32"),
            ):
                return ((n_1, m_1), x, y)

        else:

            @R.function
            def func(
                n_1: R.Prim(value="n"),
                m_1: R.Prim(value="m"),
                x: R.Tensor(("n", "m"), dtype="float32"),
                y: R.Tensor(("m", "n"), dtype="float32"),
            ):
                gv = n_1, m_1
                return (gv, x, y)

        return func

    expected = make_expected(emit_nested_tuple)
    actual = make_function(emit_nested_tuple)

    tvm.ir.assert_structural_equal(expected, actual)


@pytest.mark.skip_well_formed_check_before_transform
def test_finalize_public_private_name_conflict():
    # tir call
    bb = rx.BlockBuilder()

    def te_zero():
        return topi.full((), "int64", tir.IntImm("int64", 0))

    def te_one():
        return topi.full((), "int64", tir.IntImm("int64", 1))

    with bb.function("func", []):
        gv0 = bb.emit_te(te_zero, primfunc_name_hint="func")
        gv1 = bb.emit_te(te_one, primfunc_name_hint="func")
        bb.emit_func_output((gv0, gv1))

    mod = bb.get()
    assert not rx.analysis.well_formed(mod)
    mod_final = bb.finalize()
    assert rx.analysis.well_formed(mod_final)

    # relax function call
    bb = rx.BlockBuilder()

    with bb.function("func", [], private=True):
        gvar = bb.emit_func_output(rx.const(0, "int64"))

    with bb.function("func", [], private=True):
        gv0 = bb.emit(rx.Call(gvar, []))
        gvar1 = bb.emit_func_output(gv0)

    with bb.function("func", []):
        gv0 = bb.emit(rx.Call(gvar1, []))
        bb.emit_func_output(gv0)

    mod = bb.get()
    assert not rx.analysis.well_formed(mod)
    mod_final = bb.finalize()
    assert rx.analysis.well_formed(mod_final)


def test_emit_nested_seqexpr_in_binding_block():
    """May emit a SeqExpr inside a BindingBlock"""

    bb = rx.BlockBuilder()

    with bb.function("func", []):
        lhs = bb.emit(rx.const(1, "int64"), "a")
        rhs = bb.emit(rx.const(2, "int64"), "b")
        out = bb.emit(rx.op.add(lhs, rhs), "c")
        bb.emit_func_output(out)

    seq_expr = bb.finalize()["func"].body

    bb = rx.BlockBuilder()
    with bb.function("func", [], private=True):
        lhs = bb.emit(rx.const(3, "int64"), "d")
        rhs = bb.emit(seq_expr, "e")
        out = bb.emit(rx.op.add(lhs, rhs), "f")
        bb.emit_func_output(out)

    output = bb.finalize()["func"]

    @R.function(private=True)
    def expected():
        d = R.const(3, "int64")
        a = R.const(1, "int64")
        b = R.const(2, "int64")
        c = R.add(a, b)
        e = c
        f = R.add(d, e)
        return f

    tvm.ir.assert_structural_equal(expected, output)


def test_emit_nested_dataflow_seqexpr_in_dataflow_block():
    """May emit a SeqExpr with dataflow inside a DataflowBlock"""
    bb = rx.BlockBuilder()

    with bb.function("func", []):
        with bb.dataflow():
            lhs = bb.emit(rx.const(1, "int64"), "a")
            rhs = bb.emit(rx.const(2, "int64"), "b")
            out = bb.emit_output(rx.op.add(lhs, rhs), "c")
        bb.emit_func_output(out)

    seq_expr = bb.finalize()["func"].body

    bb = rx.BlockBuilder()
    with bb.function("func", [], private=True):
        with bb.dataflow():
            lhs = bb.emit(rx.const(3, "int64"), "d")
            rhs = bb.emit(seq_expr, "e")
            out = bb.emit_output(rx.op.add(lhs, rhs), "f")
        bb.emit_func_output(out)

    output = bb.finalize()["func"]

    @R.function(private=True)
    def expected():
        with R.dataflow():
            d = R.const(3, "int64")
            a = R.const(1, "int64")
            b = R.const(2, "int64")
            c = R.add(a, b)
            e = c
            f = R.add(d, e)
            R.output(c, f)
        return f

    tvm.ir.assert_structural_equal(expected, output)


def test_emit_ill_formed_nested_seqexpr_in_dataflow_block():
    """May emit a SeqExpr inside a DataflowBlock

    This produces ill-formed code, but cannot be caught at the
    normalizer.  See also
    test_emit_well_formed_nested_seqexpr_in_dataflow_block.

    """
    bb = rx.BlockBuilder()

    with bb.function("func", []):
        lhs = bb.emit(rx.const(1, "int64"), "a")
        rhs = bb.emit(rx.const(2, "int64"), "b")
        out = bb.emit(rx.op.add(lhs, rhs), "c")
        bb.emit_func_output(out)

    seq_expr = bb.finalize()["func"].body

    bb = rx.BlockBuilder()
    with bb.function("func", [], private=True):
        with bb.dataflow():
            lhs = bb.emit(rx.const(3, "int64"), "d")
            # This would be ill-formed, as it requires breaking up the
            # DataflowBlock with a BindingBlock.
            rhs = bb.emit(seq_expr, "e")

            # We cannot throw an error at that point, because it is
            # only the later usage of "d" that results in use of a
            # DataflowVar outside of its home DataflowBlock.
            out = bb.emit_output(rx.op.add(lhs, rhs), "f")
        bb.emit_func_output(out)

    output = bb.finalize()["func"]

    assert not rx.analysis.well_formed(tvm.ir.IRModule.from_expr(output))


def test_emit_well_formed_nested_seqexpr_in_dataflow_block():
    """May emit a SeqExpr inside a DataflowBlock

    This produces well-formed code, and should not have any output
    produced by the normalizer.  See also
    test_emit_ill_formed_nested_seqexpr_in_dataflow_block.
    """
    bb = rx.BlockBuilder()

    with bb.function("func", []):
        lhs = bb.emit(rx.const(1, "int64"), "a")
        rhs = bb.emit(rx.const(2, "int64"), "b")
        out = bb.emit(rx.op.add(lhs, rhs), "c")
        bb.emit_func_output(out)

    seq_expr = bb.finalize()["func"].body

    bb = rx.BlockBuilder()
    with bb.function("func", [], private=True):
        with bb.dataflow():
            lhs = bb.emit(rx.const(3, "int64"), "d")
            # This similarly breaks up the DataflowBlock, with
            # identical steps as the previous test up until this
            # point.
            rhs = bb.emit(seq_expr, "e")

            # But the "d" variable isn't used, and so there aren't any
            # usages of DataflowVar outside of their home
            # DataflowBlock.
            out = bb.emit_output(rhs, "f")
        bb.emit_func_output(out)

    output = bb.finalize()["func"]

    assert rx.analysis.well_formed(tvm.ir.IRModule.from_expr(output))

    @R.function(private=True)
    def expected() -> R.Tensor((), dtype="int64"):
        with R.dataflow():
            d = R.const(3, "int64")
            R.output()
        a = R.const(1, "int64")
        b = R.const(2, "int64")
        c = R.add(a, b)
        with R.dataflow():
            e = c
            f = e
            R.output(f)
        return f

    tvm.ir.assert_structural_equal(expected, output)


def test_error_when_unwrapping_dataflowvar():
    """Checks for ill-formed use of DataflowVar at normalization

    We can check for some illegal unwrapping of SeqExpr, though.  If
    the inlined non-dataflow SeqExpr uses a DataflowVar, that should
    trigger an error when the SeqExpr is being unwrapped.
    """
    bb = rx.BlockBuilder()

    lhs = rx.Var("a", rx.TensorStructInfo(shape=[], dtype="int64"))

    with bb.function("func", [lhs]):
        rhs = rx.const(2, "int64")
        out = bb.emit(rx.op.add(lhs, rhs))
        bb.emit_func_output(out)

    func = bb.finalize()["func"]

    bb = rx.BlockBuilder()
    with bb.function("func", [], private=True):
        with bb.dataflow():
            local_lhs = bb.emit(rx.const(3, "int64"), "local_a")
            rhs = bb.emit(func.bind_params({lhs: local_lhs}).body, "f")
            out = bb.emit_output(rhs, "f")

        with pytest.raises(tvm.TVMError, match="Malformed AST"):
            bb.emit_func_output(out)


def test_deduplication_when_input_contains_duplicates():
    """De-duplication of IRModules

    A well-formed IRModule may contain duplicate function definitions.
    This is rare, as most functions can be disambiguated by the the
    function attribute `tvm::attr::kGlobalSymbol`.  However, private
    functions do not have this attribute, and a well-formed IRModule
    may contain multiple copies of the same function.

    This is a regression test.  Previous implementation de-duplicated
    using a `Dict[Function, GlobalVar]`, which has the failure mode
    shown below.  This was resolved by de-duplicating using a
    `Dict[Function, Set[GlobalVar]]` instead.

    """

    @I.ir_module
    class Module:
        @R.function
        def main(A: R.Tensor):
            B = Module.subroutine_a(A)
            C = Module.subroutine_b(B)
            return C

        @R.function(private=True)
        def subroutine_a(arg: R.Tensor) -> R.Tensor:
            return R.add(arg, arg)

        @R.function(private=True)
        def subroutine_b(arg: R.Tensor) -> R.Tensor:
            return R.add(arg, arg)

        @R.function(private=True)
        def subroutine_c(arg: R.Tensor) -> R.Tensor:
            return R.multiply(arg, arg)

    # This test case is only valid when the two subroutines are
    # structurally equal, and therefore allowed to be de-duplicated by
    # the BlockBuilder.
    tvm.ir.assert_structural_equal(Module["subroutine_a"], Module["subroutine_b"])

    gvar_a = Module.get_global_var("subroutine_a")
    gvar_b = Module.get_global_var("subroutine_b")
    subroutine_c = Module["subroutine_c"]

    bb = rx.BlockBuilder(Module)

    # Add a function to the module.  What we add doesn't matter, as
    # this is only to initialize the de-duplication map.
    bb.add_func(subroutine_c, "_unused")
    # The deduplication table now maps `subroutine_ab` to either
    # `gvar_a` or `gvar_b`.

    # Update gvar_a.
    bb.update_func(gvar_a, subroutine_c)
    # The deduplication map no longer has an entry for
    # `subroutine_ab`.

    # Update gvar_b.  The deduplication map is present (because we
    # called `add_func`), but doesn't contain an entry for
    # `subroutine_ab` (because it was just removed).  This throws an
    # error.
    bb.update_func(gvar_b, subroutine_c)
# The tests here depend on tvmscript
import tvm
from tvm import te, tir
from tvm import relax as rx
from tvm.ir.base import assert_structural_equal
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
from tvm.script.parser import tir as T


def test_emit_te_with_symbolic_arg():
    bb = rx.BlockBuilder()
    m = tir.Var("m", "int64")
    x = rx.Var("x", R.Tensor([10], "float32"))
    y = rx.Var("y", R.Shape([m]))

    def te_func(A, offset):
        return te.compute(A.shape, lambda i: A[i + offset], name="B")

    with bb.function("main", [x, y]):
        out = bb.emit_te(te_func, x, m)
        bb.emit_func_output(out)

    after = bb.get()

    @I.ir_module
    class Expected:
        @T.prim_func(private=True)
        def te_func(
            A: T.Buffer((T.int64(10),), "float32"),
            B: T.Buffer((T.int64(10),), "float32"),
            m: T.int64,
        ):
            T.func_attr({"tir.noalias": True})
            for i in range(T.int64(10)):
                with T.block("B"):
                    v_i = T.axis.spatial(T.int64(10), i)
                    T.writes(B[v_i])
                    B[v_i] = A[v_i + m]

        @R.function
        def main(
            x: R.Tensor((10,), dtype="float32"), y: R.Shape(["m"])
        ) -> R.Tensor((10,), dtype="float32"):
            m = T.int64()
            cls = Expected
            gv = R.call_tir(
                cls.te_func,
                (x,),
                out_sinfo=R.Tensor((10,), dtype="float32"),
                tir_vars=R.shape([m]),
            )
            return gv

    assert_structural_equal(after, Expected)


def test_symbolic_shape_in_prim_value():
    """Symbolic vars may be provided to TE in R.Prim"""

    def te_slice(tensor, i):
        return tvm.te.compute([tensor.shape[1]], lambda j: tensor[i, j], name="slice")

    def from_builder():
        bb = rx.BlockBuilder()
        A = rx.Var("A", R.Tensor([16, 16], "float32"))
        tir_i = tvm.tir.Var("tir_i", "int64")
        relax_i = rx.Var("relax_i", R.Prim(value=tir_i))

        with bb.function("main", params=[A, relax_i]):
            A_sliced = bb.emit_te(te_slice, A, relax_i)
            bb.emit_func_output(A_sliced)

        return bb.get()

    @I.ir_module
    class Expected:
        @T.prim_func(private=True)
        def te_slice(
            A: T.Buffer([T.int64(16), T.int64(16)], "float32"),
            Output: T.Buffer(T.int64(16), "float32"),
            row_index: T.int64,
        ):
            T.func_attr({"tir.noalias": True})

            for i in range(A.shape[1]):
                with T.block("slice"):
                    vi = T.axis.remap("S", [i])
                    Output[vi] = A[row_index, vi]

        @R.function
        def main(
            A: R.Tensor([16, 16], "float32"),
            arg_row_index: R.Prim(value="row_index"),
        ):
            cls = Expected

            row_index = T.int64()

            gv = R.call_tir(
                cls.te_slice,
                A,
                tir_vars=[row_index],
                out_sinfo=R.Tensor([16], "float32"),
            )
            return gv

    tvm.ir.assert_structural_equal(from_builder(), Expected)