jit

jit#

import torch

import tvm
import tvm.testing
from tvm import tir
from tvm.relax.frontend.nn import spec
from tvm.relax.frontend import nn
class Layer(nn.Module):
    def __init__(self):
        pass

    def forward(self, x: nn.Tensor):
        y = nn.add(x, x)
        return y

forward_spec = {"forward": {"x": spec.Tensor([10, 5], dtype="float32")}}
mod = Layer()

for debug in [False, True]:
    model = mod.jit(spec=forward_spec, debug=debug)

    x = torch.rand((10, 5), dtype=torch.float32)
    y = model["forward"](x)
    assert isinstance(y, torch.Tensor)
    assert torch.allclose(x + x, y)
class Layer(nn.Module):
    def __init__(self):
        pass

    def forward(self, x: nn.Tensor, i: tir.Var):
        y = nn.add(x, x)
        y = nn.reshape(y, (i, 5, 5))
        return y

forward_spec = {"forward": {"x": spec.Tensor([10, 5], dtype="float32"), "i": int}}
mod = Layer()

for debug in [False, True]:
    model = mod.jit(spec=forward_spec, debug=debug)

    x = torch.rand((10, 5), dtype=torch.float32)
    y = model["forward"](x, 2)
    assert isinstance(y, torch.Tensor)
    assert torch.allclose(torch.reshape(x + x, (2, 5, 5)), y)
class Layer(nn.Module):
    def __init__(self):
        self.cache = nn.KVCache(10, [10, 5])

    def forward(self, x: nn.Tensor, total_seq_len: tir.Var):
        self.cache.append(x)
        y = self.cache.view(total_seq_len)
        return y

forward_spec = {
    "forward": {"x": spec.Tensor([1, 10, 5], dtype="float32"), "total_seq_len": int}
}
mod = Layer()

for debug in [False, True]:
    with tvm.transform.PassContext(opt_level=3):
        model = mod.jit(spec=forward_spec, debug=debug)

    x0 = torch.rand((1, 10, 5), dtype=torch.float32)
    y = model["forward"](x0, 1)
    assert isinstance(y, torch.Tensor)
    assert torch.allclose(x0, y)

    x1 = torch.rand((1, 10, 5), dtype=torch.float32)
    y = model["forward"](x1, 2)
    assert torch.allclose(torch.concat([x0, x1], dim=0), y)

    x2 = torch.rand((1, 10, 5), dtype=torch.float32)
    y = model["forward"](x2, 3)
    assert torch.allclose(torch.concat([x0, x1, x2], dim=0), y)
class Layer(nn.Module):
    def __init__(self):
        pass

    def forward(self, x: tuple[nn.Tensor, nn.Tensor]):
        assert isinstance(x, tuple)
        x0 = x[0]
        x1 = x[1]
        y0 = nn.add(x0, x1)
        y1 = nn.subtract(x0, x1)
        return (y0, y1)

forward_spec = {
    "forward": {
        "x": (
            spec.Tensor([10, 5], dtype="float32"),
            spec.Tensor([10, 5], dtype="float32"),
        )
    }
}
mod = Layer()

for debug in [False, True]:
    model = mod.jit(spec=forward_spec, debug=debug)

    x0 = torch.rand((10, 5), dtype=torch.float32)
    x1 = torch.rand((10, 5), dtype=torch.float32)
    x = (x0, x1)
    y = model["forward"](x)

    assert torch.allclose(x0 + x1, y[0])
    assert torch.allclose(x0 - x1, y[1])
class Layer(nn.Module):
    def __init__(self):
        pass

    def forward(self, x: list[nn.Tensor]):
        assert isinstance(x, list)
        x0 = x[0]
        x1 = x[1]
        y0 = nn.add(x0, x1)
        y1 = nn.subtract(x0, x1)
        return (y0, y1)

forward_spec = {
    "forward": {
        "x": [
            spec.Tensor([10, 5], dtype="float32"),
            spec.Tensor([10, 5], dtype="float32"),
        ]
    }
}
mod = Layer()

for debug in [False, True]:
    model = mod.jit(spec=forward_spec, debug=debug)

    x0 = torch.rand((10, 5), dtype=torch.float32)
    x1 = torch.rand((10, 5), dtype=torch.float32)
    x = (x0, x1)
    y = model["forward"](x)

    assert torch.allclose(x0 + x1, y[0])
    assert torch.allclose(x0 - x1, y[1])
from typing import Tuple
class Layer(nn.Module):
    def __init__(self):
        pass

    def forward(self, x: Tuple[nn.Tensor, nn.Tensor, int]):
        x0 = x[0]
        x1 = x[1]
        y0 = nn.add(x0, x1)
        y1 = nn.subtract(x0, x1)
        y2 = nn.reshape(x0, (5, x[2], 5))
        return (y0, y1, y2)

forward_spec = {
    "forward": {
        "x": (spec.Tensor([10, 5], dtype="float32"), spec.Tensor([10, 5], dtype="float32"), int)
    }
}
mod = Layer()

for debug in [False, True]:
    model = mod.jit(spec=forward_spec, debug=debug)

    x0 = torch.rand((10, 5), dtype=torch.float32)
    x1 = torch.rand((10, 5), dtype=torch.float32)
    x = (x0, x1, 2)
    y0, y1, y2 = model["forward"](x)

    assert torch.allclose(x0 + x1, y0)
    assert torch.allclose(x0 - x1, y1)
    assert torch.allclose(torch.reshape(x0, (5, 2, 5)), y2)