自定义 relay 算子(python)#

from testing import viz_expr # 可视化 relay
from tvm.relay.testing import run_infer_type
from tvm.relay.dataflow_pattern import (
    wildcard, is_op,
    # FunctionPattern,
    DFPatternCallback,
    rewrite
)
import tvm
from tvm.ir.attrs import DictAttrs
from tvm import relay, te, topi
from tvm.relay.op import op as _op
from tvm.target import generic_func

@generic_func
def schedule_special_op(attrs, outs, target):
    with target:
        outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
        output = outs[0]
        sch = te.create_schedule(output.op)   
        return sch

构建 reshape4d_softmax_reshape2d ONNX 算子#

from d2py.utils.file import mkdir
root_dir = ".temp"
mkdir(f"{root_dir}/logs")
import torch
from torch.nn import functional as F
from torch import nn
from torch.onnx import OperatorExportTypes, utils

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 8, 1, 1, 0, bias=False, groups=1)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)
        b, c, h, w = x.shape
        x = x.view((b, h, w, c))
        x = F.softmax(x, dim=3)
        x = x.view((b, h * w * c))
        return x

model = M()
model.eval()

shape = 1, 3, 8, 8
input_name = "data"
xx = torch.rand(*shape, dtype=torch.float32, requires_grad=False)
# model = torch.jit.trace(model, xx)
# 导出模型
output_name = "test"
utils.export(
    model,               # torch 模型
    xx,                         # 模型输入或者对于多个输入,使用元组
    f"{root_dir}/{output_name}.onnx",               # 模型保存的位置(可以是文件或类似文件的对象)
    export_params=True,        # 将训练后的参数权重存储在模型文件内
    opset_version=17,          # 导出模型的 ONNX 版本
    do_constant_folding=True,  # 是否执行常量折叠以进行优化
    input_names = [input_name],    # 模型的输入名称
    output_names = ['output'], # 模型的输出名称
    keep_initializers_as_inputs=True,
    # export_modules_as_functions=True,
    verbose=True,
    operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
    # dynamic_axes={'data' : {0 : 'batch_size'},    # 可变长度的轴
    #               'output' : {0 : 'batch_size'}}
)
Exported graph: graph(%data : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu),
      %conv.weight : Float(8, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=1, device=cpu)):
  %/conv/Conv_output_0 : Float(1, 8, 8, 8, strides=[512, 64, 8, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1], onnx_name="/conv/Conv"](%data, %conv.weight), scope: __main__.M::/torch.nn.modules.conv.Conv2d::conv # /media/pc/data/tmp/cache/conda/envs/py312x/lib/python3.12/site-packages/torch/nn/modules/conv.py:456:0
  %/pool/GlobalAveragePool_output_0 : Float(1, 8, 1, 1, strides=[8, 1, 1, 1], requires_grad=1, device=cpu) = onnx::GlobalAveragePool[onnx_name="/pool/GlobalAveragePool"](%/conv/Conv_output_0), scope: __main__.M::/torch.nn.modules.pooling.AdaptiveAvgPool2d::pool # /media/pc/data/tmp/cache/conda/envs/py312x/lib/python3.12/site-packages/torch/nn/functional.py:1260:0
  %/Constant_output_0 : Long(4, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value= 1  1  1  8 [ CPULongType{4} ], onnx_name="/Constant"](), scope: __main__.M:: # /tmp/ipykernel_3941401/1160623730.py:16:0
  %/Reshape_output_0 : Float(1, 1, 1, 8, strides=[8, 8, 8, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/Reshape"](%/pool/GlobalAveragePool_output_0, %/Constant_output_0), scope: __main__.M:: # /tmp/ipykernel_3941401/1160623730.py:16:0
  %/Softmax_output_0 : Float(1, 1, 1, 8, strides=[8, 8, 8, 1], requires_grad=1, device=cpu) = onnx::Softmax[axis=3, onnx_name="/Softmax"](%/Reshape_output_0), scope: __main__.M:: # /media/pc/data/tmp/cache/conda/envs/py312x/lib/python3.12/site-packages/torch/nn/functional.py:1885:0
  %/Constant_1_output_0 : Long(2, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value= 1  8 [ CPULongType{2} ], onnx_name="/Constant_1"](), scope: __main__.M:: # /tmp/ipykernel_3941401/1160623730.py:18:0
  %output : Float(1, 8, strides=[8, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/Reshape_1"](%/Softmax_output_0, %/Constant_1_output_0), scope: __main__.M:: # /tmp/ipykernel_3941401/1160623730.py:18:0
  return (%output)

前端导入:

import numpy as np
import onnx
import tvm
from tvm import relay
onnx_model = onnx.load(f"{root_dir}/{output_name}.onnx")
mod, params = relay.frontend.from_onnx(onnx_model, {input_name: shape}, freeze_params=True)
mod = relay.transform.InferType()(mod)
mod.show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=/conv/Conv.data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 1, 1), float32] span=/conv/Conv.conv.weight:0:0 */, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1]) /* ty=Tensor[(1, 8, 8, 8), float32] span=/conv/Conv:0:0 */;
  %1 = nn.global_avg_pool2d(%0) /* ty=Tensor[(1, 8, 1, 1), float32] span=/pool/GlobalAveragePool:0:0 */;
  %2 = reshape(%1, newshape=[1, 1, 1, 8]) /* ty=Tensor[(1, 1, 1, 8), float32] span=/Reshape:0:0 */;
  %3 = nn.softmax(%2, axis=3) /* ty=Tensor[(1, 1, 1, 8), float32] span=/Softmax:0:0 */;
  reshape(%3, newshape=[1, 8]) /* ty=Tensor[(1, 8), float32] span=/Reshape_1:0:0 */
}

