HardSigmoid Relay 实现#

%cd ..
import set_env
from d2py.utils.file import mkdir
root_dir = ".temp"
mkdir(f"{root_dir}/logs")
/media/pc/data/lxw/ai/tvm-book/doc/dev/ops
ROOT: /media/pc/data/lxw/ai/tvm-book

F.hardtanh(x + 3, 0., 6.) / 6 或者 F.relu6(x+3)/6#

import torch
from torch.nn import functional as F
from torch import nn
from torch.onnx import OperatorExportTypes, utils

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 1, 1, bias=False)
        self.hard_sigmoid1 = nn.Hardsigmoid()

    def forward(self, x):
        x1 = F.hardtanh(x + 3, 0., 6.) / 6.
        x2 = F.relu6(x+3)/6
        return x1, x2

model = M()
model.eval()

shape = 1, 3, 8, 8
input_name = "data"
xx = torch.rand(*shape, dtype=torch.float32, requires_grad=False)
# model = torch.jit.trace(model, xx)
# 导出模型
output_name = "hard-sigmoid-v1"
utils.export(
    model,               # torch 模型
    xx,                         # 模型输入或者对于多个输入,使用元组
    f"{root_dir}/{output_name}.onnx",               # 模型保存的位置(可以是文件或类似文件的对象)
    export_params=True,        # 将训练后的参数权重存储在模型文件内
    opset_version=9,          # 导出模型的 ONNX 版本
    do_constant_folding=True,  # 是否执行常量折叠以进行优化
    input_names = [input_name],    # 模型的输入名称
    output_names = ['output'], # 模型的输出名称
    keep_initializers_as_inputs=True,
    # export_modules_as_functions=True,
    verbose=True,
    operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
    # dynamic_axes={'data' : {0 : 'batch_size'},    # 可变长度的轴
    #               'output' : {0 : 'batch_size'}}
)
Exported graph: graph(%data : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu)):
  %/Constant_output_0 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={3}, onnx_name="/Constant"](), scope: __main__.M:: # /tmp/ipykernel_2933796/1400082185.py:13:0
  %/Add_output_0 : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu) = onnx::Add[onnx_name="/Add"](%data, %/Constant_output_0), scope: __main__.M:: # /tmp/ipykernel_2933796/1400082185.py:13:0
  %/Clip_output_0 : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu) = onnx::Clip[max=6., min=0., onnx_name="/Clip"](%/Add_output_0), scope: __main__.M:: # /media/pc/data/lxw/envs/anaconda3x/envs/py312/lib/python3.12/site-packages/torch/nn/functional.py:1551:0
  %/Constant_1_output_0 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={6}, onnx_name="/Constant_1"](), scope: __main__.M:: # /tmp/ipykernel_2933796/1400082185.py:13:0
  %output : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu) = onnx::Div[onnx_name="/Div"](%/Clip_output_0, %/Constant_1_output_0), scope: __main__.M:: # /tmp/ipykernel_2933796/1400082185.py:13:0
  %/Clip_1_output_0 : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu) = onnx::Clip[max=6., min=0., onnx_name="/Clip_1"](%/Add_output_0), scope: __main__.M:: # /media/pc/data/lxw/envs/anaconda3x/envs/py312/lib/python3.12/site-packages/torch/nn/functional.py:1577:0
  %/Constant_2_output_0 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={6}, onnx_name="/Constant_2"](), scope: __main__.M:: # /tmp/ipykernel_2933796/1400082185.py:14:0
  %9 : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu) = onnx::Div[onnx_name="/Div_1"](%/Clip_1_output_0, %/Constant_2_output_0), scope: __main__.M:: # /tmp/ipykernel_2933796/1400082185.py:14:0
  return (%output, %9)
import numpy as np
import tvm
from tvm import relay
data_np = (np.random.randint(0, 256, shape)/255).astype("float32")
data_torch = torch.from_dlpack(data_np)

