subroutines

subroutines#

import tvm
import tvm.testing
from tvm import relax
from tvm.ir import assert_structural_equal
from tvm.relax.frontend import nn
from tvm.script import ir as I
from tvm.script import relax as R
class Activation(nn.Module):
    define_subroutine = True

    def forward(self, state: relax.Expr) -> relax.Var:
        return nn.op.silu(state)

class Layer(nn.Module):
    define_subroutine = True

    def __init__(self, in_features, out_features):
        self.weights = nn.Parameter((in_features, out_features), dtype="float32")
        self.activation = Activation()

    def forward(self, input: relax.Expr) -> relax.Var:
        state = nn.op.matmul(input, self.weights)
        return self.activation(state)
mod = Layer(64, 32)
batch_size = tvm.tir.Var("batch_size", "int64")
tvm_mod, _ = mod.export_tvm(
    spec={"forward": {"input": nn.spec.Tensor((batch_size, 64), "float32")}}, debug=True
)
tvm_mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def _initialize_effect() -> R.Tuple(R.Object):
        with R.dataflow():
            _io: R.Object = R.null_value()
            lv: R.Tuple(R.Object) = (_io,)
            gv: R.Tuple(R.Object) = lv
            R.output(gv)
        return gv

    @R.function(private=True)
    def activation(state: R.Tensor(("batch_size", 32), dtype="float32")) -> R.Tensor(("batch_size", 32), dtype="float32"):
        batch_size = T.int64()
        with R.dataflow():
            silu: R.Tensor((batch_size, 32), dtype="float32") = R.nn.silu(state)
            gv1: R.Tensor((batch_size, 32), dtype="float32") = silu
            R.output(gv1)
        return gv1

    @R.function
    def forward(input: R.Tensor(("batch_size", 64), dtype="float32"), _io: R.Object, weights: R.Tensor((64, 32), dtype="float32")) -> R.Tuple(R.Tensor(("batch_size", 32), dtype="float32"), R.Tuple(R.Object)):
        batch_size = T.int64()
        R.func_attr({"num_input": 2})
        cls = Module
        with R.dataflow():
            layer_output: R.Tensor((batch_size, 32), dtype="float32") = cls.layer(input, weights)
            gv3: R.Tuple(R.Tensor((batch_size, 32), dtype="float32"), R.Tuple(R.Object)) = layer_output, (_io,)
            R.output(gv3)
        return gv3

    @R.function(private=True)
    def layer(input: R.Tensor(("batch_size", 64), dtype="float32"), weights: R.Tensor((64, 32), dtype="float32")) -> R.Tensor(("batch_size", 32), dtype="float32"):
        batch_size = T.int64()
        cls = Module
        with R.dataflow():
            matmul: R.Tensor((batch_size, 32), dtype="float32") = R.matmul(input, weights, out_dtype="void")
            activation_output: R.Tensor((batch_size, 32), dtype="float32") = cls.activation(matmul)
            gv2: R.Tensor((batch_size, 32), dtype="float32") = activation_output
            R.output(gv2)
        return gv2