量化 QPartitionExpr#
下面以表达式 \(f(x, y) = (x + y)(x -y)\) 为例展示。
import set_env
import tvm
from tvm import relay
x = relay.var("x", dtype="float32", shape=(10,))
y = relay.var("y", dtype="float32", shape=(10,))
z1 = x + y
z2 = x - y
z3 = z1 * z2
z4 = relay.exp(z3)
mod = tvm.IRModule.from_expr(z4)
mod.show()
def @main(%x: Tensor[(10), float32], %y: Tensor[(10), float32]) {
%0 = add(%x, %y);
%1 = subtract(%x, %y);
%2 = multiply(%0, %1);
exp(%2)
}
自定义分区#
from tvm.relay.quantize._partition import (
register_partition_function,
QPartitionExpr,
partition_expr_check
)
from tvm.relay.dataflow_pattern import is_constant, is_op, wildcard, is_var
from tvm.relay import Call
from tvm.relay.function import Function, FunctionWithFields
@tvm.relay.transform.function_pass(opt_level=1)
class MergeGraphTransform:
def __init__(self):
self.reset()
def reset(self):
self.nodes = []
def transform_function(self, func, mod, ctx):
obj = self
class Replace(tvm.relay.ExprMutator):
def visit_function(self, fn):
new_params = [self.visit(x) for x in fn.params]
new_body = self.visit(fn.body)
new_body = QPartitionExpr(new_body).realize()
if new_params == list(fn.params) and new_body == fn.body:
new_fn = fn
else:
new_fn = FunctionWithFields(fn, list(new_params), new_body)
obj.nodes.append(new_fn)
return new_fn
return Replace().visit(func)
def make_add_subtract_multiply_pattern():
"""查找模式
(x + y)(x - y)
"""
x = is_var()
y = is_var()
node1 = is_op("add")(x, y)
node2 = is_op("subtract")(x, y)
node = is_op("multiply")(node1, node2)
return node
compiler_name = "ccompiler"
pattern_table = [
(f"{compiler_name}.add_subtract_multiply", make_add_subtract_multiply_pattern()),
]
merge_passes = tvm.transform.Sequential([
relay.transform.MergeComposite(pattern_table),
# relay.transform.AnnotateTarget([compiler_name]),
relay.transform.PartitionGraph(),
# relay.transform.ToANormalForm()
])
run_mod = merge_passes(mod)
run_mod.show()
def @main(%x: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, %y: Tensor[(10), float32] /* ty=Tensor[(10), float32] */) -> Tensor[(10), float32] {
%2 = fn (%FunctionVar_0_0: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, %FunctionVar_0_1: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, PartitionedFromPattern="add_subtract_multiply_", Composite="ccompiler.add_subtract_multiply") -> Tensor[(10), float32] {
%0 = add(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10), float32] */;
%1 = subtract(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10), float32] */;
multiply(%0, %1) /* ty=Tensor[(10), float32] */
} /* ty=fn (Tensor[(10), float32], Tensor[(10), float32]) -> Tensor[(10), float32] */;
%3 = %2(%x, %y) /* ty=Tensor[(10), float32] */;
exp(%3) /* ty=Tensor[(10), float32] */
}
transform = MergeGraphTransform()
run_mod = transform(run_mod)
run_mod.show()
def @main(%x: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, %y: Tensor[(10), float32] /* ty=Tensor[(10), float32] */) -> Tensor[(10), float32] {
%4 = fn (%FunctionVar_0_0: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, %FunctionVar_0_1: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, PartitionedFromPattern="add_subtract_multiply_", Composite="ccompiler.add_subtract_multiply") -> Tensor[(10), float32] {
%0 = add(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10), float32] */;
%1 = subtract(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10), float32] */;
%2 = multiply(%0, %1) /* ty=Tensor[(10), float32] */;
%3 = annotation.cast_hint(%2, dtype="int8") /* ty=Tensor[(10), float32] */;
annotation.stop_fusion(%3) /* ty=Tensor[(10), float32] */
} /* ty=fn (Tensor[(10), float32], Tensor[(10), float32]) -> Tensor[(10), float32] */;
%5 = %4(%x, %y) /* ty=Tensor[(10), float32] */;
%6 = exp(%5) /* ty=Tensor[(10), float32] */;
%7 = annotation.cast_hint(%6, dtype="int8") /* ty=Tensor[(10), float32] */;
annotation.stop_fusion(%7) /* ty=Tensor[(10), float32] */
}
从数学角度来看,上述问题可以化简为 \(f(x, y) = x^2 - y^2\):
from tvm.relay.dataflow_pattern import DFPatternCallback
class MergeGraphCallback(DFPatternCallback):
# A callback class to rewrite the matched pattern to a batch_norm op.
def __init__(self, require_type=False):
super().__init__(require_type)
self.pattern = make_add_subtract_multiply_pattern()
def callback(self, pre, post, node_map):
x = post.args[0].args[0] * post.args[0].args[0]
y = post.args[0].args[1] * post.args[0].args[1]
return x - y
from tvm.relay.dataflow_pattern import rewrite
rewrite(MergeGraphCallback(), relay.transform.DefuseOps()(run_mod)["main"])
fn (%x: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, %y: Tensor[(10), float32] /* ty=Tensor[(10), float32] */) -> Tensor[(10), float32] {
%0 = multiply(%x, %x);
%1 = multiply(%y, %y);
%2 = subtract(%0, %1);
%3 = annotation.cast_hint(%2, dtype="int8") /* ty=Tensor[(10), float32] */;
%4 = annotation.stop_fusion(%3) /* ty=Tensor[(10), float32] */;
%5 = exp(%4) /* ty=Tensor[(10), float32] */;
%6 = annotation.cast_hint(%5, dtype="int8") /* ty=Tensor[(10), float32] */;
annotation.stop_fusion(%6) /* ty=Tensor[(10), float32] */
} /* ty=fn (Tensor[(10), float32], Tensor[(10), float32]) -> Tensor[(10), float32] */