model = M().eval()
scripted_model = torch.jit.trace(model, data_torch).eval()
shape_list = [(input_name, shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
tvm.IRModule.from_expr(mod["main"]).show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* span=aten::add_0.data:0:0 */) {
  %0 = add(%data, 3f /* span=aten::add_0:0:0 */) /* span=aten::add_0:0:0 */;
  %1 = clip(%0, a_min=0f, a_max=6f) /* span=aten::hardtanh_0:0:0 */;
  %2 = add(%data, 3f /* span=aten::add_1:0:0 */) /* span=aten::add_1:0:0 */;
  %3 = clip(%2, a_min=0f, a_max=6f) /* span=aten::relu6_0:0:0 */;
  %4 = divide(%1, 6f /* span=aten::div_0:0:0 */) /* span=aten::div_0:0:0 */;
  %5 = divide(%3, 6f /* span=aten::div_1:0:0 */) /* span=aten::div_1:0:0 */;
  (%4, %5)
}
import tvm
from tvm import relay
import onnx
onnx_model = onnx.load(f"{root_dir}/{output_name}.onnx")
mod, params = relay.frontend.onnx.from_onnx(onnx_model, {input_name: shape})
tvm.IRModule.from_expr(mod["main"]).show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=/Add.data:0:0 */) -> (Tensor[(1, 3, 8, 8), float32], Tensor[(1, 3, 8, 8), float32]) {
  %0 = add(%data, 3f /* ty=float32 span=/Constant:0:0 */) /* ty=Tensor[(1, 3, 8, 8), float32] span=/Add:0:0 */;
  %1 = clip(%0, a_min=0f, a_max=6f) /* ty=Tensor[(1, 3, 8, 8), float32] span=/Clip:0:0 */;
  %2 = clip(%0, a_min=0f, a_max=6f) /* ty=Tensor[(1, 3, 8, 8), float32] span=/Clip_1:0:0 */;
  %3 = divide(%1, 6f /* ty=float32 span=/Constant_1:0:0 */) /* ty=Tensor[(1, 3, 8, 8), float32] span=/Div:0:0 */;
  %4 = divide(%2, 6f /* ty=float32 span=/Constant_2:0:0 */) /* ty=Tensor[(1, 3, 8, 8), float32] span=/Div_1:0:0 */;
  (%3, %4) /* ty=(Tensor[(1, 3, 8, 8), float32], Tensor[(1, 3, 8, 8), float32]) */
}

F.hardtanh(x*0.2+0.5, 0., 1.) 或者 F.hardtanh(x*(1/6)+0.5, 0., 1.)#

import torch
from torch.nn import functional as F
from torch import nn
from torch.onnx import OperatorExportTypes, utils

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 1, 1, bias=False)
        self.hard_sigmoid1 = nn.Hardsigmoid()

    def forward(self, x):
        x1 = F.hardtanh(x*0.2+0.5, 0., 1.)
        x2 = F.hardtanh(x* (1/6)+0.5, 0., 1.)
        return x1, x2

model = M()
model.eval()

shape = 1, 3, 8, 8
input_name = "data"
xx = torch.rand(*shape, dtype=torch.float32, requires_grad=False)
# model = torch.jit.trace(model, xx)
# 导出模型
output_name = "hard-sigmoid-v2"
utils.export(
    model,               # torch 模型
    xx,                         # 模型输入或者对于多个输入,使用元组
    f"{root_dir}/{output_name}.onnx",               # 模型保存的位置(可以是文件或类似文件的对象)
    export_params=True,        # 将训练后的参数权重存储在模型文件内
    opset_version=9,          # 导出模型的 ONNX 版本
    do_constant_folding=True,  # 是否执行常量折叠以进行优化
    input_names = [input_name],    # 模型的输入名称
    output_names = ['output'], # 模型的输出名称
    keep_initializers_as_inputs=True,
    # export_modules_as_functions=True,
    verbose=True,
    operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
    # dynamic_axes={'data' : {0 : 'batch_size'},    # 可变长度的轴
    #               'output' : {0 : 'batch_size'}}
)
Exported graph: graph(%data : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu)):
  %/Constant_output_0 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={0.2}, onnx_name="/Constant"](), scope: __main__.M:: # /tmp/ipykernel_2933796/4238986861.py:13:0
  %/Mul_output_0 : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu) = onnx::Mul[onnx_name="/Mul"](%data, %/Constant_output_0), scope: __main__.M:: # /tmp/ipykernel_2933796/4238986861.py:13:0
  %/Constant_1_output_0 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={0.5}, onnx_name="/Constant_1"](), scope: __main__.M:: # /tmp/ipykernel_2933796/4238986861.py:13:0
  %/Add_output_0 : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu) = onnx::Add[onnx_name="/Add"](%/Mul_output_0, %/Constant_1_output_0), scope: __main__.M:: # /tmp/ipykernel_2933796/4238986861.py:13:0
  %output : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu) = onnx::Clip[max=1., min=0., onnx_name="/Clip"](%/Add_output_0), scope: __main__.M:: # /media/pc/data/lxw/envs/anaconda3x/envs/py312/lib/python3.12/site-packages/torch/nn/functional.py:1551:0
  %/Constant_2_output_0 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={0.166667}, onnx_name="/Constant_2"](), scope: __main__.M:: # /tmp/ipykernel_2933796/4238986861.py:14:0
  %/Mul_1_output_0 : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu) = onnx::Mul[onnx_name="/Mul_1"](%data, %/Constant_2_output_0), scope: __main__.M:: # /tmp/ipykernel_2933796/4238986861.py:14:0
  %/Constant_3_output_0 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={0.5}, onnx_name="/Constant_3"](), scope: __main__.M:: # /tmp/ipykernel_2933796/4238986861.py:14:0
  %/Add_1_output_0 : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu) = onnx::Add[onnx_name="/Add_1"](%/Mul_1_output_0, %/Constant_3_output_0), scope: __main__.M:: # /tmp/ipykernel_2933796/4238986861.py:14:0
  %11 : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu) = onnx::Clip[max=1., min=0., onnx_name="/Clip_1"](%/Add_1_output_0), scope: __main__.M:: # /media/pc/data/lxw/envs/anaconda3x/envs/py312/lib/python3.12/site-packages/torch/nn/functional.py:1551:0
  return (%output, %11)
