jit#
import torch
import tvm
import tvm.testing
from tvm import tir
from tvm.relax.frontend.nn import spec
from tvm.relax.frontend import nn
class Layer(nn.Module):
def __init__(self):
pass
def forward(self, x: nn.Tensor):
y = nn.add(x, x)
return y
forward_spec = {"forward": {"x": spec.Tensor([10, 5], dtype="float32")}}
mod = Layer()
for debug in [False, True]:
model = mod.jit(spec=forward_spec, debug=debug)
x = torch.rand((10, 5), dtype=torch.float32)
y = model["forward"](x)
assert isinstance(y, torch.Tensor)
assert torch.allclose(x + x, y)
class Layer(nn.Module):
def __init__(self):
pass
def forward(self, x: nn.Tensor, i: tir.Var):
y = nn.add(x, x)
y = nn.reshape(y, (i, 5, 5))
return y
forward_spec = {"forward": {"x": spec.Tensor([10, 5], dtype="float32"), "i": int}}
mod = Layer()
for debug in [False, True]:
model = mod.jit(spec=forward_spec, debug=debug)
x = torch.rand((10, 5), dtype=torch.float32)
y = model["forward"](x, 2)
assert isinstance(y, torch.Tensor)
assert torch.allclose(torch.reshape(x + x, (2, 5, 5)), y)
class Layer(nn.Module):
def __init__(self):
self.cache = nn.KVCache(10, [10, 5])
def forward(self, x: nn.Tensor, total_seq_len: tir.Var):
self.cache.append(x)
y = self.cache.view(total_seq_len)
return y
forward_spec = {
"forward": {"x": spec.Tensor([1, 10, 5], dtype="float32"), "total_seq_len": int}
}
mod = Layer()
for debug in [False, True]:
with tvm.transform.PassContext(opt_level=3):
model = mod.jit(spec=forward_spec, debug=debug)
x0 = torch.rand((1, 10, 5), dtype=torch.float32)
y = model["forward"](x0, 1)
assert isinstance(y, torch.Tensor)
assert torch.allclose(x0, y)
x1 = torch.rand((1, 10, 5), dtype=torch.float32)
y = model["forward"](x1, 2)
assert torch.allclose(torch.concat([x0, x1], dim=0), y)
x2 = torch.rand((1, 10, 5), dtype=torch.float32)
y = model["forward"](x2, 3)
assert torch.allclose(torch.concat([x0, x1, x2], dim=0), y)
class Layer(nn.Module):
def __init__(self):
pass
def forward(self, x: tuple[nn.Tensor, nn.Tensor]):
assert isinstance(x, tuple)
x0 = x[0]
x1 = x[1]
y0 = nn.add(x0, x1)
y1 = nn.subtract(x0, x1)
return (y0, y1)
forward_spec = {
"forward": {
"x": (
spec.Tensor([10, 5], dtype="float32"),
spec.Tensor([10, 5], dtype="float32"),
)
}
}
mod = Layer()
for debug in [False, True]:
model = mod.jit(spec=forward_spec, debug=debug)
x0 = torch.rand((10, 5), dtype=torch.float32)
x1 = torch.rand((10, 5), dtype=torch.float32)
x = (x0, x1)
y = model["forward"](x)
assert torch.allclose(x0 + x1, y[0])
assert torch.allclose(x0 - x1, y[1])
class Layer(nn.Module):
def __init__(self):
pass
def forward(self, x: list[nn.Tensor]):
assert isinstance(x, list)
x0 = x[0]
x1 = x[1]
y0 = nn.add(x0, x1)
y1 = nn.subtract(x0, x1)
return (y0, y1)
forward_spec = {
"forward": {
"x": [
spec.Tensor([10, 5], dtype="float32"),
spec.Tensor([10, 5], dtype="float32"),
]
}
}
mod = Layer()
for debug in [False, True]:
model = mod.jit(spec=forward_spec, debug=debug)
x0 = torch.rand((10, 5), dtype=torch.float32)
x1 = torch.rand((10, 5), dtype=torch.float32)
x = (x0, x1)
y = model["forward"](x)
assert torch.allclose(x0 + x1, y[0])
assert torch.allclose(x0 - x1, y[1])
from typing import Tuple
class Layer(nn.Module):
def __init__(self):
pass
def forward(self, x: Tuple[nn.Tensor, nn.Tensor, int]):
x0 = x[0]
x1 = x[1]
y0 = nn.add(x0, x1)
y1 = nn.subtract(x0, x1)
y2 = nn.reshape(x0, (5, x[2], 5))
return (y0, y1, y2)
forward_spec = {
"forward": {
"x": (spec.Tensor([10, 5], dtype="float32"), spec.Tensor([10, 5], dtype="float32"), int)
}
}
mod = Layer()
for debug in [False, True]:
model = mod.jit(spec=forward_spec, debug=debug)
x0 = torch.rand((10, 5), dtype=torch.float32)
x1 = torch.rand((10, 5), dtype=torch.float32)
x = (x0, x1, 2)
y0, y1, y2 = model["forward"](x)
assert torch.allclose(x0 + x1, y0)
assert torch.allclose(x0 - x1, y1)
assert torch.allclose(torch.reshape(x0, (5, 2, 5)), y2)