# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
lv: R.Tuple(R.Object) = (_io,)
gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv
@R.function(private=True)
def activation(state: R.Tensor(("batch_size", 32), dtype="float32")) -> R.Tensor(("batch_size", 32), dtype="float32"):
batch_size = T.int64()
with R.dataflow():
silu: R.Tensor((batch_size, 32), dtype="float32") = R.nn.silu(state)
gv1: R.Tensor((batch_size, 32), dtype="float32") = silu
R.output(gv1)
return gv1
@R.function
def forward(input: R.Tensor(("batch_size", 64), dtype="float32"), _io: R.Object, weights: R.Tensor((64, 32), dtype="float32")) -> R.Tuple(R.Tensor(("batch_size", 32), dtype="float32"), R.Tuple(R.Object)):
batch_size = T.int64()
R.func_attr({"num_input": 2})
cls = Module
with R.dataflow():
layer_output: R.Tensor((batch_size, 32), dtype="float32") = cls.layer(input, weights)
gv3: R.Tuple(R.Tensor((batch_size, 32), dtype="float32"), R.Tuple(R.Object)) = layer_output, (_io,)
R.output(gv3)
return gv3
@R.function(private=True)
def layer(input: R.Tensor(("batch_size", 64), dtype="float32"), weights: R.Tensor((64, 32), dtype="float32")) -> R.Tensor(("batch_size", 32), dtype="float32"):
batch_size = T.int64()
cls = Module
with R.dataflow():
matmul: R.Tensor((batch_size, 32), dtype="float32") = R.matmul(input, weights, out_dtype="void")
activation_output: R.Tensor((batch_size, 32), dtype="float32") = cls.activation(matmul)
gv2: R.Tensor((batch_size, 32), dtype="float32") = activation_output
R.output(gv2)
return gv2