Mutator

Mutator#

from typing import Any

import tvm
from tvm.relax.frontend import nn
def test_mutator_naming_basic():
    class Module0(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.param0 = nn.Parameter((32, 128), "float64")

    class Module1(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.mod0 = Module0()
            self.param1 = nn.Parameter((32, 128), "float32")

    class Module2(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.mod1 = Module1()
            self.param2 = nn.Parameter((32, 128), "float16")

    class Module3(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.mod2 = Module2()
            self.param3 = nn.Parameter((32, 128), "float8")

    class Mutator(nn.Mutator):
        def visit_param(self, name: str, node: nn.Parameter) -> Any:
            if node.dtype == "float8":
                assert name == "mod3.param3"
                return node
            elif node.dtype == "float16":
                assert name == "mod3.mod2.param2"
                return node
            elif node.dtype == "float32":
                assert name == "mod3.mod2.mod1.param1"
                return node
            elif node.dtype == "float64":
                assert name == "mod3.mod2.mod1.mod0.param0"
                return node

    mod3 = Module3()
    mutator = Mutator()
    mutator.visit("mod3", mod3)


def test_mutator_naming_modulelist():
    class Module(nn.Module):
        def __init__(self, dtype) -> None:
            super().__init__()
            self.param = nn.Parameter((32, 128), dtype)

    class Mutator(nn.Mutator):
        def visit_param(self, name: str, node: nn.Parameter) -> Any:
            if node.dtype == "float64":
                assert name == "mod_list.0.0.param"
                return node
            elif node.dtype == "float32":
                assert name == "mod_list.0.1.param"
                return node
            elif node.dtype == "float16":
                assert name == "mod_list.1.0.param"
                return node
            elif node.dtype == "float8":
                assert name == "mod_list.1.1.param"
                return node

    mod_list = nn.ModuleList(
        [
            nn.ModuleList([Module("float64"), Module("float32")]),
            nn.ModuleList([Module("float16"), Module("float8")]),
        ]
    )
    mutator = Mutator()
    mutator.visit("mod_list", mod_list)


def test_mutator_module():
    class SubModule1(nn.Module):
        def __init__(self) -> None:
            super().__init__()

    class SubModule2(nn.Module):
        def __init__(self) -> None:
            super().__init__()

    class Module(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.mod = SubModule1()

    class Mutator(nn.Mutator):
        def visit_module(self, name: str, node: nn.Module) -> Any:
            if isinstance(node, SubModule1):
                return SubModule2()
            else:
                return node

    mutator = Mutator()
    module = Module()
    assert isinstance(module.mod, SubModule1)
    module = mutator.visit("", module)
    assert isinstance(module.mod, SubModule2)


def test_mutator_modulelist():
    class Module1(nn.Module):
        def __init__(self) -> None:
            super().__init__()

    class Module2(nn.Module):
        def __init__(self) -> None:
            super().__init__()

    class Module3(nn.Module):
        def __init__(self) -> None:
            super().__init__()

    class Mutator(nn.Mutator):
        def visit_module(self, name: str, node: nn.Module) -> Any:
            if isinstance(node, Module3):
                return Module1()
            else:
                return node

    mutator = Mutator()
    module_list = nn.ModuleList([Module1(), Module2(), Module3()])
    assert isinstance(module_list[0], Module1)
    assert isinstance(module_list[1], Module2)
    assert isinstance(module_list[2], Module3)
    module_list = mutator.visit("", module_list)
    print(module_list[2])
    assert isinstance(module_list[0], Module1)
    assert isinstance(module_list[1], Module2)
    assert isinstance(module_list[2], Module1)


def test_mutator_effect():
    class Effect1(nn.Effect):
        pass

    class Effect2(nn.Effect):
        pass

    class Module(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.effect = Effect1()

    class Mutator(nn.Mutator):
        def visit_effect(self, name: str, node: nn.Effect) -> Any:
            if isinstance(node, Effect1):
                return Effect2()

    mutator = Mutator()
    module = Module()
    assert isinstance(module.effect, Effect1)
    module = mutator.visit("", module)
    assert isinstance(module.effect, Effect2)


def test_mutator_param():
    class Module(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.weight = nn.Parameter((128, 64), "float16")

    class Mutator(nn.Mutator):
        def visit_param(self, name: str, node: nn.Parameter) -> Any:
            if node.dtype == "float16":
                return nn.Parameter(node.shape, "float32")

    mutator = Mutator()
    module = Module()
    assert module.weight.dtype == "float16"
    module = mutator.visit("", module)
    assert module.weight.dtype == "float32"


def test_mutator_recursively():
    class SubModule(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.weight = nn.Parameter((128, 64), "float16")

    class Module(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.mod = SubModule()

    class Mutator(nn.Mutator):
        def visit_param(self, name: str, node: nn.Parameter) -> Any:
            if node.dtype == "float16":
                return nn.Parameter(node.shape, "float32")

    mutator = Mutator()
    module = Module()
    assert module.mod.weight.dtype == "float16"
    module = mutator.visit("", module)
    assert module.mod.weight.dtype == "float32"