ONNX Script 模式重写进阶

ONNX Script 模式重写进阶#

参考:pattern_rewriting

本节展示了如何基于模式定义重写规则。

import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh

import onnxscript
from onnxscript import ir
from onnxscript.rewriter import generic_pattern

定义简单模型:

def get_rotary_model(bad_model=False):
    inputs = [
        oh.make_tensor_value_info("x", onnx.TensorProto.INT64, shape=[]),
        oh.make_tensor_value_info("pos_ids", onnx.TensorProto.FLOAT, shape=[]),
        oh.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=[]),
    ]
    nodes = [
        oh.make_node("Unsqueeze", ["x", "axis"], ["_onx_unsqueeze0"]),
        oh.make_node("Cast", ["_onx_unsqueeze0"], ["_onx_cast0"], to=1),
        oh.make_node("MatMul", ["pos_ids", "_onx_cast0"], ["_onx_matmul0"]),
        oh.make_node("Transpose", ["_onx_matmul0"], ["_onx_transpose0"]),
        oh.make_node(
            "ConcatTrainingBad" if bad_model else "ConcatTraining",
            ["_onx_transpose0", "_onx_transpose0"],
            ["_onx_concattraining0", "_onx_concattraining1"],
            domain="com.microsoft",
        ),
        oh.make_node("Sin", ["_onx_concattraining0"], ["_onx_sin0"]),
        oh.make_node("Cast", ["_onx_sin0"], ["_onx_cast02"], to=1),
        oh.make_node("Cos", ["_onx_concattraining0"], ["_onx_cos0"]),
        oh.make_node("Cast", ["_onx_cos0"], ["_onx_cast03"], to=1),
    ]
    outputs = [
        oh.make_tensor_value_info("_onx_cast02", onnx.TensorProto.UNDEFINED, []),
        oh.make_tensor_value_info("_onx_cast03", onnx.TensorProto.UNDEFINED, []),
    ]
    model = oh.make_model(
        oh.make_graph(
            nodes,
            "experiment",
            inputs,
            outputs,
        ),
        opset_imports=[
            oh.make_opsetid("", 18),
            oh.make_opsetid("com.microsoft", 18),
        ],
    )
    return model
model = get_rotary_model()
ir_model = ir.serde.deserialize_model(model)

重写模式(ONNX)#

op = onnxscript.opset18
msft_op = onnxscript.values.Opset("com.microsoft", 1)


def rotary_match_pattern(x, pos_ids, axis):
    """The pattern to match."""
    unsqueeze = op.Unsqueeze(x, axis)
    cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT)

    matmul = op.MatMul(pos_ids, cast)
    transpose = op.Transpose(matmul)
    output, length = msft_op.ConcatTraining(transpose, transpose)

    sin = op.Sin(output)
    cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT)
    cos = op.Cos(output)
    cast2 = op.Cast(cos, to=onnx.TensorProto.FLOAT)
    return cast1, cast2


def validate_rotary_mapping(g, match_result) -> bool:
    """The validation post matching.

    Returns True to validate the replacement,
    False not to apply it.

    :param g: model
    :param match_result: matched nodes
    """
    del g
    del match_result
    return True


def rotary_apply_pattern(x, pos_ids, axis):
    """The replacement pattern."""
    cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
    sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
    part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache)
    return part1, part2

构建规则#

rule_with_validation_function = generic_pattern.make_pattern_rule(
    rotary_match_pattern,
    rotary_apply_pattern,
    validate_rotary_mapping,
)
'RotaryEmbedding' is not a known op in 'com.microsoft'
'ConcatTraining' is not a known op in 'com.microsoft'

validate_rotary_mapping 函数总是返回 True。在这种情况下,可以忽略这个参数。

rule = generic_pattern.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern)
'RotaryEmbedding' is not a known op in 'com.microsoft'
rotary_apply_pattern: Already defined.
'ConcatTraining' is not a known op in 'com.microsoft'
rotary_match_pattern: Already defined.

应用规则:

rule.apply_to_model(ir_model)
1

重写模型:

rewritten_model = ir.serde.serialize_model(ir_model)

查看运行情况:

for node in rewritten_model.graph.node:
    print(f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}")
Constant() -> val_0
Constant() -> val_1
RotaryEmbedding(x, pos_ids, cos_cache, sin_cache) -> val_2, val_3

如果它失败了呢?

model = get_rotary_model(True)
ir_model = ir.serde.deserialize_model(model)

rule.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)

print([n.op_type for n in rewritten_model.graph.node])
['Unsqueeze', 'Cast', 'MatMul', 'Transpose', 'ConcatTrainingBad', 'Sin', 'Cast', 'Cos', 'Cast']

