ONNX Script GELU

ONNX Script GELU#

import math
from onnxscript import (
    script, opset18 as op, FLOAT
)

M_SQRT1_2 = math.sqrt(0.5)

@script()
def gelu(X: FLOAT[...]):
    phiX = 0.5 * (op.Erf(M_SQRT1_2 * X) + 1.0)
    return X * phiX

model = gelu.to_model_proto()
import onnx
onnx.save_model(
    model,
    "gelu.onnx",
)