LayerNorm 前端支持#
from_exported_program()#
import torch
from torch.export import export
import tvm
from tvm import relax
import tvm.testing
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.relax.frontend.torch import from_exported_program
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.ln = torch.nn.RMSNorm((32))
def forward(self, x):
return self.ln(x)
input_shape = [1, 3, 32, 32]
model = M().eval()
input_data = torch.randn(input_shape)
torch.onnx.export(
model,
input_data,
".temp/test.onnx",
input_names=["x"],
dynamo=True,
opset_version=23,
do_constant_folding=True,
export_params=True,
dump_exported_program=True,
keep_initializers_as_inputs=False,
)
[torch.onnx] Obtain model graph for `M([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `M([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Dumping ExportedProgram because `dump_exported_program=True`...
[torch.onnx] ExportedProgram has been saved to 'onnx_export_2025-09-12_17-11-16-671771.pt2'.
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
ONNXProgram(
model=
<
ir_version=10,
opset_imports={'': 23},
producer_name='pytorch',
producer_version='2.8.0+cu128',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"x"<FLOAT,[1,3,32,32]>
),
outputs=(
%"type_as"<FLOAT,[1,3,32,32]>
),
initializers=(
%"ln.weight"<FLOAT,[32]>{TorchTensor(...)},
%"val_0"<FLOAT,[]>{Tensor<FLOAT,[]>(array(2., dtype=float32), name='val_0')},
%"val_3"<INT64,[1]>{Tensor<INT64,[1]>(array([3]), name='val_3')},
%"val_4"<FLOAT,[]>{Tensor<FLOAT,[]>(array(1.1920929e-07, dtype=float32), name='val_4')}
),
) {
0 | # node_pow_1
%"pow_1"<FLOAT,[1,3,32,32]> ⬅️ ::Pow(%"x", %"val_0"{2.0})
1 | # node_mean
%"mean"<FLOAT,[1,3,32,1]> ⬅️ ::ReduceMean(%"pow_1", %"val_3"{[3]}) {keepdims=True, noop_with_empty_axes=0}
2 | # node_add
%"add"<FLOAT,[1,3,32,1]> ⬅️ ::Add(%"mean", %"val_4"{1.1920928955078125e-07})
3 | # node_Sqrt_5
%"val_5"<FLOAT,[1,3,32,1]> ⬅️ ::Sqrt(%"add")
4 | # node_rsqrt
%"rsqrt"<FLOAT,[1,3,32,1]> ⬅️ ::Reciprocal(%"val_5")
5 | # node_mul
%"mul"<FLOAT,[1,3,32,32]> ⬅️ ::Mul(%"x", %"rsqrt")
6 | # node_type_as
%"type_as"<FLOAT,[1,3,32,32]> ⬅️ ::Mul(%"mul", %"ln.weight"{...})
return %"type_as"<FLOAT,[1,3,32,32]>
}
,
exported_program=
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_ln_weight: "f32[32]", x: "f32[1, 3, 32, 32]"):
# File: /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/nn/modules/normalization.py:402 in forward, code: return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
pow_1: "f32[1, 3, 32, 32]" = torch.ops.aten.pow.Tensor_Scalar(x, 2)
mean: "f32[1, 3, 32, 1]" = torch.ops.aten.mean.dim(pow_1, [3], True); pow_1 = None
add: "f32[1, 3, 32, 1]" = torch.ops.aten.add.Scalar(mean, 1.1920928955078125e-07); mean = None
rsqrt: "f32[1, 3, 32, 1]" = torch.ops.aten.rsqrt.default(add); add = None
mul: "f32[1, 3, 32, 32]" = torch.ops.aten.mul.Tensor(x, rsqrt); rsqrt = None
mul_1: "f32[1, 3, 32, 32]" = torch.ops.aten.mul.Tensor(mul, p_ln_weight); mul = p_ln_weight = None
type_as: "f32[1, 3, 32, 32]" = torch.ops.aten.type_as.default(mul_1, x); mul_1 = x = None
return (type_as,)
Graph signature:
# inputs
p_ln_weight: PARAMETER target='ln.weight'
x: USER_INPUT
# outputs
type_as: USER_OUTPUT
Range constraints: {}
)
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
model = LayerNorm()
exported_program = export(model, args=example_args,)
mod = from_exported_program(exported_program)
mod.show()
LayerNorm#
from tvm.relax.frontend.nn import core, modules, spec
mod = modules.LayerNorm(8)
tvm_mod, params = mod.export_tvm(
spec={"forward": {"x": spec.Tensor((2, 4, 8), "float32")}}, debug=True
)
tvm_mod.show()
from_fx()#
import tvm
from tvm import relax
import tvm.testing
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.relax.frontend import detach_params
from tvm.relax.frontend.torch import from_fx
class LayerNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.ln = torch.nn.LayerNorm((10, 10))
def forward(self, x):
return self.ln(x)
input_info = [([1, 3, 10, 10], "float32")]
torch_model = LayerNorm()
graph_model = torch.fx.symbolic_trace(torch_model)
with torch.no_grad():
mod = from_fx(graph_model, input_info)
mod.show()
class LayerNorm(torch.nn.Module):
def __init__(self, shape):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(shape))
self.bias = torch.nn.Parameter(torch.zeros(shape))
def forward(self, x):
return torch.nn.functional.layer_norm(
x, self.weight.shape, self.weight, self.bias, 1e-5
)
torch_model = LayerNorm((10, 10))
graph_model = torch.fx.symbolic_trace(torch_model)
with torch.no_grad():
mod = from_fx(graph_model, input_info)
mod.show()
class LayerNorm2(torch.nn.Module):
def __init__(self, shape):
super().__init__()
self.shape = shape
self.weight = None
self.bias = None
def forward(self, input):
return torch.nn.functional.layer_norm(input, self.shape, self.weight, self.bias, 1e-5)
torch_model = LayerNorm2((10, 10))
graph_model = torch.fx.symbolic_trace(torch_model)
with torch.no_grad():
mod = from_fx(graph_model, input_info)
mod.show()
class LayerNorm3(torch.nn.Module):
def __init__(self, shape):
super().__init__()
self.shape = shape
self.weight = torch.nn.Parameter(torch.ones(shape))
self.bias = torch.nn.Parameter(torch.zeros(shape))
def forward(self, x):
return torch.nn.functional.layer_norm(x, self.shape, self.weight, self.bias, 1e-5)
torch_model = LayerNorm3((10, 10))
graph_model = torch.fx.symbolic_trace(torch_model)
with torch.no_grad():
mod = from_fx(graph_model, input_info)
mod.show()