tvm_book.relay 源代码

import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import (
    wildcard, is_op, 
    is_constant,
    DFPatternCallback,
    rewrite
)

[文档] class L2NormalizeONNX(DFPatternCallback): def __init__(self): super().__init__()
[文档] self.x = wildcard()
[文档] self.multiply = is_op("multiply")(self.x, self.x)
[文档] self.sum = is_op("sum")(self.multiply)
[文档] self.sqrt = is_op("sqrt")(self.sum)
[文档] self.broadcast_to = is_op("broadcast_to")(self.sqrt)
[文档] self.divide = is_op("divide")(self.x, self.broadcast_to)
[文档] self.pattern = self.divide
[文档] def callback(self, pre, post, node_map): x = node_map[self.x][0] sum = node_map[self.sum][0] ret = relay.nn.l2_normalize(x, eps=1e-12, axis=sum.attrs.axis) relay.transform.InferTypeLocal(ret) return ret
[文档] class L2NormalizeTorch(DFPatternCallback): def __init__(self): super().__init__()
[文档] self.x = wildcard()
[文档] self.abs = is_op("abs")(self.x)
[文档] self.n1 = is_constant()
[文档] self.power = is_op("power")(self.abs, self.n1)
[文档] self.sum = is_op("sum")(self.power)
[文档] self.n2 = is_constant()
[文档] self.sqrt = is_op("power")(self.sum, self.n2)
[文档] self.clip = is_op("clip")(self.sqrt)
[文档] self.broadcast_to_like = is_op("broadcast_to_like")(self.clip, self.x)
[文档] self.divide = is_op("divide")(self.x, self.broadcast_to_like)
[文档] self.pattern = self.divide
[文档] def callback(self, pre, post, node_map): x = node_map[self.x][0] n1 = node_map[self.n1][0] n2 = node_map[self.n2][0] clip = node_map[self.clip][0] dtype = relay.transform.InferTypeLocal(x).dtype if (n1.data.numpy() == 2) and (n2.data.numpy()==0.5) and clip.attrs.a_max==np.finfo(dtype).max: sum = node_map[self.sum][0] ret = relay.nn.l2_normalize(x, eps=clip.attrs.a_min, axis=sum.attrs.axis) relay.transform.InferTypeLocal(ret) return ret return post
@tvm.transform.module_pass(opt_level=1)
[文档] class FuseL2Normalize: """融合 torch.nn.functional.normalize"""
[文档] def transform_module(self, mod, ctx): # 融合 torch.nn.functional.normalize 的 TorchScript 版本 mod["main"] = rewrite(L2NormalizeTorch(), mod["main"]) # 融合 torch.nn.functional.normalize 的 ONNX 版本 mod["main"] = rewrite(L2NormalizeONNX(), mod["main"]) return mod