def test_mutator_naming_basic():
class Module0(nn.Module):
def __init__(self) -> None:
super().__init__()
self.param0 = nn.Parameter((32, 128), "float64")
class Module1(nn.Module):
def __init__(self) -> None:
super().__init__()
self.mod0 = Module0()
self.param1 = nn.Parameter((32, 128), "float32")
class Module2(nn.Module):
def __init__(self) -> None:
super().__init__()
self.mod1 = Module1()
self.param2 = nn.Parameter((32, 128), "float16")
class Module3(nn.Module):
def __init__(self) -> None:
super().__init__()
self.mod2 = Module2()
self.param3 = nn.Parameter((32, 128), "float8")
class Mutator(nn.Mutator):
def visit_param(self, name: str, node: nn.Parameter) -> Any:
if node.dtype == "float8":
assert name == "mod3.param3"
return node
elif node.dtype == "float16":
assert name == "mod3.mod2.param2"
return node
elif node.dtype == "float32":
assert name == "mod3.mod2.mod1.param1"
return node
elif node.dtype == "float64":
assert name == "mod3.mod2.mod1.mod0.param0"
return node
mod3 = Module3()
mutator = Mutator()
mutator.visit("mod3", mod3)
def test_mutator_naming_modulelist():
class Module(nn.Module):
def __init__(self, dtype) -> None:
super().__init__()
self.param = nn.Parameter((32, 128), dtype)
class Mutator(nn.Mutator):
def visit_param(self, name: str, node: nn.Parameter) -> Any:
if node.dtype == "float64":
assert name == "mod_list.0.0.param"
return node
elif node.dtype == "float32":
assert name == "mod_list.0.1.param"
return node
elif node.dtype == "float16":
assert name == "mod_list.1.0.param"
return node
elif node.dtype == "float8":
assert name == "mod_list.1.1.param"
return node
mod_list = nn.ModuleList(
[
nn.ModuleList([Module("float64"), Module("float32")]),
nn.ModuleList([Module("float16"), Module("float8")]),
]
)
mutator = Mutator()
mutator.visit("mod_list", mod_list)
def test_mutator_module():
class SubModule1(nn.Module):
def __init__(self) -> None:
super().__init__()
class SubModule2(nn.Module):
def __init__(self) -> None:
super().__init__()
class Module(nn.Module):
def __init__(self) -> None:
super().__init__()
self.mod = SubModule1()
class Mutator(nn.Mutator):
def visit_module(self, name: str, node: nn.Module) -> Any:
if isinstance(node, SubModule1):
return SubModule2()
else:
return node
mutator = Mutator()
module = Module()
assert isinstance(module.mod, SubModule1)
module = mutator.visit("", module)
assert isinstance(module.mod, SubModule2)
def test_mutator_modulelist():
class Module1(nn.Module):
def __init__(self) -> None:
super().__init__()
class Module2(nn.Module):
def __init__(self) -> None:
super().__init__()
class Module3(nn.Module):
def __init__(self) -> None:
super().__init__()
class Mutator(nn.Mutator):
def visit_module(self, name: str, node: nn.Module) -> Any:
if isinstance(node, Module3):
return Module1()
else:
return node
mutator = Mutator()
module_list = nn.ModuleList([Module1(), Module2(), Module3()])
assert isinstance(module_list[0], Module1)
assert isinstance(module_list[1], Module2)
assert isinstance(module_list[2], Module3)
module_list = mutator.visit("", module_list)
print(module_list[2])
assert isinstance(module_list[0], Module1)
assert isinstance(module_list[1], Module2)
assert isinstance(module_list[2], Module1)
def test_mutator_effect():
class Effect1(nn.Effect):
pass
class Effect2(nn.Effect):
pass
class Module(nn.Module):
def __init__(self) -> None:
super().__init__()
self.effect = Effect1()
class Mutator(nn.Mutator):
def visit_effect(self, name: str, node: nn.Effect) -> Any:
if isinstance(node, Effect1):
return Effect2()
mutator = Mutator()
module = Module()
assert isinstance(module.effect, Effect1)
module = mutator.visit("", module)
assert isinstance(module.effect, Effect2)
def test_mutator_param():
class Module(nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = nn.Parameter((128, 64), "float16")
class Mutator(nn.Mutator):
def visit_param(self, name: str, node: nn.Parameter) -> Any:
if node.dtype == "float16":
return nn.Parameter(node.shape, "float32")
mutator = Mutator()
module = Module()
assert module.weight.dtype == "float16"
module = mutator.visit("", module)
assert module.weight.dtype == "float32"
def test_mutator_recursively():
class SubModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = nn.Parameter((128, 64), "float16")
class Module(nn.Module):
def __init__(self) -> None:
super().__init__()
self.mod = SubModule()
class Mutator(nn.Mutator):
def visit_param(self, name: str, node: nn.Parameter) -> Any:
if node.dtype == "float16":
return nn.Parameter(node.shape, "float32")
mutator = Mutator()
module = Module()
assert module.mod.weight.dtype == "float16"
module = mutator.visit("", module)
assert module.mod.weight.dtype == "float32"