ONNX Script 模式重写进阶#
本节展示了如何基于模式定义重写规则。
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
import onnxscript
from onnxscript import ir
from onnxscript.rewriter import generic_pattern
定义简单模型:
def get_rotary_model(bad_model=False):
inputs = [
oh.make_tensor_value_info("x", onnx.TensorProto.INT64, shape=[]),
oh.make_tensor_value_info("pos_ids", onnx.TensorProto.FLOAT, shape=[]),
oh.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=[]),
]
nodes = [
oh.make_node("Unsqueeze", ["x", "axis"], ["_onx_unsqueeze0"]),
oh.make_node("Cast", ["_onx_unsqueeze0"], ["_onx_cast0"], to=1),
oh.make_node("MatMul", ["pos_ids", "_onx_cast0"], ["_onx_matmul0"]),
oh.make_node("Transpose", ["_onx_matmul0"], ["_onx_transpose0"]),
oh.make_node(
"ConcatTrainingBad" if bad_model else "ConcatTraining",
["_onx_transpose0", "_onx_transpose0"],
["_onx_concattraining0", "_onx_concattraining1"],
domain="com.microsoft",
),
oh.make_node("Sin", ["_onx_concattraining0"], ["_onx_sin0"]),
oh.make_node("Cast", ["_onx_sin0"], ["_onx_cast02"], to=1),
oh.make_node("Cos", ["_onx_concattraining0"], ["_onx_cos0"]),
oh.make_node("Cast", ["_onx_cos0"], ["_onx_cast03"], to=1),
]
outputs = [
oh.make_tensor_value_info("_onx_cast02", onnx.TensorProto.UNDEFINED, []),
oh.make_tensor_value_info("_onx_cast03", onnx.TensorProto.UNDEFINED, []),
]
model = oh.make_model(
oh.make_graph(
nodes,
"experiment",
inputs,
outputs,
),
opset_imports=[
oh.make_opsetid("", 18),
oh.make_opsetid("com.microsoft", 18),
],
)
return model
model = get_rotary_model()
ir_model = ir.serde.deserialize_model(model)
重写模式(ONNX)#
op = onnxscript.opset18
msft_op = onnxscript.values.Opset("com.microsoft", 1)
def rotary_match_pattern(x, pos_ids, axis):
"""The pattern to match."""
unsqueeze = op.Unsqueeze(x, axis)
cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT)
matmul = op.MatMul(pos_ids, cast)
transpose = op.Transpose(matmul)
output, length = msft_op.ConcatTraining(transpose, transpose)
sin = op.Sin(output)
cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT)
cos = op.Cos(output)
cast2 = op.Cast(cos, to=onnx.TensorProto.FLOAT)
return cast1, cast2
def validate_rotary_mapping(g, match_result) -> bool:
"""The validation post matching.
Returns True to validate the replacement,
False not to apply it.
:param g: model
:param match_result: matched nodes
"""
del g
del match_result
return True
def rotary_apply_pattern(x, pos_ids, axis):
"""The replacement pattern."""
cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache)
return part1, part2
构建规则#
rule_with_validation_function = generic_pattern.make_pattern_rule(
rotary_match_pattern,
rotary_apply_pattern,
validate_rotary_mapping,
)
'RotaryEmbedding' is not a known op in 'com.microsoft'
'ConcatTraining' is not a known op in 'com.microsoft'
validate_rotary_mapping
函数总是返回 True
。在这种情况下,可以忽略这个参数。
rule = generic_pattern.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern)
'RotaryEmbedding' is not a known op in 'com.microsoft'
rotary_apply_pattern: Already defined.
'ConcatTraining' is not a known op in 'com.microsoft'
rotary_match_pattern: Already defined.
应用规则:
rule.apply_to_model(ir_model)
1
重写模型:
rewritten_model = ir.serde.serialize_model(ir_model)
查看运行情况:
for node in rewritten_model.graph.node:
print(f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}")
Constant() -> val_0
Constant() -> val_1
RotaryEmbedding(x, pos_ids, cos_cache, sin_cache) -> val_2, val_3
如果它失败了呢?
model = get_rotary_model(True)
ir_model = ir.serde.deserialize_model(model)
rule.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)
print([n.op_type for n in rewritten_model.graph.node])
['Unsqueeze', 'Cast', 'MatMul', 'Transpose', 'ConcatTrainingBad', 'Sin', 'Cast', 'Cos', 'Cast']
匹配没有发生,我们可以增加细节。
rule = generic_pattern.make_pattern_rule(
rotary_match_pattern, rotary_apply_pattern, verbose=10
)
rule.apply_to_model(ir_model)
[GenericPattern.match] starts with %"_onx_cast0"<?,?> ⬅️ ::Cast(%"_onx_unsqueeze0") {to=1}
[GenericPattern.match] match pattern <onnxscript.rewriter.generic_pattern.FunctionPattern object at 0x7fbe13cc1340>
[GenericPattern.match] iteration=1 n_matched=1, n_stack=1, matched_types=Counter({'Cast': 1})
[FunctionPattern.match] NONE - line: 460:onnxscript.rewriter.generic_pattern, op_type=Cast
--hint--: BACKWARD: different node types
%"cos"<?,?> ⬅️ ::Cos(%"output")
%"_onx_unsqueeze0"<?,?> ⬅️ ::Unsqueeze(%"x", %"axis")
iteration=0
--matched-- #1
Cast(%"cos"<?,?>) ~ Cast(%"_onx_unsqueeze0"<?,?>) [140454351111376-140454352338048]
len(stack)=0:[]
[GenericPattern.match] done. backward failed.
[GenericPattern.match] starts with %"_onx_cast02"<UNDEFINED,[]> ⬅️ ::Cast(%"_onx_sin0") {to=1}
[GenericPattern.match] match pattern <onnxscript.rewriter.generic_pattern.FunctionPattern object at 0x7fbe13cc1340>
[GenericPattern.match] iteration=1 n_matched=1, n_stack=1, matched_types=Counter({'Cast': 1})
[FunctionPattern.match] NONE - line: 460:onnxscript.rewriter.generic_pattern, op_type=Cast
--hint--: BACKWARD: different node types
%"cos"<?,?> ⬅️ ::Cos(%"output")
%"_onx_sin0"<?,?> ⬅️ ::Sin(%"_onx_concattraining0")
iteration=0
--matched-- #1
Cast(%"cos"<?,?>) ~ Cast(%"_onx_sin0"<?,?>) [140454351111376-140454352338768]
len(stack)=0:[]
[GenericPattern.match] done. backward failed.
[GenericPattern.match] starts with %"_onx_cast03"<UNDEFINED,[]> ⬅️ ::Cast(%"_onx_cos0") {to=1}
[GenericPattern.match] match pattern <onnxscript.rewriter.generic_pattern.FunctionPattern object at 0x7fbe13cc1340>
[GenericPattern.match] iteration=1 n_matched=1, n_stack=1, matched_types=Counter({'Cast': 1})
[GenericPattern._match_backward] match Cos((Value('_onx_concattraining0', type=None, shape=None, producer=anonymous_node:140454352338480, index=0),)) with Cos((Value('output', type=None, shape=None, producer=n4, index=0),)) (pattern)
[GenericPattern._match_backward] add 1 nodes
[GenericPattern.match] iteration=2 n_matched=2, n_stack=1, matched_types=Counter({'Cast': 1, 'Cos': 1})
[FunctionPattern.match] NONE - line: 460:onnxscript.rewriter.generic_pattern, op_type=Cast
--hint--: BACKWARD: different node types
%"output"<?,?>, %"length"<?,?> ⬅️ com.microsoft::ConcatTraining(%"transpose", %"transpose")
%"_onx_concattraining0"<?,?>, %"_onx_concattraining1"<?,?> ⬅️ com.microsoft::ConcatTrainingBad(%"_onx_transpose0", %"_onx_transpose0")
iteration=1
--matched-- #2
Cast(%"cos"<?,?>) ~ Cast(%"_onx_cos0"<?,?>) [140454351111376-140454351110224]
Cos(%"output"<?,?>) ~ Cos(%"_onx_concattraining0"<?,?>) [140454351111088-140455237148048]
len(stack)=0:[]
[GenericPattern.match] done. backward failed.
'RotaryEmbedding' is not a known op in 'com.microsoft'
rotary_apply_pattern: Already defined.
'ConcatTraining' is not a known op in 'com.microsoft'
rotary_match_pattern: Already defined.
0
日志显示了每次算法拒绝模式的情况。
可能的信息为:
::
[OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast
--hint--: BACKWARD: different node types
--pattern
ConcatTraining(transpose, transpose) -> (output, length)
-- model
ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1)
iteration=1
--marked-- #2
Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320]
Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472]
len(stacked)=0:[]
在文件 generic_pattern.py
的第673行,匹配被拒绝了。它表示在向后方向比较两个节点时,节点类型不匹配。它还表示实际上有两个节点是匹配的。