import numpy as np
import tvm
from tvm import relay
data_np = (np.random.randint(0, 256, shape)/255).astype("float32")
data_torch = torch.from_dlpack(data_np)

model = M().eval()
scripted_model = torch.jit.trace(model, data_torch).eval()
shape_list = [(input_name, shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
tvm.IRModule.from_expr(mod["main"]).show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* span=aten::mul_0.data:0:0 */) {
  %0 = multiply(%data, 0.2f /* span=aten::mul_0:0:0 */) /* span=aten::mul_0:0:0 */;
  %1 = add(%0, 0.5f /* span=aten::add_0:0:0 */) /* span=aten::add_0:0:0 */;
  %2 = multiply(%data, 0.166667f /* span=aten::mul_1:0:0 */) /* span=aten::mul_1:0:0 */;
  %3 = add(%2, 0.5f /* span=aten::add_1:0:0 */) /* span=aten::add_1:0:0 */;
  %4 = clip(%1, a_min=0f, a_max=1f) /* span=aten::hardtanh_0:0:0 */;
  %5 = clip(%3, a_min=0f, a_max=1f) /* span=aten::hardtanh_1:0:0 */;
  (%4, %5)
}
import tvm
from tvm import relay
import onnx
onnx_model = onnx.load(f"{root_dir}/{output_name}.onnx")
mod, params = relay.frontend.onnx.from_onnx(onnx_model, {input_name: shape})
tvm.IRModule.from_expr(mod["main"]).show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=/Mul.data:0:0 */) -> (Tensor[(1, 3, 8, 8), float32], Tensor[(1, 3, 8, 8), float32]) {
  %0 = multiply(%data, 0.2f /* ty=float32 span=/Constant:0:0 */) /* ty=Tensor[(1, 3, 8, 8), float32] span=/Mul:0:0 */;
  %1 = add(%0, 0.5f /* ty=float32 span=/Constant_1:0:0 */) /* ty=Tensor[(1, 3, 8, 8), float32] span=/Add:0:0 */;
  %2 = multiply(%data, 0.166667f /* ty=float32 span=/Constant_2:0:0 */) /* ty=Tensor[(1, 3, 8, 8), float32] span=/Mul_1:0:0 */;
  %3 = add(%2, 0.5f /* ty=float32 span=/Constant_3:0:0 */) /* ty=Tensor[(1, 3, 8, 8), float32] span=/Add_1:0:0 */;
  %4 = clip(%1, a_min=0f, a_max=1f) /* ty=Tensor[(1, 3, 8, 8), float32] span=/Clip:0:0 */;
  %5 = clip(%3, a_min=0f, a_max=1f) /* ty=Tensor[(1, 3, 8, 8), float32] span=/Clip_1:0:0 */;
  (%4, %5) /* ty=(Tensor[(1, 3, 8, 8), float32], Tensor[(1, 3, 8, 8), float32]) */
}

nn.Hardsigmoid 或者 F.hardsigmoid#

import torch
from torch.nn import functional as F
from torch import nn
from torch.onnx import OperatorExportTypes, utils

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 1, 1, bias=False)
        self.hard_sigmoid = nn.Hardsigmoid()

    def forward(self, x):
        x1 = F.hardsigmoid(x)
        x2 = self.hard_sigmoid(x)
        return x1, x2

model = M()
model.eval()

