# subroutines

In [1]:
import tvm
import tvm.testing
from tvm import relax
from tvm.ir import assert_structural_equal
from tvm.relax.frontend import nn
from tvm.script import ir as I
from tvm.script import relax as R

In [2]:
class Activation(nn.Module):
    define_subroutine = True

    def forward(self, state: relax.Expr) -> relax.Var:
        return nn.op.silu(state)

class Layer(nn.Module):
    define_subroutine = True

    def __init__(self, in_features, out_features):
        self.weights = nn.Parameter((in_features, out_features), dtype="float32")
        self.activation = Activation()

    def forward(self, input: relax.Expr) -> relax.Var:
        state = nn.op.matmul(input, self.weights)
        return self.activation(state)

In [3]:
mod = Layer(64, 32)
batch_size = tvm.tir.Var("batch_size", "int64")
tvm_mod, _ = mod.export_tvm(
    spec={"forward": {"input": nn.spec.Tensor((batch_size, 64), "float32")}}, debug=True
)
tvm_mod.show()