基于模式的使用规则进行重写#
ONNX 重写工具为用户提供了一个功能,可以根据用户提供的重写规则,将 ONNX 计算图中的某些模式替换为另一种模式。
使用方法#
在计算图重写模式时,需要三个主要部分:
target_pattern
:要匹配的原始模式。这个模式使用类似 ONNXScript 的算子编写函数。replacement_pattern
:用于替换原始模式的模式。这个模式也使用类似 ONNXScript 的算子编写函数。match_condition
(可选):只有满足匹配条件时,才会进行模式重写。
简单的例子#
一个简单示例,演示了如何使用 GELU 激活函数的此功能:
可以使用给定公式中的高斯误差函数来计算 GELU 激活函数:
from onnxscript.rewriter import pattern
from onnxscript import ir
使用 onnxscript
算子创建需要替换的目标模式。
import math
def erf_gelu_pattern(op, x):
return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))
之后,创建由 GELU onnxscript
算子组成的替换模式。
def gelu(op, x: ir.Value):
return op.Gelu(x, _domain="com.microsoft")
备注
替换模式的输入是 ir.Value
类型。
在这个例子中,我们不需要 match_condition
,所以暂时跳过这个选项。然后使用 RewriteRule
函数创建重写规则。
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
)
现在重写规则已经创建,下一步是应用这些基于模式的重写规则。rewriter.rewrite
调用包含三个主要部分:
model
:要应用模式重写规则的原始模型。这是onnx.ModelProto
类型。function_rewrite_rules
:(可选)此参数用于传递基于函数名称的重写规则。如何使用此参数的步骤将在另一个教程中介绍。此参数是Sequence[type[FunctionRewriteRule]]
类型。pattern_rewrite_rules
:(可选)此参数用于传递基于提供的替换模式的重写规则。在本教程中,我们将仅使用此参数与model结合。此参数是以下类型之一:Sequence[PatternRewriteRule]
RewriteRuleSet
备注
pattern_rewrite_rules
接受 PatternRewriteRule
类型的序列,或者 RewriteRuleSet
,后者本质上也是使用 PatternRewriteRule
类型的序列创建的规则集。因此,如果要传递单个重写规则,需要将其作为序列的一部分传递。有关如何创建和使用规则集的步骤,请参阅“使用不同模式创建规则集”部分中的示例。
下面的代码片段演示了如何使用 rewriter.rewrite
调用上述创建的重写规则:
def apply_rewrite(model):
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement
)
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=[rule],
)
return model_with_rewrite_applied
利用 commute
参数进行模式匹配#
使用不同模式创建规则集#
此方法需要创建两个单独的规则,并将它们打包成 PatternRewriteRules
的序列或 RewriteRuleSet
。创建 RewriteRuleSet
是首选选项,但两者都可以使用。为了创建一个包含多个规则(例如 rule1
和 rule2
)的 RewriteRuleSet
:
from onnxscript.rewriter import pattern
rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2])
为了将此方法应用于上述示例,首先创建两个单独的目标模式,如下所示:
def erf_gelu_pattern(op, x):
return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))
def erf_gelu_pattern_2(op, x):
return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5
然后,为每个目标模式创建两个单独的 PatternRewriteRules
。将这些规则打包成一个 RewriteRuleSet
对象,并通过传递创建的 RewriteRuleSet
作为 pattern_rewrite_rules
参数来应用重写。
def apply_rewrite_with_ruleset(model):
# Create multiple rules
rule1 = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement
)
rule2 = pattern.RewriteRule(
erf_gelu_pattern_2, # Target Pattern
gelu, # Replacement
)
# Create a Rewrite Rule Set with multiple rules.
rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2])
# Apply rewrites
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
# pattern_rewrite_rules=[rule1, rule2], # Alternative method of passing multiple rules
)
return model_with_rewrite_applied
在创建规则时使用 commute
参数#
为相似模式创建多个目标模式可能会很繁琐。为了避免这种情况,可以在创建 RewriteRuleSet
时利用 commute
参数。只需设置 commute=True
,即可避免为因交换性而不同的模式创建多个目标模式。满足交换性属性的不同模式的多个规则会自动打包成 RewriteRuleSet
对象。然后通过传递创建的 RewriteRuleSet
作为 pattern_rewrite_rules
参数来应用重写。
def apply_rewrite_with_commute(model):
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement
)
# Create a Rewrite Rule Set with commute=True
rewrite_rule_set = pattern.RewriteRuleSet([rule], commute=True)
# Apply rewrites
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite_applied
使用 match_condition
参数进行模式匹配#
本节将讨论如何利用 match_condition
参数。match_condition
参数检查模式是否在考虑某些约束的情况下与目标模式匹配。
基于 ONNX Matmul 规范,onnx Matmul的行为类似于 numpy.matmul
,并且也遵循 numpy
广播。因此,在这个特定模式中,如果 matmul
广播满足,那么我们不需要 reshapes
。为了验证这一点,我们需要检查以下内容:
输入形状检查:
input_a
和input_b
应该是可广播的输出形状检查:
shape_c
应该与从matmul(input_a, input_b)
得到的输出形状相同
如果上述情况为真,那么我们不需要 reshapes
,可以使用基于模式的重写来消除它们。
首先,以类似于第一个示例的方式编写目标模式和替换模式。
def two_reshapes_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_b, shape_c):
reshape_a = op.Reshape(input_a, shape_a)
reshape_b = op.Reshape(input_b, shape_b)
matmul = op.MatMul(reshape_a, reshape_b)
return op.Reshape(matmul, shape_c)
def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_):
return op.MatMul(input_a, input_b)
备注
在这种情况下,目标模式有5个输入:input_a
、input_b
、shape_a
、shape_b
、shape_c
。然而,替换模式仅利用了 input_a
和 input_b
。为了避免在替换模式签名中引用所有未使用的参数,只传递 input_a
和 input_b
,并使用 **_
来表示所有未使用的参数。
同样,在编写条件检查函数时,我们只需要 input_a
、input_b
和 shape_c
。在条件匹配函数签名中使用 **_
来表示所有未使用的参数。
为了验证 matmul
广播是否满足,我们编写一个条件检查函数,如下所示:
def check_if_not_need_reshape(
context, input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_
) -> bool:
"""Condition to check if we need to replace the pattern.
If matmul broadcasting is enough, then we don't need the reshapes.
To validate this, we need to check the following:
1. Input shapes check: input_a and input_b should be broadcastable
2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b)
If the above are true, then we don't need the reshapes.
Returns:
True if we need to replace the pattern, False otherwise.
"""
del context # Reserved for future extensions
input_a_shape = input_a.shape
input_b_shape = input_b.shape
# TODO: Get a helper func to get const_value
_ir_utils.propagate_const_value(shape_c)
shape_c_tensor = shape_c.const_value
if shape_c_tensor is None:
logger.info("The value 'shape_c' is not statically known.")
return False
if len(shape_c_tensor.shape) != 1:
logger.info(
"Unexpected final shape. The shape of 'shape' value is %s",
shape_c_tensor.shape,
)
return False
# NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape
# information. So, we need to check if the shape is None and return False.
if input_a_shape is None or input_b_shape is None:
logger.info("Shape information is not available for the inputs and outputs.")
return False
input_a_shape = input_a_shape.numpy()
input_b_shape = input_b_shape.numpy()
shape_c = shape_c_tensor.numpy().tolist()
a_rank = len(input_a_shape)
b_rank = len(input_b_shape)
# TODO(justinchuby): Check shape size
# 1. Check if input shapes are broadcastable
# 1.a. If the first input is 1-D, check whether
# the dim matches the last second dim of the second input.
mimic_matmul_broadcast_behavior = False
if a_rank < 2:
if b_rank < 2:
logger.info("Optimization of dot product is not supported yet.")
return False
if input_a_shape[-1] != input_b_shape[-2]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_a_shape = [1, *input_a_shape]
a_rank = len(input_a_shape)
mimic_matmul_broadcast_behavior = True
# 1.b. If the second input is 1-D, check whether
# the dim matches the last dim of the first input.
if b_rank < 2:
if input_b_shape[-1] != input_a_shape[-1]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_b_shape = [*input_b_shape, 1]
b_rank = len(input_b_shape)
mimic_matmul_broadcast_behavior = True
# 1.c. If both inputs are at least 2-D, check whether
# the last dimension of the first input matches the second
# last dimension of the second input, and shape[:-2] are
# broadcastable.
input_a_shape_except_second_last_dim = [*input_a_shape[:-2], *[input_a_shape[-1]]]
input_b_shape_except_last_dim = input_b_shape[:-1]
broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]]
for idx, (dim_from_a, dim_from_b) in enumerate(
zip(
reversed(input_a_shape_except_second_last_dim),
reversed(input_b_shape_except_last_dim),
)
):
if dim_from_a not in {1, dim_from_b}:
logger.info("Original shape is not broadcastable.")
return False
elif idx > 0:
broadcast_matmul_output_shape = [
max(dim_from_a, dim_from_b),
*broadcast_matmul_output_shape,
]
# 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b)
# Prepend the broadcast_matmul_output_shape with the longer shape of input
if a_rank > b_rank:
longer_shape = input_a_shape
shorter_shape = input_b_shape
else:
longer_shape = input_b_shape
shorter_shape = input_a_shape
broadcast_matmul_output_shape = [
*longer_shape[: -len(shorter_shape)],
*broadcast_matmul_output_shape,
]
if mimic_matmul_broadcast_behavior and b_rank == 2 and input_b_shape[-1] == 1:
# If input_b is expanded to 2-D, then we need to remove the last dimension
broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1]
if mimic_matmul_broadcast_behavior and a_rank == 2 and input_a_shape[0] == 1:
# If input_a is expanded to 2-D, then we need to remove the first dimension
# of input_a, which would be the -2nd dimension of the output shape.
broadcast_matmul_output_shape.pop(-2)
if shape_c != broadcast_matmul_output_shape:
logger.info(
"Final output shape is not the same. Expected %s vs actual %s",
shape_c,
broadcast_matmul_output_shape,
)
return False
return True
所有必要的组件都准备好了,现在创建带有 match_condition
函数的模式重写规则,然后调用 rewriter.rewrite
来应用重写。
def apply_rewrite(model):
# Create rewrite rules
two_reshapes_matmul_reshape_rule = pattern.RewriteRule(
two_reshapes_matmul_reshape_pattern, # target pattern
matmul_pattern, # replacement pattern
check_if_not_need_reshape, # match_condition function
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet([two_reshapes_matmul_reshape_rule])
# Apply rewrite while passing match_condition
model_with_rewrite = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite