LayerNorm 前端支持

LayerNorm 前端支持#

from_exported_program()#

import torch
from torch.export import export
import tvm
from tvm import relax
import tvm.testing
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.relax.frontend.torch import from_exported_program
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ln = torch.nn.RMSNorm((32))

    def forward(self, x):
        return self.ln(x)
input_shape = [1, 3, 32, 32]
model = M().eval()
input_data = torch.randn(input_shape)
torch.onnx.export(
    model, 
    input_data,
    ".temp/test.onnx", 
    input_names=["x"],
    dynamo=True,
    opset_version=23,
    do_constant_folding=True,
    export_params=True,
    dump_exported_program=True,
    keep_initializers_as_inputs=False,
)
[torch.onnx] Obtain model graph for `M([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `M([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Dumping ExportedProgram because `dump_exported_program=True`...
[torch.onnx] ExportedProgram has been saved to 'onnx_export_2025-09-12_17-11-16-671771.pt2'.
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
ONNXProgram(
    model=
        <
            ir_version=10,
            opset_imports={'': 23},
            producer_name='pytorch',
            producer_version='2.8.0+cu128',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"x"<FLOAT,[1,3,32,32]>
            ),
            outputs=(
                %"type_as"<FLOAT,[1,3,32,32]>
            ),
            initializers=(
                %"ln.weight"<FLOAT,[32]>{TorchTensor(...)},
                %"val_0"<FLOAT,[]>{Tensor<FLOAT,[]>(array(2., dtype=float32), name='val_0')},
                %"val_3"<INT64,[1]>{Tensor<INT64,[1]>(array([3]), name='val_3')},
                %"val_4"<FLOAT,[]>{Tensor<FLOAT,[]>(array(1.1920929e-07, dtype=float32), name='val_4')}
            ),
        ) {
            0 |  # node_pow_1
                 %"pow_1"<FLOAT,[1,3,32,32]> ⬅️ ::Pow(%"x", %"val_0"{2.0})
            1 |  # node_mean
                 %"mean"<FLOAT,[1,3,32,1]> ⬅️ ::ReduceMean(%"pow_1", %"val_3"{[3]}) {keepdims=True, noop_with_empty_axes=0}
            2 |  # node_add
                 %"add"<FLOAT,[1,3,32,1]> ⬅️ ::Add(%"mean", %"val_4"{1.1920928955078125e-07})
            3 |  # node_Sqrt_5
                 %"val_5"<FLOAT,[1,3,32,1]> ⬅️ ::Sqrt(%"add")
            4 |  # node_rsqrt
                 %"rsqrt"<FLOAT,[1,3,32,1]> ⬅️ ::Reciprocal(%"val_5")
            5 |  # node_mul
                 %"mul"<FLOAT,[1,3,32,32]> ⬅️ ::Mul(%"x", %"rsqrt")
            6 |  # node_type_as
                 %"type_as"<FLOAT,[1,3,32,32]> ⬅️ ::Mul(%"mul", %"ln.weight"{...})
            return %"type_as"<FLOAT,[1,3,32,32]>
        }


    ,
    exported_program=
        ExportedProgram:
            class GraphModule(torch.nn.Module):
                def forward(self, p_ln_weight: "f32[32]", x: "f32[1, 3, 32, 32]"):
                     # File: /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/nn/modules/normalization.py:402 in forward, code: return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
                    pow_1: "f32[1, 3, 32, 32]" = torch.ops.aten.pow.Tensor_Scalar(x, 2)
                    mean: "f32[1, 3, 32, 1]" = torch.ops.aten.mean.dim(pow_1, [3], True);  pow_1 = None
                    add: "f32[1, 3, 32, 1]" = torch.ops.aten.add.Scalar(mean, 1.1920928955078125e-07);  mean = None
                    rsqrt: "f32[1, 3, 32, 1]" = torch.ops.aten.rsqrt.default(add);  add = None
                    mul: "f32[1, 3, 32, 32]" = torch.ops.aten.mul.Tensor(x, rsqrt);  rsqrt = None
                    mul_1: "f32[1, 3, 32, 32]" = torch.ops.aten.mul.Tensor(mul, p_ln_weight);  mul = p_ln_weight = None
                    type_as: "f32[1, 3, 32, 32]" = torch.ops.aten.type_as.default(mul_1, x);  mul_1 = x = None
                    return (type_as,)
            
        Graph signature: 
            # inputs
            p_ln_weight: PARAMETER target='ln.weight'
            x: USER_INPUT
    
            # outputs
            type_as: USER_OUTPUT
    
        Range constraints: {}

)
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)

model = LayerNorm()
exported_program = export(model, args=example_args,)
mod = from_exported_program(exported_program)
mod.show()

LayerNorm#

from tvm.relax.frontend.nn import core, modules, spec
mod = modules.LayerNorm(8)
tvm_mod, params = mod.export_tvm(
    spec={"forward": {"x": spec.Tensor((2, 4, 8), "float32")}}, debug=True
)
tvm_mod.show()

from_fx()#

import tvm
from tvm import relax
import tvm.testing
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.relax.frontend import detach_params
from tvm.relax.frontend.torch import from_fx
class LayerNorm(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ln = torch.nn.LayerNorm((10, 10))

    def forward(self, x):
        return self.ln(x)

input_info = [([1, 3, 10, 10], "float32")]

torch_model = LayerNorm()
graph_model = torch.fx.symbolic_trace(torch_model)
with torch.no_grad():
    mod = from_fx(graph_model, input_info)
    mod.show()
class LayerNorm(torch.nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(shape))
        self.bias = torch.nn.Parameter(torch.zeros(shape))

    def forward(self, x):
        return torch.nn.functional.layer_norm(
            x, self.weight.shape, self.weight, self.bias, 1e-5
        )

torch_model = LayerNorm((10, 10))
graph_model = torch.fx.symbolic_trace(torch_model)
with torch.no_grad():
    mod = from_fx(graph_model, input_info)
    mod.show()
class LayerNorm2(torch.nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape
        self.weight = None
        self.bias = None

    def forward(self, input):
        return torch.nn.functional.layer_norm(input, self.shape, self.weight, self.bias, 1e-5)

torch_model = LayerNorm2((10, 10))
graph_model = torch.fx.symbolic_trace(torch_model)
with torch.no_grad():
    mod = from_fx(graph_model, input_info)
    mod.show()
class LayerNorm3(torch.nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape
        self.weight = torch.nn.Parameter(torch.ones(shape))
        self.bias = torch.nn.Parameter(torch.zeros(shape))

    def forward(self, x):
        return torch.nn.functional.layer_norm(x, self.shape, self.weight, self.bias, 1e-5)

torch_model = LayerNorm3((10, 10))
graph_model = torch.fx.symbolic_trace(torch_model)
with torch.no_grad():
    mod = from_fx(graph_model, input_info)
    mod.show()