变换 reshape4d_softmax_reshape2dsoftmax_transpose_reshape2d#

其实 reshape4d_softmax_reshape2d 等价于如下结构:

class Reshape4dSoftmaxReshape2dRewrite(DFPatternCallback):
    def __init__(self):
        super().__init__()
        self.x = wildcard()
        self.reshape4d = is_op("reshape")(self.x) # 将 NCHW 转换为 NHWC,其他 H=W=1
        self.softmax = is_op("nn.softmax")(self.reshape4d)
        self.softmax_axis = self.softmax.has_attr({"axis": 3})
        self.reshape2d = is_op("reshape")(self.softmax_axis)
        self.pattern = self.reshape2d

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        relay.transform.InferTypeLocal(x).shape
        x = relay.nn.softmax(x, axis=1)
        relay.transform.InferTypeLocal(x)
        x = relay.transpose(x, (0, 2, 3, 1))
        relay.transform.InferTypeLocal(x)
        x = relay.reshape(x, (1, -1))
        relay.transform.InferTypeLocal(x)
        return x
mod["main"] = rewrite(Reshape4dSoftmaxReshape2dRewrite(), mod["main"])
mod.show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=/conv/Conv.data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 1, 1), float32] span=/conv/Conv.conv.weight:0:0 */, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1]) /* ty=Tensor[(1, 8, 8, 8), float32] span=/conv/Conv:0:0 */;
  %1 = nn.global_avg_pool2d(%0) /* ty=Tensor[(1, 8, 1, 1), float32] span=/pool/GlobalAveragePool:0:0 */;
  %2 = nn.softmax(%1, axis=1) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %3 = transpose(%2, axes=[0, 2, 3, 1]) /* ty=Tensor[(1, 1, 1, 8), float32] */;
  reshape(%3, newshape=[1, -1]) /* ty=Tensor[(1, 8), float32] */
}

声明 softmax_transpose_reshape2d 算子类型关系#

def custom_softmax_transpose_reshape2d_rel(arg_types, attrs):
    assert len(arg_types) == 1, "type relation arg number mismatch!"
    if attrs:
        assert isinstance(attrs, DictAttrs)
    inputa_type = arg_types[0]
    shape = inputa_type.shape
    shape = shape[0], shape[1] * shape[2] * shape[3]
    return relay.TensorType(shape, inputa_type.dtype)

注册 softmax_transpose_reshape2d 算子#

op_name = "softmax_transpose_reshape2d"
_op.register(op_name, r"code(cal softmax_transpose_reshape2d.)code")
_op.get(op_name).set_num_inputs(1)
_op.get(op_name).add_argument("data", "Tensor", "The input data tensor.")
_op.get(op_name).set_attrs_type_key("DictAttrs")
_op.get(op_name).add_type_rel(op_name, custom_softmax_transpose_reshape2d_rel)
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.COMM_REDUCE)
_op.register_stateful(op_name, False) # 无状态算子
_op.register_stateful??
Signature: _op.register_stateful(op_name, stateful, level=10)
Source:   
def register_stateful(op_name, stateful, level=10):
    """Register stateful flag for an op.

    Parameters
    ----------
    op_name : str
        The name of the op.

    stateful : bool
        The stateful flag.

    level : int
        The priority level
    """
    tvm.ir.register_op_attr(op_name, "TOpIsStateful", stateful, level)
