基于模式的使用规则进行重写#

参考:rewrite_patterns

ONNX 重写工具为用户提供了一个功能,可以根据用户提供的重写规则,将 ONNX 计算图中的某些模式替换为另一种模式。

使用方法#

在计算图重写模式时,需要三个主要部分:

  • target_pattern:要匹配的原始模式。这个模式使用类似 ONNXScript 的算子编写函数。

  • replacement_pattern:用于替换原始模式的模式。这个模式也使用类似 ONNXScript 的算子编写函数。

  • match_condition(可选):只有满足匹配条件时,才会进行模式重写。

简单的例子#

一个简单示例,演示了如何使用 GELU 激活函数的此功能:

可以使用给定公式中的高斯误差函数来计算 GELU 激活函数:

\[ \text{GELU} = x\Phi(x) = x \cdot \frac{1}{2} [1 + \text{erf}(x / \sqrt{2})] \]
from onnxscript.rewriter import pattern
from onnxscript import ir

使用 onnxscript 算子创建需要替换的目标模式。

import math
def erf_gelu_pattern(op, x):
    return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))

之后,创建由 GELU onnxscript 算子组成的替换模式。

def gelu(op, x: ir.Value):
    return op.Gelu(x, _domain="com.microsoft")

备注

替换模式的输入是 ir.Value 类型。

在这个例子中,我们不需要 match_condition,所以暂时跳过这个选项。然后使用 RewriteRule 函数创建重写规则。

rule = pattern.RewriteRule(
    erf_gelu_pattern,  # Target Pattern
    gelu,  # Replacement Pattern
)

现在重写规则已经创建,下一步是应用这些基于模式的重写规则。rewriter.rewrite 调用包含三个主要部分:

  • model:要应用模式重写规则的原始模型。这是 onnx.ModelProto 类型。

  • function_rewrite_rules:(可选)此参数用于传递基于函数名称的重写规则。如何使用此参数的步骤将在另一个教程中介绍。此参数是 Sequence[type[FunctionRewriteRule]] 类型。

  • pattern_rewrite_rules:(可选)此参数用于传递基于提供的替换模式的重写规则。在本教程中,我们将仅使用此参数与model结合。此参数是以下类型之一:

    • Sequence[PatternRewriteRule]

    • RewriteRuleSet

备注

pattern_rewrite_rules 接受 PatternRewriteRule 类型的序列,或者 RewriteRuleSet,后者本质上也是使用 PatternRewriteRule 类型的序列创建的规则集。因此,如果要传递单个重写规则,需要将其作为序列的一部分传递。有关如何创建和使用规则集的步骤,请参阅“使用不同模式创建规则集”部分中的示例。

下面的代码片段演示了如何使用 rewriter.rewrite 调用上述创建的重写规则:

def apply_rewrite(model):
    rule = pattern.RewriteRule(
        erf_gelu_pattern,  # Target Pattern
        gelu,  # Replacement
    )
    model_with_rewrite_applied = onnxscript.rewriter.rewrite(
        model,
        pattern_rewrite_rules=[rule],
    )
    return model_with_rewrite_applied

利用 commute 参数进行模式匹配#

使用不同模式创建规则集#

此方法需要创建两个单独的规则,并将它们打包成 PatternRewriteRules 的序列或 RewriteRuleSet。创建 RewriteRuleSet 是首选选项,但两者都可以使用。为了创建一个包含多个规则(例如 rule1rule2)的 RewriteRuleSet

from onnxscript.rewriter import pattern
rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2])

为了将此方法应用于上述示例,首先创建两个单独的目标模式,如下所示:

def erf_gelu_pattern(op, x):
    return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))

def erf_gelu_pattern_2(op, x):
    return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5

然后,为每个目标模式创建两个单独的 PatternRewriteRules。将这些规则打包成一个 RewriteRuleSet 对象,并通过传递创建的 RewriteRuleSet 作为 pattern_rewrite_rules 参数来应用重写。

