# MSC 基础 pass

In [1]:
import set_env

In [2]:
import tvm.testing
from tvm.relax.frontend.torch import from_fx
from tvm.relax import PyExprVisitor

from tvm.relay import testing
from tvm.relay.expr_functor import ExprVisitor
from tvm.relay.build_module import bind_params_by_name

from tvm.contrib.msc.core import transform as msc_transform
from tvm.contrib.msc.core import utils as msc_utils

## 为 `relax` 测试 `SetExprLayout`

In [3]:
# pylint: disable=import-outside-toplevel
try:
    import torch
    import torchvision
    from torch import fx
except:  # pylint: disable=bare-except
    print("please install pytorch python package")

In [4]:
class RelaxLayoutChecker(PyExprVisitor):
    """检查是否设置了 `name` 作为 `span` 属性。"""

    def check(self, expr):
        self._missing_exprs = []
        if isinstance(expr, tvm.relax.Expr):
            self.visit_expr(expr)
        elif isinstance(expr, tvm.relax.BindingBlock):
            self.visit_binding_block(expr)
        assert len(self._missing_exprs) == 0, f"Missing {len(self._missing_exprs)} layouts"

    def visit_var_binding_(self, binding) -> None:
        super().visit_var_binding_(binding)
        if not msc_utils.get_expr_layout(binding.value):
            self._missing_exprs.append(binding.value)

    def visit_constant_(self, op) -> None:
        super().visit_constant_(op)
        if not msc_utils.get_expr_layout(op):
            self._missing_exprs.append(op)

In [5]:
torch_model = torchvision.models.resnet50()
graph_model = fx.symbolic_trace(torch_model)
input_info = [([1, 3, 224, 224], "float32")]
with torch.no_grad():
    mod = from_fx(graph_model, input_info)
mod = msc_transform.SetExprLayout()(mod)
RelaxLayoutChecker().check(mod)

## 为 `relay` 测试 `SetExprName`

In [6]:
class RelayNameChecker(ExprVisitor):
    """Check if name as span attribute is setted."""

    def check(self, expr):
        self._missing_exprs = []
        super().visit(expr)
        assert len(self._missing_exprs) == 0, "Missing {} names".format(
            len(self._missing_exprs)
        )

    def visit_constant(self, expr):
        super().visit_constant(expr)
        if not msc_utils.get_expr_name(expr):
            self._missing_exprs.append(expr)

    def visit_call(self, expr):
        super().visit_call(expr)
        if not msc_utils.get_expr_name(expr):
            self._missing_exprs.append(expr)

mod, params = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype="float32")
mod["main"] = bind_params_by_name(mod["main"], params)
mod = msc_transform.SetExprName(as_relax=False)(mod)
RelayNameChecker().check(mod["main"])

## 为 `relax` 测试 `SetExprName `

In [7]:
class RelaxNameChecker(PyExprVisitor):
    """Check if name as span attribute is setted."""

    def check(self, expr):
        self._missing_exprs = []
        if isinstance(expr, tvm.relax.Expr):
            self.visit_expr(expr)
        elif isinstance(expr, tvm.relax.BindingBlock):
            self.visit_binding_block(expr)
        assert len(self._missing_exprs) == 0, "Missing {} names".format(
            len(self._missing_exprs)
        )

    def visit_var_binding_(self, binding) -> None:
        super().visit_var_binding_(binding)
        if not msc_utils.get_expr_name(binding.value):
            self._missing_exprs.append(binding.value)

    def visit_constant_(self, op) -> None:
        super().visit_constant_(op)
        if not msc_utils.get_expr_name(op):
            self._missing_exprs.append(op)

torch_model = torchvision.models.resnet50()
graph_model = fx.symbolic_trace(torch_model)
input_info = [([1, 3, 224, 224], "float32")]
with torch.no_grad():
    mod = from_fx(graph_model, input_info)
mod = msc_transform.SetExprName()(mod)
RelaxNameChecker().check(mod)

In [None]:
from torchvision.models import vgg