File:      /media/pc/data/lxw/ai/tvm/python/tvm/relay/op/op.py
Type:      function

备注

IsStateful 的函数,用于判断算子是否是有状态的或包含内部状态。

目前 TVM 注册的所有 primitive ops 都是纯粹的,这个属性是为了将来可能的兼容性原因而保留的。如果需要处理有状态的算子,可以通过添加额外的句柄参数来解决这个问题。

测试 softmax_transpose_reshape2d 算子类型推断#

def softmax_transpose_reshape2d(x):
    return relay.Call(_op.get(op_name), [x])
tp = relay.TensorType((1, 2, 1, 1), "float32")
x = relay.var("x", tp)
sb = relay.ScopeBuilder()
t1 = sb.let("t1", softmax_transpose_reshape2d(x))
t2 = sb.let("t2", relay.add(t1, relay.const(1, dtype="float32")))
sb.ret(t2)
f = relay.Function([x], sb.get())
print(tvm.IRModule.from_expr(f))
def @main(%x: Tensor[(1, 2, 1, 1), float32]) {
  let %t1 = softmax_transpose_reshape2d(%x);
  let %t2 = add(%t1, 1f);
  %t2
}
f_type = relay.transform.InferTypeLocal(f)
f_type
I.FuncType([], [I.TensorType([1, 2, 1, 1], "float32")], I.TensorType([1, 2], "float32"))
print(relay.transform.InferType()(tvm.IRModule.from_expr(f)))
def @main(%x: Tensor[(1, 2, 1, 1), float32] /* ty=Tensor[(1, 2, 1, 1), float32] */) -> Tensor[(1, 2), float32] {
  let %t1: Tensor[(1, 2), float32] /* ty=Tensor[(1, 2), float32] */ = softmax_transpose_reshape2d(%x) /* ty=Tensor[(1, 2), float32] */;
  let %t2: Tensor[(1, 2), float32] /* ty=Tensor[(1, 2), float32] */ = add(%t1, 1f /* ty=float32 */) /* ty=Tensor[(1, 2), float32] */;
  %t2
}

融合 mod 部分算子为 softmax_transpose_reshape2d#

可视化表达式:

viz_expr(mod["main"])
../../_images/e28096630f81c3de820119925584cd8e830a1f4b5750ab6396ae250cb6308565.svg
class SoftmaxTransposeReshape2dFuse(DFPatternCallback):
    def __init__(self):
        super().__init__()
        self.x = wildcard()
        self.softmax = is_op("nn.softmax")(self.x)
        self.softmax_axis = self.softmax.has_attr({"axis": 1})
        # 将 NCHW 转换为 NHWC
        self.transpose = is_op("transpose")(self.softmax_axis).has_attr({"axes": (0, 2, 3, 1)})
        self.reshape2d = is_op("reshape")(self.transpose).has_attr({"newshape": (1, -1)})
        self.pattern = self.reshape2d

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        relay.transform.InferTypeLocal(x).shape
        return softmax_transpose_reshape2d(x)
mod.show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=/conv/Conv.data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 1, 1), float32] span=/conv/Conv.conv.weight:0:0 */, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1]) /* ty=Tensor[(1, 8, 8, 8), float32] span=/conv/Conv:0:0 */;
  %1 = nn.global_avg_pool2d(%0) /* ty=Tensor[(1, 8, 1, 1), float32] span=/pool/GlobalAveragePool:0:0 */;
  %2 = nn.softmax(%1, axis=1) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %3 = transpose(%2, axes=[0, 2, 3, 1]) /* ty=Tensor[(1, 1, 1, 8), float32] */;
  reshape(%3, newshape=[1, -1]) /* ty=Tensor[(1, 8), float32] */
}
expr = mod["main"]
expr = rewrite(SoftmaxTransposeReshape2dFuse(), expr)
run_mod = tvm.IRModule.from_expr(expr)
run_mod = relay.transform.InferType()(run_mod)
run_mod.show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=/conv/Conv.data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 1, 1), float32] span=/conv/Conv.conv.weight:0:0 */, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1]) /* ty=Tensor[(1, 8, 8, 8), float32] span=/conv/Conv:0:0 */;
  %1 = nn.global_avg_pool2d(%0) /* ty=Tensor[(1, 8, 1, 1), float32] span=/pool/GlobalAveragePool:0:0 */;
  softmax_transpose_reshape2d(%1) /* ty=Tensor[(1, 8), float32] */
}
viz_expr(run_mod["main"])
../../_images/bed6a1a85b2ff533047df7a983c45a0dce4f4ca28f7e736176f50800518a9d78.svg

