# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
I.module_attrs({"input_num": 1, "optim_state": [metadata["runtime.NDArray"][0], metadata["runtime.NDArray"][1]], "param_num": 1, "state_num": 1})
@R.function
def backbone(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), z: R.Tensor((2, 2), dtype="float64")) -> R.Tuple(R.Tensor((2, 2), dtype="float64"), R.Tensor((2, 2), dtype="float64")):
with R.dataflow():
x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
z1: R.Tensor((2, 2), dtype="float64") = R.add(z, R.const(1.0, "float64"))
R.output(x1, z1)
return (x1, z1)
@R.function
def backbone_loss(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), z: R.Tensor((2, 2), dtype="float64"), targets: R.Tensor((2, 2), dtype="float64")) -> R.Tuple(R.Tensor((), dtype="float64"), R.Tensor((2, 2), dtype="float64")):
with R.dataflow():
x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
z1: R.Tensor((2, 2), dtype="float64") = R.add(z, R.const(1.0, "float64"))
lv: R.Tensor((2, 2), dtype="float64") = R.subtract(x1, targets)
lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv, lv)
gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False)
R.output(z1, gv)
return (gv, z1)
@R.function
def backbone_loss_adjoint(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), z: R.Tensor((2, 2), dtype="float64"), targets: R.Tensor((2, 2), dtype="float64")) -> R.Tuple(R.Tuple(R.Tensor((), dtype="float64"), R.Tensor((2, 2), dtype="float64")), R.Tuple(R.Tensor((2, 2), dtype="float64"))):
with R.dataflow():
x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
z1: R.Tensor((2, 2), dtype="float64") = R.add(z, R.const(1.0, "float64"))
lv: R.Tensor((2, 2), dtype="float64") = R.subtract(x1, targets)
lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv, lv)
gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False)
gv_adjoint: R.Tensor((), dtype="float64") = R.ones(R.shape([]), dtype="float64")
lv1_adjoint: R.Tensor((2, 2), dtype="float64") = R.broadcast_to(gv_adjoint, R.shape([2, 2]))
lv_adjoint: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv)
lv_1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv)
lv_adjoint1: R.Tensor((2, 2), dtype="float64") = R.add(lv_adjoint, lv_1)
x1_adjoint: R.Tensor((2, 2), dtype="float64") = lv_adjoint1
y_adjoint: R.Tensor((2, 2), dtype="float64") = x1_adjoint
y_adjoint_out: R.Tensor((2, 2), dtype="float64") = y_adjoint
R.output(z1, gv, y_adjoint_out)
return ((gv, z1), (y_adjoint_out,))
@R.function
def optimizer(params: R.Tuple(R.Tensor((2, 2), dtype="float64")), gradients: R.Tuple(R.Tensor((2, 2), dtype="float64")), optim_states: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((2, 2), dtype="float64"))) -> R.Tuple(R.Tuple(R.Tensor((2, 2), dtype="float64")), R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((2, 2), dtype="float64"))):
with R.dataflow():
num_steps: R.Tensor((), dtype="int64") = optim_states[0]
num_steps_new: R.Tensor((), dtype="int64") = R.add(num_steps, R.const(1, "int64"))
y: R.Tensor((2, 2), dtype="float64") = params[0]
y_grad: R.Tensor((2, 2), dtype="float64") = gradients[0]
y_v: R.Tensor((2, 2), dtype="float64") = optim_states[1]
lv: R.Tensor((2, 2), dtype="float64") = R.multiply(R.const(0.10000000000000001, "float64"), y_v)
y_v_new: R.Tensor((2, 2), dtype="float64") = R.add(lv, y_grad)
lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(R.const(0.10000000000000001, "float64"), y_v_new)
y_new: R.Tensor((2, 2), dtype="float64") = R.subtract(y, lv1)
params_new: R.Tuple(R.Tensor((2, 2), dtype="float64")) = (y_new,)
optim_states_new: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((2, 2), dtype="float64")) = num_steps_new, y_v_new
R.output(params_new, optim_states_new)
return (params_new, optim_states_new)
# Metadata omitted. Use show_meta=True in script() method to show it.