packing

packing#

import tvm
from tvm.relax.frontend import nn
from tvm.script import ir as I
from tvm.script import relax as R

def _iter_binding_names(mod):
    """Helper function to compare the names of relax variables"""
    for block in mod["forward"].body.blocks:
        for binding in block.bindings:
            yield binding.var.name_hint
class TestModule(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.linear_1 = nn.Linear(in_features, out_features, bias=False)
        self.linear_2 = nn.Linear(in_features, out_features, bias=False)

    def forward(self, x: nn.Tensor):
        x1 = self.linear_1(x)
        x2 = self.linear_2(x)
        return x1 + x2

model = TestModule(10, 20)
mod, _ = model.export_tvm(
    spec={
        "forward": {
            "x": nn.spec.Tensor([1, model.in_features], "float32"),
            "$": {
                "param_mode": "packed",
                "effect_mode": "none",
            },
        }
    }
)
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor((1, 10), dtype="float32"), packed_params: R.Tuple(R.Tensor((20, 10), dtype="float32"), R.Tensor((20, 10), dtype="float32"))) -> R.Tensor((1, 20), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            linear_1_weight: R.Tensor((20, 10), dtype="float32") = packed_params[0]
            linear_2_weight: R.Tensor((20, 10), dtype="float32") = packed_params[1]
            matmul_1_weight: R.Tensor((10, 20), dtype="float32") = R.permute_dims(linear_1_weight, axes=None)
            matmul: R.Tensor((1, 20), dtype="float32") = R.matmul(x, matmul_1_weight, out_dtype="void")
            matmul_2_weight: R.Tensor((10, 20), dtype="float32") = R.permute_dims(linear_2_weight, axes=None)
            matmul1: R.Tensor((1, 20), dtype="float32") = R.matmul(x, matmul_2_weight, out_dtype="void")
            add: R.Tensor((1, 20), dtype="float32") = R.add(matmul, matmul1)
            gv: R.Tensor((1, 20), dtype="float32") = add
            R.output(gv)
        return gv