# `LayerNorm` 前端支持

## {func}`~tvm.relax.frontend.torch.from_exported_program`

In [None]:
import torch
from torch.export import export
import tvm
from tvm import relax
import tvm.testing
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.relax.frontend.torch import from_exported_program
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ln = torch.nn.RMSNorm((32))

    def forward(self, x):
        return self.ln(x)

In [None]:
input_shape = [1, 3, 32, 32]
model = M().eval()
input_data = torch.randn(input_shape)
torch.onnx.export(
    model, 
    input_data,
    ".temp/test.onnx", 
    input_names=["x"],
    dynamo=True,
    opset_version=23,
    do_constant_folding=True,
    export_params=True,
    dump_exported_program=True,
    keep_initializers_as_inputs=False,
)

[torch.onnx] Obtain model graph for `M([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `M([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Dumping ExportedProgram because `dump_exported_program=True`...
[torch.onnx] ExportedProgram has been saved to 'onnx_export_2025-09-12_17-11-16-671771.pt2'.
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅


ONNXProgram(
    model=
        <
            ir_version=10,
            opset_imports={'': 23},
            producer_name='pytorch',
            producer_version='2.8.0+cu128',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"x"<FLOAT,[1,3,32,32]>
            ),
            outputs=(
                %"type_as"<FLOAT,[1,3,32,32]>
            ),
            initializers=(
                %"ln.weight"<FLOAT,[32]>{TorchTensor(...)},
                %"val_0"<FLOAT,[]>{Tensor<FLOAT,[]>(array(2., dtype=float32), name='val_0')},
                %"val_3"<INT64,[1]>{Tensor<INT64,[1]>(array([3]), name='val_3')},
                %"val_4"<FLOAT,[]>{Tensor<FLOAT,[]>(array(1.1920929e-07, dtype=float32), name='val_4')}
            ),
        ) {
            0 |  # node_pow_1
                 %"pow_1"<FLOAT,[1,3,32,32]> ⬅️ ::Pow(%"x", %"val_0"{2.0})
            1 |  # node_mean
                 %"mean"

In [None]:
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)

model = LayerNorm()
exported_program = export(model, args=example_args,)
mod = from_exported_program(exported_program)
mod.show()

## {class}`~tvm.relax.frontend.nn.modules.LayerNorm`

In [None]:
from tvm.relax.frontend.nn import core, modules, spec
mod = modules.LayerNorm(8)
tvm_mod, params = mod.export_tvm(
    spec={"forward": {"x": spec.Tensor((2, 4, 8), "float32")}}, debug=True
)
tvm_mod.show()

## {func}`~tvm.relax.frontend.torch.from_fx`

In [None]:
import tvm
from tvm import relax
import tvm.testing
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.relax.frontend import detach_params
from tvm.relax.frontend.torch import from_fx

In [None]:
class LayerNorm(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ln = torch.nn.LayerNorm((10, 10))

    def forward(self, x):
        return self.ln(x)

input_info = [([1, 3, 10, 10], "float32")]

torch_model = LayerNorm()
graph_model = torch.fx.symbolic_trace(torch_model)
with torch.no_grad():
    mod = from_fx(graph_model, input_info)
    mod.show()

In [None]:
class LayerNorm(torch.nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(shape))
        self.bias = torch.nn.Parameter(torch.zeros(shape))

    def forward(self, x):
        return torch.nn.functional.layer_norm(
            x, self.weight.shape, self.weight, self.bias, 1e-5
        )

torch_model = LayerNorm((10, 10))
graph_model = torch.fx.symbolic_trace(torch_model)
with torch.no_grad():
    mod = from_fx(graph_model, input_info)
    mod.show()

In [None]:
class LayerNorm2(torch.nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape
        self.weight = None
        self.bias = None

    def forward(self, input):
        return torch.nn.functional.layer_norm(input, self.shape, self.weight, self.bias, 1e-5)

torch_model = LayerNorm2((10, 10))
graph_model = torch.fx.symbolic_trace(torch_model)
with torch.no_grad():
    mod = from_fx(graph_model, input_info)
    mod.show()

In [None]:
class LayerNorm3(torch.nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape
        self.weight = torch.nn.Parameter(torch.ones(shape))
        self.bias = torch.nn.Parameter(torch.zeros(shape))

    def forward(self, x):
        return torch.nn.functional.layer_norm(x, self.shape, self.weight, self.bias, 1e-5)

torch_model = LayerNorm3((10, 10))
graph_model = torch.fx.symbolic_trace(torch_model)
with torch.no_grad():
    mod = from_fx(graph_model, input_info)
    mod.show()