shape = 1, 3, 8, 8
input_name = "data"
xx = torch.rand(*shape, dtype=torch.float32, requires_grad=False)
# model = torch.jit.trace(model, xx)
# 导出模型
output_name = "hard-sigmoid-v3"
utils.export(
    model,               # torch 模型
    xx,                         # 模型输入或者对于多个输入,使用元组
    f"{root_dir}/{output_name}.onnx",               # 模型保存的位置(可以是文件或类似文件的对象)
    export_params=True,        # 将训练后的参数权重存储在模型文件内
    opset_version=9,          # 导出模型的 ONNX 版本
    do_constant_folding=True,  # 是否执行常量折叠以进行优化
    input_names = [input_name],    # 模型的输入名称
    output_names = ['output'], # 模型的输出名称
    keep_initializers_as_inputs=True,
    # export_modules_as_functions=True,
    verbose=True,
    operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
    # dynamic_axes={'data' : {0 : 'batch_size'},    # 可变长度的轴
    #               'output' : {0 : 'batch_size'}}
)
Exported graph: graph(%data : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu)):
  %output : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu) = onnx::HardSigmoid[alpha=0.16666666666666666, onnx_name="/HardSigmoid"](%data), scope: __main__.M:: # /media/pc/data/lxw/envs/anaconda3x/envs/py312/lib/python3.12/site-packages/torch/nn/functional.py:2032:0
  %3 : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu) = onnx::HardSigmoid[alpha=0.16666666666666666, onnx_name="/hard_sigmoid/HardSigmoid"](%data), scope: __main__.M::/torch.nn.modules.activation.Hardsigmoid::hard_sigmoid # /media/pc/data/lxw/envs/anaconda3x/envs/py312/lib/python3.12/site-packages/torch/nn/functional.py:2032:0
  return (%output, %3)
import numpy as np
import tvm
from tvm import relay
data_np = (np.random.randint(0, 256, shape)/255).astype("float32")
data_torch = torch.from_dlpack(data_np)

model = M().eval()
scripted_model = torch.jit.trace(model, data_torch).eval()
shape_list = [(input_name, shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
tvm.IRModule.from_expr(mod["main"]).show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* span=aten::hardsigmoid_0.data:0:0 */) {
  %0 = add(%data, 3f /* span=aten::hardsigmoid_0:0:0 */) /* span=aten::hardsigmoid_0:0:0 */;
  %1 = clip(%0, a_min=0f, a_max=6f) /* span=aten::hardsigmoid_0:0:0 */;
  %2 = add(%data, 3f /* span=aten::hardsigmoid_1:0:0 */) /* span=aten::hardsigmoid_1:0:0 */;
  %3 = clip(%2, a_min=0f, a_max=6f) /* span=aten::hardsigmoid_1:0:0 */;
  %4 = divide(%1, 6f /* span=aten::hardsigmoid_0:0:0 */) /* span=aten::hardsigmoid_0:0:0 */;
  %5 = divide(%3, 6f /* span=aten::hardsigmoid_1:0:0 */) /* span=aten::hardsigmoid_1:0:0 */;
  (%4, %5)
}
import tvm
from tvm import relay
import onnx
onnx_model = onnx.load(f"{root_dir}/{output_name}.onnx")
mod, params = relay.frontend.onnx.from_onnx(onnx_model, {input_name: shape})
tvm.IRModule.from_expr(mod["main"]).show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=/HardSigmoid.data:0:0 */) -> (Tensor[(1, 3, 8, 8), float32], Tensor[(1, 3, 8, 8), float32]) {
  %0 = multiply(%data, 0.166667f /* ty=float32 span=/HardSigmoid:0:0 */) /* ty=Tensor[(1, 3, 8, 8), float32] span=/HardSigmoid:0:0 */;
  %1 = add(%0, 0.5f /* ty=float32 span=/HardSigmoid:0:0 */) /* ty=Tensor[(1, 3, 8, 8), float32] span=/HardSigmoid:0:0 */;
  %2 = multiply(%data, 0.166667f /* ty=float32 span=/hard_sigmoid/HardSigmoid:0:0 */) /* ty=Tensor[(1, 3, 8, 8), float32] span=/hard_sigmoid/HardSigmoid:0:0 */;
  %3 = add(%2, 0.5f /* ty=float32 span=/hard_sigmoid/HardSigmoid:0:0 */) /* ty=Tensor[(1, 3, 8, 8), float32] span=/hard_sigmoid/HardSigmoid:0:0 */;
  %4 = clip(%1, a_min=0f, a_max=1f) /* ty=Tensor[(1, 3, 8, 8), float32] span=/HardSigmoid:0:0 */;
  %5 = clip(%3, a_min=0f, a_max=1f) /* ty=Tensor[(1, 3, 8, 8), float32] span=/hard_sigmoid/HardSigmoid:0:0 */;
  (%4, %5) /* ty=(Tensor[(1, 3, 8, 8), float32], Tensor[(1, 3, 8, 8), float32]) */
}