def apply_rewrite_with_ruleset(model):
    # Create multiple rules
    rule1 = pattern.RewriteRule(
        erf_gelu_pattern,  # Target Pattern
        gelu,  # Replacement
    )
    rule2 = pattern.RewriteRule(
        erf_gelu_pattern_2,  # Target Pattern
        gelu,  # Replacement
    )
    # Create a Rewrite Rule Set with multiple rules.
    rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2])
    # Apply rewrites
    model_with_rewrite_applied = onnxscript.rewriter.rewrite(
        model,
        pattern_rewrite_rules=rewrite_rule_set,
        # pattern_rewrite_rules=[rule1, rule2], # Alternative method of passing multiple rules
    )
    return model_with_rewrite_applied

在创建规则时使用 commute 参数#

为相似模式创建多个目标模式可能会很繁琐。为了避免这种情况,可以在创建 RewriteRuleSet 时利用 commute 参数。只需设置 commute=True,即可避免为因交换性而不同的模式创建多个目标模式。满足交换性属性的不同模式的多个规则会自动打包成 RewriteRuleSet 对象。然后通过传递创建的 RewriteRuleSet 作为 pattern_rewrite_rules 参数来应用重写。

def apply_rewrite_with_commute(model):
    rule = pattern.RewriteRule(
        erf_gelu_pattern,  # Target Pattern
        gelu,  # Replacement
    )
    # Create a Rewrite Rule Set with commute=True
    rewrite_rule_set = pattern.RewriteRuleSet([rule], commute=True)
    # Apply rewrites
    model_with_rewrite_applied = onnxscript.rewriter.rewrite(
        model,
        pattern_rewrite_rules=rewrite_rule_set,
    )
    return model_with_rewrite_applied

使用 match_condition 参数进行模式匹配#

本节将讨论如何利用 match_condition 参数。match_condition 参数检查模式是否在考虑某些约束的情况下与目标模式匹配。

基于 ONNX Matmul 规范,onnx Matmul的行为类似于 numpy.matmul,并且也遵循 numpy 广播。因此,在这个特定模式中,如果 matmul 广播满足,那么我们不需要 reshapes。为了验证这一点,我们需要检查以下内容:

  • 输入形状检查:input_ainput_b 应该是可广播的

  • 输出形状检查:shape_c 应该与从 matmul(input_a, input_b) 得到的输出形状相同

如果上述情况为真,那么我们不需要 reshapes,可以使用基于模式的重写来消除它们。

首先,以类似于第一个示例的方式编写目标模式和替换模式。

def two_reshapes_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_b, shape_c):
    reshape_a = op.Reshape(input_a, shape_a)
    reshape_b = op.Reshape(input_b, shape_b)
    matmul = op.MatMul(reshape_a, reshape_b)
    return op.Reshape(matmul, shape_c)

def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_):
    return op.MatMul(input_a, input_b)

备注

在这种情况下,目标模式有5个输入:input_ainput_bshape_ashape_bshape_c。然而,替换模式仅利用了 input_ainput_b。为了避免在替换模式签名中引用所有未使用的参数,只传递 input_ainput_b,并使用 **_ 来表示所有未使用的参数。

同样,在编写条件检查函数时,我们只需要 input_ainput_bshape_c。在条件匹配函数签名中使用 **_ 来表示所有未使用的参数。

为了验证 matmul 广播是否满足,我们编写一个条件检查函数,如下所示:

