exporter#
import tvm
import tvm.testing
from tvm import relax, tir
from tvm.ir import assert_structural_equal
from tvm.relax.frontend import nn
from tvm.script import ir as I, relax as R, tir as T
"""The nn.modules.* may be exported from nn.Module to Relax"""
slm_mod = nn.modules.ReLU()
exported_mod, _ = slm_mod.export_tvm(
spec={"forward": {"x": nn.spec.Tensor((3, 3), "float32")}},
debug=False,
)
exported_mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
relu: R.Tensor((3, 3), dtype="float32") = R.nn.relu(x)
gv: R.Tensor((3, 3), dtype="float32") = relu
R.output(gv)
return gv
A user can define their own nn.Module subclasses
Like the built-in subclasses, these can be exported from nn.Module to Relax.
class Before(nn.Module):
def forward(self, x: R.Tensor):
return nn.op.relu(x)
slm_mod = Before()
exported_mod, _ = slm_mod.export_tvm(
spec={"forward": {"x": nn.spec.Tensor((3, 3), "float32")}},
debug=False,
)
exported_mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
relu: R.Tensor((3, 3), dtype="float32") = R.nn.relu(x)
gv: R.Tensor((3, 3), dtype="float32") = relu
R.output(gv)
return gv
Passing debug=True provides an argument for IO effects
slm_mod = nn.modules.ReLU()
exported_mod, _ = slm_mod.export_tvm(
spec={"forward": {"x": nn.spec.Tensor((3, 3), "float32")}},
debug=True,
)
exported_mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
lv: R.Tuple(R.Object) = (_io,)
gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv
@R.function
def forward(x: R.Tensor((3, 3), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
relu: R.Tensor((3, 3), dtype="float32") = R.nn.relu(x)
gv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)) = relu, (_io,)
R.output(gv1)
return gv1
An argument may have a dynamic shape
slm_mod = nn.modules.ReLU()
exported_mod, _ = slm_mod.export_tvm(
spec={"forward": {"x": nn.spec.Tensor([tir.Var("batch_size", "int64"), 8], "float32")}},
debug=False,
)
exported_mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor(("batch_size", 8), dtype="float32")) -> R.Tensor(("batch_size", 8), dtype="float32"):
batch_size = T.int64()
R.func_attr({"num_input": 1})
with R.dataflow():
relu: R.Tensor((batch_size, 8), dtype="float32") = R.nn.relu(x)
gv: R.Tensor((batch_size, 8), dtype="float32") = relu
R.output(gv)
return gv
A dynamic shape may be used in multiple functions
class Before(nn.Module):
def forward_relu(self, x: nn.Tensor):
return nn.relu(x)
def forward_silu(self, x: nn.Tensor):
return nn.silu(x)
slm_mod = Before()
exported_mod, _ = slm_mod.export_tvm(
spec={
"forward_relu": {"x": nn.spec.Tensor((tir.Var("batch_size", "int64"), 8), "float32")},
"forward_silu": {"x": nn.spec.Tensor((tir.Var("batch_size", "int64"), 8), "float32")},
},
debug=False,
)
exported_mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward_relu(x: R.Tensor(("batch_size", 8), dtype="float32")) -> R.Tensor(("batch_size", 8), dtype="float32"):
batch_size = T.int64()
R.func_attr({"num_input": 1})
with R.dataflow():
relu: R.Tensor((batch_size, 8), dtype="float32") = R.nn.relu(x)
gv: R.Tensor((batch_size, 8), dtype="float32") = relu
R.output(gv)
return gv
@R.function
def forward_silu(x: R.Tensor(("batch_size", 8), dtype="float32")) -> R.Tensor(("batch_size", 8), dtype="float32"):
batch_size = T.int64()
R.func_attr({"num_input": 1})
with R.dataflow():
silu: R.Tensor((batch_size, 8), dtype="float32") = R.nn.silu(x)
gv1: R.Tensor((batch_size, 8), dtype="float32") = silu
R.output(gv1)
return gv1
nn.Module instances may contain other nn.Module
When exporting to a Relax IRModule, all nn.Parameter
instances
within the nn.Module
become Relax function parameters.
class LlamaMLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.gate_proj = nn.Linear(
in_features=hidden_size,
out_features=intermediate_size,
dtype="float16",
bias=False,
)
self.up_proj = nn.Linear(
in_features=hidden_size,
out_features=intermediate_size,
dtype="float16",
bias=False,
)
self.down_proj = nn.Linear(
intermediate_size,
hidden_size,
dtype="float16",
bias=False,
)
def forward(self, x: nn.Tensor):
gate = self.gate_proj(x)
up = self.up_proj(x)
return self.down_proj(nn.op.silu(gate) * up)
hidden_size = 4096
intermediate_size = 11008
slm_mod = LlamaMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
exported_mod, _ = slm_mod.export_tvm(
spec={
"forward": {
"x": nn.spec.Tensor((tir.Var("batch_size", "int64"), hidden_size), "float16")
},
},
debug=False,
)
exported_mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor(("batch_size", 4096), dtype="float16"), gate_proj_weight: R.Tensor((11008, 4096), dtype="float16"), up_proj_weight: R.Tensor((11008, 4096), dtype="float16"), down_proj_weight: R.Tensor((4096, 11008), dtype="float16")) -> R.Tensor(("batch_size", 4096), dtype="float16"):
batch_size = T.int64()
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((4096, 11008), dtype="float16") = R.permute_dims(gate_proj_weight, axes=None)
matmul: R.Tensor((batch_size, 11008), dtype="float16") = R.matmul(x, permute_dims, out_dtype="void")
permute_dims1: R.Tensor((4096, 11008), dtype="float16") = R.permute_dims(up_proj_weight, axes=None)
matmul1: R.Tensor((batch_size, 11008), dtype="float16") = R.matmul(x, permute_dims1, out_dtype="void")
silu: R.Tensor((batch_size, 11008), dtype="float16") = R.nn.silu(matmul)
mul: R.Tensor((batch_size, 11008), dtype="float16") = R.multiply(silu, matmul1)
permute_dims2: R.Tensor((11008, 4096), dtype="float16") = R.permute_dims(down_proj_weight, axes=None)
matmul2: R.Tensor((batch_size, 4096), dtype="float16") = R.matmul(mul, permute_dims2, out_dtype="void")
gv: R.Tensor((batch_size, 4096), dtype="float16") = matmul2
R.output(gv)
return gv
@pytest.mark.xfail(reason="Not yet supported. See revert https://github.com/apache/tvm/pull/16777")
def test_generate_parameters():
"""Weights may be expressions in terms of other parameters
Optimizations often require preprocessing of the model weights.
1. Declare the `nn.Module` members that contain the original model
weights. These are used to define the parameter names when
reading from a Pytorch or Safetensors file.
2. Declare the `nn.Module` members, with the `weight` field
in terms of the un-optimized weights. These `nn.Module`
do not generate any parameters in the Relax function.
3. Define the `forward` function in terms of the `nn.Module`
members for the updated weight tensors.
The exported Relax function accepts the original model parameters,
computes the pre-processed weights, and then performs computations
using the pre-processed weights.
In this example, the `LiftTransformParams` transform is applied
immediately, splitting the Relax function into a pre-processing
step and an execution step. In practice, this transform would be
applied much later in an optimization pipeline, to allow optimized
compute kernels to be recognized. For example, in some cases
`R.matmul(x, R.permute_dims(weight))` may be computed more
efficiently than `R.matmul(x, weight_transpose)`. For this
reason, we do *not* apply `LiftTransformParams` as part of the
export from `nn.Module` to Relax.
"""
class LlamaMLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
# The nn.Linear for the original parameters are present in
# the model definition, and are still found when
# collecting a function's parameters.
self.gate_proj = nn.Linear(
in_features=hidden_size,
out_features=intermediate_size,
dtype="float16",
bias=False,
)
self.up_proj = nn.Linear(
in_features=hidden_size,
out_features=intermediate_size,
dtype="float16",
bias=False,
)
self.down_proj = nn.Linear(
intermediate_size,
hidden_size,
dtype="float16",
bias=False,
)
# At runtime, we'd like to have a single concatenated
# tensor containing both the gate and up projection
# weights. We also want to use it in the `forward`
# function as if it owned its own weights.
self.gate_up_proj = nn.Linear(
in_features=hidden_size,
out_features=intermediate_size,
dtype="float16",
bias=False,
)
# The weight tensor of `gate_up_proj` can be overwritten
# in terms of the original `gate_proj` and `up_proj`
# tensors.
self.gate_up_proj.weight = nn.op.concat(
[self.gate_proj.weight, self.up_proj.weight], dim=0, name="gate_up_proj_weights"
)
def forward(self, x: nn.Tensor):
# Even though the `gate_up_proj` weights are defined as an
# expression rather than a `nn.Parameter`, the `forward`
# function does not require any special handling for it.
concat_gate_up = self.gate_up_proj(x)
gate, up = nn.op.split(concat_gate_up, 2, axis=-1)
return self.down_proj(nn.op.silu(gate) * up)
hidden_size = 4096
intermediate_size = 11008
slm_mod = LlamaMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
exported_mod, _ = slm_mod.export_tvm(
spec={
"forward": {
"x": nn.spec.Tensor((tir.Var("batch_size", "int64"), hidden_size), "float16")
},
},
debug=False,
)
@I.ir_module
class Expected:
@R.function
def forward(
x: R.Tensor(["batch_size", hidden_size], "float16"),
# The function's parameters are defined by the
# `nn.Parameter` instances, and still reference the
# original `gate_proj` and `up_proj` weights. This
# maintains compatibility with named model weights in a
# Pytorch or Safetensors file.
gate_proj_weights: R.Tensor([intermediate_size, hidden_size], "float16"),
up_proj_weights: R.Tensor([intermediate_size, hidden_size], "float16"),
down_proj_weights: R.Tensor([hidden_size, intermediate_size], "float16"),
):
R.func_attr({"num_input": 1})
batch_size = T.int64()
with R.dataflow():
# At this stage of compilation, the concatenation is
# written within the body of the function. This will
# later be extracted into a pre-processing step using
# `relax.transform.LiftTransformParams`.
gate_up_proj_weights: R.Tensor(
[intermediate_size * 2, hidden_size], "float16"
) = R.concat([gate_proj_weights, up_proj_weights], axis=0)
gate_up: R.Tensor([batch_size, intermediate_size * 2], "float16") = R.matmul(
x, R.permute_dims(gate_up_proj_weights)
)
gate_up_split = R.split(gate_up, 2, axis=-1)
gate = gate_up_split[0]
up = gate_up_split[1]
down: R.Tensor([batch_size, hidden_size], "float16") = R.matmul(
R.nn.silu(gate) * up, R.permute_dims(down_proj_weights)
)
R.output(down)
return down
assert_structural_equal(exported_mod, Expected)
@I.ir_module
class ExpectedAfterLift:
@R.function
def forward(
x: R.Tensor(["batch_size", hidden_size], "float16"),
# After `relax.transform.LiftTransformParams`, the
# `gate_proj` and `up_proj` weights have been concatenated
# together.
gate_up_proj_weights_transpose: R.Tensor(
[hidden_size, intermediate_size * 2], "float16"
),
down_proj_weights_transpose: R.Tensor([intermediate_size, hidden_size], "float16"),
):
R.func_attr({"num_input": 1})
batch_size = T.int64()
with R.dataflow():
gate_up: R.Tensor([batch_size, intermediate_size * 2], "float16") = R.matmul(
x, gate_up_proj_weights_transpose
)
gate_up_split = R.split(gate_up, 2, axis=-1)
gate = gate_up_split[0]
up = gate_up_split[1]
down: R.Tensor([batch_size, hidden_size], "float16") = R.matmul(
R.nn.silu(gate) * up, down_proj_weights_transpose
)
R.output(down)
return down
@R.function
def transform_params(
model_params: R.Tuple(
R.Tensor([intermediate_size, hidden_size], "float16"),
R.Tensor([intermediate_size, hidden_size], "float16"),
R.Tensor([hidden_size, intermediate_size], "float16"),
)
):
R.func_attr({"num_input": 0})
with R.dataflow():
gate_proj_weights: R.Tensor(
[intermediate_size, hidden_size], "float16"
) = model_params[0]
up_proj_weights: R.Tensor(
[intermediate_size, hidden_size], "float16"
) = model_params[1]
gate_up_proj_weights: R.Tensor(
[intermediate_size * 2, hidden_size], "float16"
) = R.concat([gate_proj_weights, up_proj_weights], axis=0)
gate_up_proj_weights_transpose: R.Tensor(
[hidden_size, intermediate_size * 2], "float16"
) = R.permute_dims(gate_up_proj_weights)
down_proj_weights: R.Tensor(
[hidden_size, intermediate_size], "float16"
) = model_params[2]
down_proj_weights_transpose: R.Tensor(
[intermediate_size, hidden_size], "float16"
) = R.permute_dims(down_proj_weights)
output = (gate_up_proj_weights_transpose, down_proj_weights_transpose)
R.output(output)
return output
lifted_mod = relax.transform.LiftTransformParams(shared_transform=True)(exported_mod)
assert_structural_equal(lifted_mod, ExpectedAfterLift)
def test_linear_dynamic_shape():
"""The weight and bias of nn.Linear have the same out_features
Even if dynamic, the weight/bias must be the same value.
"""
@R.function
def forward(
x: R.Tensor((1, 4), dtype="float32"),
_io: R.Object,
weight: R.Tensor(("n", 4), dtype="float32"),
bias: R.Tensor(("n",), dtype="float32"),
) -> R.Tuple(R.Tensor((1, "n"), dtype="float32"), R.Tuple(R.Object)):
n = T.int64()
R.func_attr({"num_input": 2})
with R.dataflow():
permute_dims: R.Tensor((4, n), dtype="float32") = R.permute_dims(weight, axes=None)
matmul: R.Tensor((1, n), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((1, n), dtype="float32") = R.add(matmul, bias)
gv1: R.Tuple(R.Tensor((1, n), dtype="float32"), R.Tuple(R.Object)) = add, (_io,)
R.output(gv1)
return gv1
mod = nn.modules.Linear(in_features=4, out_features="n", bias=True)
tvm_mod, _ = mod.export_tvm(
spec={"forward": {"x": nn.spec.Tensor((1, 4), "float32")}}, debug=True
)
assert_structural_equal(tvm_mod["forward"], forward, True)
@pytest.mark.parametrize(
"dynamic_type",
[
"same_python_string",
"different_python_string",
"same_tir_var",
"distinct_tir_vars_with_distinct_names",
pytest.param(
"distinct_tir_vars_with_same_name",
marks=pytest.mark.xfail(
reason="Not yet supported. See revert https://github.com/apache/tvm/pull/16777"
),
),
],
)
def test_duplicate_names(dynamic_type):
class Linear(nn.Module):
def __init__(self, input_size, output_size):
self.weights = nn.Parameter([output_size, input_size], dtype="float32")
def forward(self, state: nn.Tensor):
matmul_weights = nn.op.permute_dims(self.weights)
return nn.op.matmul(state, matmul_weights)
class Model(nn.Module):
def __init__(self, hidden_size, intermediate_size):
self.embedding = Linear(1024, hidden_size)
self.up = Linear(hidden_size, intermediate_size)
self.down = Linear(intermediate_size, hidden_size)
def forward(self, state: nn.Tensor):
state = self.embedding(state)
state = self.up(state)
state = nn.op.silu(state)
assert state.dtype == "float32"
state = self.down(state)
return state
if dynamic_type == "same_python_string":
# Python strings have value equality. Providing the same name
# for two different shape parameters results in a single
# symbolic variable.
args = ["hidden_size", "hidden_size"]
expected_num_symbolic_vars = 1
elif dynamic_type == "different_python_string":
# Providing two distinct variable names for the two different
# shape parameters results in two distinct symbolic variables.
args = ["hidden_size", "intermediate_size"]
expected_num_symbolic_vars = 2
elif dynamic_type == "same_tir_var":
# Symbolic variables can be specified as tir.Var instances.
# Providing the same variable for the two different shape
# parameters uses the symbolic variable in both locations.
dim = tir.Var("hidden_size", "int64")
args = [dim, dim]
expected_num_symbolic_vars = 1
elif dynamic_type == "distinct_tir_vars_with_distinct_names":
# Providing distinct TIR variables for the two different shape
# parameters uses each TIR variable in the specified location.
args = [tir.Var("hidden_size", "int64"), tir.Var("intermediate_size", "int64")]
expected_num_symbolic_vars = 2
elif dynamic_type == "distinct_tir_vars_with_same_name":
# TIR variable have reference equality. Even if two different
# TIR variables have the same name, providing two distinct TIR
# variables still results in two distinct symbolic variables.
args = [tir.Var("hidden_size", "int64"), tir.Var("hidden_size", "int64")]
expected_num_symbolic_vars = 2
else:
raise ValueError(f"Unexpected dynamic_type: {dynamic_type}")
slm_mod = Model(*args)
exported_mod, _ = slm_mod.export_tvm(
spec={
"forward": {"state": nn.spec.Tensor(["batch_size", 1024], dtype="float32")},
},
debug=False,
)
def get_expected_with_intermediate_size():
@I.ir_module
class Expected:
@R.function
def forward(
state: R.Tensor(["batch_size", 1024], "float32"),
embedding_weights: R.Tensor(["hidden_size", 1024], "float32"),
up_weights: R.Tensor(["intermediate_size", "hidden_size"], "float32"),
down_weights: R.Tensor(["hidden_size", "intermediate_size"], "float32"),
):
R.func_attr({"num_input": 1})
batch_size = T.int64()
hidden_size = T.int64()
intermediate_size = T.int64()
with R.dataflow():
state: R.Tensor([batch_size, hidden_size], "float32") = R.matmul(
state, R.permute_dims(embedding_weights)
)
state: R.Tensor([batch_size, intermediate_size], "float32") = R.matmul(
state, R.permute_dims(up_weights)
)
state: R.Tensor([batch_size, intermediate_size], "float32") = R.nn.silu(state)
state: R.Tensor([batch_size, hidden_size], "float32") = R.matmul(
state, R.permute_dims(down_weights)
)
state = state
R.output(state)
return state
return Expected
def get_expected_without_intermediate_size():
@I.ir_module
class Expected:
@R.function
def forward(
state: R.Tensor(["batch_size", 1024], "float32"),
embedding_weights: R.Tensor(["hidden_size", 1024], "float32"),
up_weights: R.Tensor(["hidden_size", "hidden_size"], "float32"),
down_weights: R.Tensor(["hidden_size", "hidden_size"], "float32"),
):
R.func_attr({"num_input": 1})
batch_size = T.int64()
hidden_size = T.int64()
with R.dataflow():
state: R.Tensor([batch_size, hidden_size], "float32") = R.matmul(
state, R.permute_dims(embedding_weights)
)
state: R.Tensor([batch_size, hidden_size], "float32") = R.matmul(
state, R.permute_dims(up_weights)
)
state: R.Tensor([batch_size, hidden_size], "float32") = R.nn.silu(state)
state: R.Tensor([batch_size, hidden_size], "float32") = R.matmul(
state, R.permute_dims(down_weights)
)
state = state
R.output(state)
return state
return Expected
if expected_num_symbolic_vars == 1:
expected = get_expected_without_intermediate_size()
elif expected_num_symbolic_vars == 2:
expected = get_expected_with_intermediate_size()
else:
raise ValueError(f"Unexpected number of symbolic vars: {expected_num_symbolic_vars}")
assert_structural_equal(exported_mod["forward"], expected["forward"], True)