ONNX Script optimizer

ONNX Script optimizer#

from pathlib import Path
import tempfile
import numpy as np
import onnx
import onnxruntime

from onnxscript import optimizer
from onnxscript.utils import evaluation_utils
model_folder_path = (
    Path(".").resolve().parents[3]/f"tests/testdata/e2e_models"
)
# List all entries in the directory and filter for directories
model_names = [entry.name for entry in model_folder_path.iterdir() if entry.is_dir()]
for model_name in model_names:
    if model_name == "torchscript_model":
        continue
    model_dir = Path(model_folder_path) / model_name / "dynamo"
    model_path = model_dir / f"{model_name}_dynamo.onnx"
    model = onnx.load(model_path)
    model = optimizer.optimize(model, onnx_shape_inference=False)
    with tempfile.TemporaryDirectory() as tmp_folder:
        tmp_folder = Path(tmp_folder)
        optimized_model_path = tmp_folder / f"{model_name}_opt.onnx"
        onnx.save(
            model,
            optimized_model_path,
            save_as_external_data=True,
            all_tensors_to_one_file=True,
        )

        session = onnxruntime.InferenceSession(
            optimized_model_path, providers=("CPUExecutionProvider",)
        )

        inputs, expected_outputs = evaluation_utils.load_test_data(
            model_dir, [i.name for i in model.graph.input]
        )

        input_names = [i.name for i in session.get_inputs()]
        assert set(input_names) == set(inputs.keys())

        outputs = session.run(None, inputs)
        # Free the session so the model file is no longer used
        del session

        for output, expected_output in zip(outputs, expected_outputs):
            np.testing.assert_allclose(output, expected_output, rtol=1e-3, atol=1e-3)
len(value_info): 7144
Applied 44 of general pattern rewrite rules.
len(value_info): 860
len(value_info): 4218
Applied 0 of general pattern rewrite rules.
len(value_info): 768
len(value_info): 23992
Applied 0 of general pattern rewrite rules.
len(value_info): 2069
len(value_info): 10935
Applied 0 of general pattern rewrite rules.
len(value_info): 1755
len(value_info): 5182
Applied 0 of general pattern rewrite rules.
len(value_info): 480
len(value_info): 2584
Applied 0 of general pattern rewrite rules.
len(value_info): 217
Skip storing constant folded nvalue self_4 due to large size 1050624.
Skip storing constant folded nvalue result_1 due to large size 2097152.
Skip storing constant folded nvalue t_4 due to large size 2097152.
Skip storing constant folded nvalue result_1 due to large size 2097152.
Skip storing constant folded nvalue t_5 due to large size 2097152.
Skip storing constant folded nvalue result_1 due to large size 2097152.
Skip storing constant folded nvalue t_10 due to large size 2097152.
Skip storing constant folded nvalue result_1 due to large size 2097152.
Skip storing constant folded nvalue t_11 due to large size 2097152.
Skip storing constant folded nvalue result_1 due to large size 2097152.
Skip storing constant folded nvalue t_16 due to large size 2097152.
Skip storing constant folded nvalue result_1 due to large size 2097152.
Skip storing constant folded nvalue t_17 due to large size 2097152.
Skip storing constant folded nvalue result_1 due to large size 2097152.
Skip storing constant folded nvalue t_22 due to large size 2097152.
Skip storing constant folded nvalue result_1 due to large size 2097152.
Skip storing constant folded nvalue t_23 due to large size 2097152.
Skip storing constant folded nvalue result_1 due to large size 2097152.
Skip storing constant folded nvalue t_28 due to large size 2097152.
Skip storing constant folded nvalue result_1 due to large size 2097152.
Skip storing constant folded nvalue t_29 due to large size 2097152.
Skip storing constant folded nvalue result_1 due to large size 2097152.
Skip storing constant folded nvalue t_34 due to large size 2097152.
Skip storing constant folded nvalue result_1 due to large size 2097152.
Skip storing constant folded nvalue t_35 due to large size 2097152.
Skip storing constant folded nvalue result_1 due to large size 10240000.
Skip storing constant folded nvalue t_36 due to large size 10240000.
Skip storing constant folded nvalue t_4 due to large size 2097152.
Skip storing constant folded nvalue t_5 due to large size 2097152.
Skip storing constant folded nvalue t_10 due to large size 2097152.
Skip storing constant folded nvalue t_11 due to large size 2097152.
Skip storing constant folded nvalue t_16 due to large size 2097152.
Skip storing constant folded nvalue t_17 due to large size 2097152.
Skip storing constant folded nvalue t_22 due to large size 2097152.
Skip storing constant folded nvalue t_23 due to large size 2097152.
Skip storing constant folded nvalue t_28 due to large size 2097152.
Skip storing constant folded nvalue t_29 due to large size 2097152.
Skip storing constant folded nvalue t_34 due to large size 2097152.
Skip storing constant folded nvalue t_35 due to large size 2097152.
Skip storing constant folded nvalue torch_nn_modules_linear_Linear_lm_head_1_8_t_36 due to large size 10240000.
Skip storing constant folded nvalue result_1 due to large size 5120000.
Skip storing constant folded nvalue torch_nn_modules_linear_Linear_classifier_1_6_t due to large size 5120000.
Skip storing constant folded nvalue torch_nn_modules_linear_Linear_classifier_1_6_t due to large size 5120000.
Skip storing constant folded nvalue result_1 due to large size 2048000.
Skip storing constant folded nvalue torch_nn_modules_linear_Linear_fc_1_12_t due to large size 2048000.
Skip storing constant folded nvalue torch_nn_modules_linear_Linear_fc_1_12_t due to large size 2048000.