定义 softmax_transpose_reshape2d 计算与调度#

def topi_softmax_transpose_reshape2d(x):
    """reshape4d_softmax_reshape2d TOPI 计算"""
    n, c, h, w = x.shape
    x = topi.nn.softmax(x, axis=1)
    x = topi.reshape(x, (n, h * w * c))
    return x
@_op.register_compute(op_name)
def output_softmax_transpose_reshape2d_compute(attrs, inputs, out_type):
    """reshape4d_softmax_reshape2d Relay 计算"""
    assert len(inputs) == 1, "输入参数数量不为 1"
    x = topi_softmax_transpose_reshape2d(inputs[0])
    return [x]
_op.register_schedule(op_name, schedule_special_op) # 定义调度
GenericFunc(0x92187f0)

验证 softmax_transpose_reshape2d 数值一致性#

data = np.random.normal(0, 1, size=shape).astype("float32")
torch_out = model(torch.from_numpy(data)).detach().numpy()


target = 'llvm'
dev = tvm.device(target, 0)


# 原始模型
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target, params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(dev))
module.run(**{input_name: data})
output1 = module.get_output(0).numpy()

# 重写后的模型
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(run_mod, target, params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(dev))
module.run(**{input_name: data})
output2 = module.get_output(0).numpy()
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
torch_out
array([[0.13445556, 0.12515777, 0.11348397, 0.12063695, 0.13582838,
        0.12997034, 0.11783069, 0.12263636]], dtype=float32)
output1
array([[0.13445556, 0.12515777, 0.11348397, 0.12063695, 0.13582838,
        0.12997034, 0.1178307 , 0.12263636]], dtype=float32)
output2
array([[0.13445556, 0.12515777, 0.11348397, 0.12063695, 0.13582838,
        0.12997034, 0.1178307 , 0.12263636]], dtype=float32)

打印带有 softmax_transpose_reshape2d 算子的量化问题#

with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        skip_conv_layers=[],
        # calibrate_mode="kl_divergence", 
        weight_scale="max",
        # round_for_shift=True,
        # rounding="TONEAREST", # "UPWARD" or "TONEAREST"
        skip_dense_layer=False,
    ):
        qmod = relay.quantize.quantize(run_mod, params)
qmod.show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=/conv/Conv.data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = multiply(%data, 16f /* ty=float32 */) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %1 = round(%0) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 3, 8, 8), int8] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 1, 1), int8] */, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 8, 8, 8), int32] */;
  %5 = cast(%4, dtype="int64") /* ty=Tensor[(1, 8, 8, 8), int64] */;
  %6 = fixed_point_multiply(%5, multiplier=1192595968, shift=-7) /* ty=Tensor[(1, 8, 8, 8), int64] */;
  %7 = clip(%6, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 8, 8, 8), int64] */;
  %8 = cast(%7, dtype="int32") /* ty=Tensor[(1, 8, 8, 8), int32] */;
  %9 = cast(%8, dtype="int8") /* ty=Tensor[(1, 8, 8, 8), int8] */;
  %10 = annotation.stop_fusion(%9) /* ty=Tensor[(1, 8, 8, 8), int8] */;
  %11 = cast(%10, dtype="int32") /* ty=Tensor[(1, 8, 8, 8), int32] */;
  %12 = nn.global_avg_pool2d(%11) /* ty=Tensor[(1, 8, 1, 1), int32] */;
  %13 = cast(%12, dtype="float32") /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %14 = multiply(%13, 0.0625f /* ty=float32 */) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  softmax_transpose_reshape2d(%14) /* ty=Tensor[(1, 8), float32] */
}