匹配没有发生,我们可以增加细节。

rule = generic_pattern.make_pattern_rule(
    rotary_match_pattern, rotary_apply_pattern, verbose=10
)

rule.apply_to_model(ir_model)
[GenericPattern.match] starts with %"_onx_cast0"<?,?> ⬅️ ::Cast(%"_onx_unsqueeze0") {to=1}
[GenericPattern.match] match pattern <onnxscript.rewriter.generic_pattern.FunctionPattern object at 0x7fbe13cc1340>
[GenericPattern.match] iteration=1 n_matched=1, n_stack=1, matched_types=Counter({'Cast': 1})
[FunctionPattern.match] NONE - line: 460:onnxscript.rewriter.generic_pattern, op_type=Cast
    --hint--: BACKWARD: different node types
      %"cos"<?,?> ⬅️ ::Cos(%"output")
      %"_onx_unsqueeze0"<?,?> ⬅️ ::Unsqueeze(%"x", %"axis")
    iteration=0
    --matched-- #1
      Cast(%"cos"<?,?>) ~ Cast(%"_onx_unsqueeze0"<?,?>) [140454351111376-140454352338048]
    len(stack)=0:[]
[GenericPattern.match] done. backward failed.
[GenericPattern.match] starts with %"_onx_cast02"<UNDEFINED,[]> ⬅️ ::Cast(%"_onx_sin0") {to=1}
[GenericPattern.match] match pattern <onnxscript.rewriter.generic_pattern.FunctionPattern object at 0x7fbe13cc1340>
[GenericPattern.match] iteration=1 n_matched=1, n_stack=1, matched_types=Counter({'Cast': 1})
[FunctionPattern.match] NONE - line: 460:onnxscript.rewriter.generic_pattern, op_type=Cast
    --hint--: BACKWARD: different node types
      %"cos"<?,?> ⬅️ ::Cos(%"output")
      %"_onx_sin0"<?,?> ⬅️ ::Sin(%"_onx_concattraining0")
    iteration=0
    --matched-- #1
      Cast(%"cos"<?,?>) ~ Cast(%"_onx_sin0"<?,?>) [140454351111376-140454352338768]
    len(stack)=0:[]
[GenericPattern.match] done. backward failed.
[GenericPattern.match] starts with %"_onx_cast03"<UNDEFINED,[]> ⬅️ ::Cast(%"_onx_cos0") {to=1}
[GenericPattern.match] match pattern <onnxscript.rewriter.generic_pattern.FunctionPattern object at 0x7fbe13cc1340>
[GenericPattern.match] iteration=1 n_matched=1, n_stack=1, matched_types=Counter({'Cast': 1})
[GenericPattern._match_backward] match Cos((Value('_onx_concattraining0', type=None, shape=None, producer=anonymous_node:140454352338480, index=0),)) with Cos((Value('output', type=None, shape=None, producer=n4, index=0),)) (pattern)
[GenericPattern._match_backward] add 1 nodes
[GenericPattern.match] iteration=2 n_matched=2, n_stack=1, matched_types=Counter({'Cast': 1, 'Cos': 1})
[FunctionPattern.match] NONE - line: 460:onnxscript.rewriter.generic_pattern, op_type=Cast
    --hint--: BACKWARD: different node types
      %"output"<?,?>, %"length"<?,?> ⬅️ com.microsoft::ConcatTraining(%"transpose", %"transpose")
      %"_onx_concattraining0"<?,?>, %"_onx_concattraining1"<?,?> ⬅️ com.microsoft::ConcatTrainingBad(%"_onx_transpose0", %"_onx_transpose0")
    iteration=1
    --matched-- #2
      Cast(%"cos"<?,?>) ~ Cast(%"_onx_cos0"<?,?>) [140454351111376-140454351110224]
      Cos(%"output"<?,?>) ~ Cos(%"_onx_concattraining0"<?,?>) [140454351111088-140455237148048]
    len(stack)=0:[]
[GenericPattern.match] done. backward failed.
'RotaryEmbedding' is not a known op in 'com.microsoft'
rotary_apply_pattern: Already defined.
'ConcatTraining' is not a known op in 'com.microsoft'
rotary_match_pattern: Already defined.
0

日志显示了每次算法拒绝模式的情况。

可能的信息为:

::

    [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast
        --hint--: BACKWARD: different node types
          --pattern
          ConcatTraining(transpose, transpose) -> (output, length)
          -- model
          ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1)
        iteration=1
        --marked-- #2
          Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320]
          Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472]
        len(stacked)=0:[]

在文件 generic_pattern.py 的第673行,匹配被拒绝了。它表示在向后方向比较两个节点时,节点类型不匹配。它还表示实际上有两个节点是匹配的。