def check_if_not_need_reshape(
    context, input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_
) -> bool:
    """Condition to check if we need to replace the pattern.

    If matmul broadcasting is enough, then we don't need the reshapes.

    To validate this, we need to check the following:
    1. Input shapes check: input_a and input_b should be broadcastable
    2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b)

    If the above are true, then we don't need the reshapes.

    Returns:
        True if we need to replace the pattern, False otherwise.
    """
    del context  # Reserved for future extensions

    input_a_shape = input_a.shape
    input_b_shape = input_b.shape
    # TODO: Get a helper func to get const_value
    _ir_utils.propagate_const_value(shape_c)
    shape_c_tensor = shape_c.const_value
    if shape_c_tensor is None:
        logger.info("The value 'shape_c' is not statically known.")
        return False

    if len(shape_c_tensor.shape) != 1:
        logger.info(
            "Unexpected final shape. The shape of 'shape' value is %s",
            shape_c_tensor.shape,
        )
        return False

    # NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape
    # information. So, we need to check if the shape is None and return False.
    if input_a_shape is None or input_b_shape is None:
        logger.info("Shape information is not available for the inputs and outputs.")
        return False
    input_a_shape = input_a_shape.numpy()
    input_b_shape = input_b_shape.numpy()
    shape_c = shape_c_tensor.numpy().tolist()

    a_rank = len(input_a_shape)
    b_rank = len(input_b_shape)

    # TODO(justinchuby): Check shape size

    # 1. Check if input shapes are broadcastable
    # 1.a. If the first input is 1-D, check whether
    # the dim matches the last second dim of the second input.
    mimic_matmul_broadcast_behavior = False
    if a_rank < 2:
        if b_rank < 2:
            logger.info("Optimization of dot product is not supported yet.")
            return False
        if input_a_shape[-1] != input_b_shape[-2]:
            logger.info("Original shape is not MatMul compatible.")
            return False
        else:
            input_a_shape = [1, *input_a_shape]
            a_rank = len(input_a_shape)
            mimic_matmul_broadcast_behavior = True
    # 1.b. If the second input is 1-D, check whether
    # the dim matches the last dim of the first input.
    if b_rank < 2:
        if input_b_shape[-1] != input_a_shape[-1]:
            logger.info("Original shape is not MatMul compatible.")
            return False
        else:
            input_b_shape = [*input_b_shape, 1]
            b_rank = len(input_b_shape)
            mimic_matmul_broadcast_behavior = True
    # 1.c. If both inputs are at least 2-D, check whether
    # the last dimension of the first input matches the second
    # last dimension of the second input, and shape[:-2] are
    # broadcastable.
    input_a_shape_except_second_last_dim = [*input_a_shape[:-2], *[input_a_shape[-1]]]
    input_b_shape_except_last_dim = input_b_shape[:-1]
    broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]]
    for idx, (dim_from_a, dim_from_b) in enumerate(
        zip(
            reversed(input_a_shape_except_second_last_dim),
            reversed(input_b_shape_except_last_dim),
        )
    ):
        if dim_from_a not in {1, dim_from_b}:
            logger.info("Original shape is not broadcastable.")
            return False
        elif idx > 0:
            broadcast_matmul_output_shape = [
                max(dim_from_a, dim_from_b),
                *broadcast_matmul_output_shape,
            ]

    # 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b)
    # Prepend the broadcast_matmul_output_shape with the longer shape of input
    if a_rank > b_rank:
        longer_shape = input_a_shape
        shorter_shape = input_b_shape
    else:
        longer_shape = input_b_shape
        shorter_shape = input_a_shape
    broadcast_matmul_output_shape = [
        *longer_shape[: -len(shorter_shape)],
        *broadcast_matmul_output_shape,
    ]
    if mimic_matmul_broadcast_behavior and b_rank == 2 and input_b_shape[-1] == 1:
        # If input_b is expanded to 2-D, then we need to remove the last dimension
        broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1]
    if mimic_matmul_broadcast_behavior and a_rank == 2 and input_a_shape[0] == 1:
        # If input_a is expanded to 2-D, then we need to remove the first dimension
        # of input_a, which would be the -2nd dimension of the output shape.
        broadcast_matmul_output_shape.pop(-2)
    if shape_c != broadcast_matmul_output_shape:
        logger.info(
            "Final output shape is not the same. Expected %s vs actual %s",
            shape_c,
            broadcast_matmul_output_shape,
        )
        return False

    return True

所有必要的组件都准备好了,现在创建带有 match_condition 函数的模式重写规则,然后调用 rewriter.rewrite 来应用重写。

def apply_rewrite(model):
    # Create rewrite rules
    two_reshapes_matmul_reshape_rule = pattern.RewriteRule(
        two_reshapes_matmul_reshape_pattern,  # target pattern
        matmul_pattern,  # replacement pattern
        check_if_not_need_reshape,  # match_condition function
    )
    # Create a Rewrite Rule Set
    rewrite_rule_set = pattern.RewriteRuleSet([two_reshapes_matmul_reshape_rule])
    # Apply rewrite while passing match_condition
    model_with_rewrite = onnxscript.rewriter.rewrite(
        model,
        pattern_rewrite_rules=rewrite_rule_set,
    )
    return model_with_rewrite