HardSwish 简介#
Hard Swish 是一种计算效率更高的 Swish 函数变体,适用于要求高效率的深度学习场景。Swish 激活函数由谷歌在 2017 年提出,其数学表达式为 \(f(x) = x \cdot \sigma (\beta x)\),其中 \(\beta\) 是一个可学习的参数。该函数因其能提升神经网络性能而备受关注,尤其在缓解梯度消失问题和提高模型训练效率方面表现优异。然而,标准的 Swish 函数包含计算成本较高的 sigmoid 函数,这限制了其在资源受限环境下的应用。
为了解决标准 Swish 函数在计算上的负担,Hard Swish 被开发出来。它采用分段线性函数来近似替代计算密集的 sigmoid 部分,从而显著降低了计算成本。具体来说,Hard Swish 的公式为:
其中 ReLU6 是对 ReLU 函数的改进,其将输出值限制在 \(0\) 到 \(6\) 之间。这种简化不仅保持了 Swish 函数的基本特性,还显著提高了计算速度,特别适用于移动设备和嵌入式系统等资源受限的环境。
从性能角度看,Hard Swish 尽管在形式上更加简单,但在实际使用中几乎能够与标准Swish媲美。实验结果显示,它在多种深度学习任务中的表现与标准 Swish 相差无几,但在计算效率上有明显优势。这意味着在需要快速响应或低功耗设备的场合,如移动电话或小型传感器,Hard Swish 能够提供更好的体验。
总结来说,Hard Swish 通过简化计算过程,在保留 Swish 函数优点的同时,提升了计算效率,使其更适合于资源受限的应用场景。这些特性使得 Hard Swish 成为深度学习领域内一个值得关注的激活函数变体。
真正的编程实现如下:
NumPy/TensorFlow/PyTorch 实现 HardSwish#
参考:tf.keras hard_swish & torch.nn.Hardswish & onnx__HardSwish & Searching for MobileNetV3
import numpy as np
def hard_swish(x):
return x * np.clip(x/6 + 1/2, 0, 1)
import plotly.graph_objects as go
x = np.linspace(-7, 7, 100)
y = hard_swish(x)
fig = go.Figure(data=go.Scatter(x=x, y=y, mode='lines+markers'))
fig.update_layout(
xaxis_title='x',
yaxis_title='hard_swish(x)'
)