计算图#

MSCGraph 是 MSC 的核心,它对编译器的作用类似于 IR(中间表示)。MSCGraph 是 Relax.Function/Relay.Function 的 DAG(有向无环图)格式。它可以在 Relax/Relay 之间转换。构建 MSCGraph 的目标是使压缩算法的开发和权重管理(这在训练时很重要)更加容易。如果所选的运行时目标不支持所有的 Calls,那么一个 Relax/Relay 模块将拥有多个 MSCGraphs。MSCGraph 存在于编译过程各个阶段,用于管理模型计算信息。

%cd ..
import set_env
/media/pc/data/lxw/ai/tvm-book/doc/tutorials
import numpy as np
import torch
from torch import fx
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 6, 1, bias=False)
        self.relu = torch.nn.ReLU()

    def forward(self, data):
        x = self.conv(data)
        return self.relu(x)

# Give the input shape and data type
input_info = [((1, 3, 4, 4), "float32")]

# Convert the model to IRModule
with torch.no_grad():
    torch_fx_model = fx.symbolic_trace(M())
    mod = from_fx(torch_fx_model, input_info, keep_params_as_input=False)

# mod, params = relax.frontend.detach_params(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(inp_0: R.Tensor((1, 3, 4, 4), dtype="float32")) -> R.Tensor((1, 6, 4, 4), dtype="float32"):
        with R.dataflow():
            lv: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(inp_0, metadata["relax.expr.Constant"][0], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.relu(lv)
            gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv1
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.
from tvm.contrib.msc.core.frontend import translate

relax 构建 msc 计算图:

graph, weights = translate.from_relax(mod)
print(graph)
main <INPUTS: inp_0:0| OUTPUTS: relu:0>
ID_0 inp_0 <PARENTS: | CHILDERN: conv2d>
  OUT: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
  OPTYPE: input

ID_1 conv2d <PARENTS: inp_0| CHILDERN: relu>
  IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
  OUT: conv2d:0<1,6,4,4|float32|NCHW>
  OPTYPE: nn.conv2d
  SCOPE: block
  ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW 
  WEIGHTS: 
    weight: const<6,3,1,1|float32|OIHW>

ID_2 relu <PARENTS: conv2d| CHILDERN: >
  IN: conv2d:0<1,6,4,4|float32|NCHW>
  OUT: relu:0(relu)<1,6,4,4|float32|NCHW>
  OPTYPE: nn.relu
  SCOPE: block
type(graph)
tvm.contrib.msc.core.ir.graph.MSCGraph

导出序列化文件以加载计算图:

print(graph.to_json())
{
  "name": "main", 
  "inputs": [
    "inp_0:0"
  ], 
  "outputs": [
    "relu:0"
  ], 
  "nodes": [
    {
      "index": 0, 
      "name": "inp_0", 
      "shared_ref": "", 
      "optype": "input", 
      "parents": [], 
      "inputs": [], 
      "outputs": [
        {
          "name": "inp_0:0", 
          "alias": "inp_0", 
          "dtype": "float32", 
          "layout": "NCHW", 
          "shape": [1, 3, 4, 4]
        }
      ], 
      "attrs": {}, 
      "weights": {}
    }, 
    {
      "index": 1, 
      "name": "conv2d", 
      "shared_ref": "", 
      "optype": "nn.conv2d", 
      "parents": [
        "inp_0"
      ], 
      "inputs": [
        "inp_0:0"
      ], 
      "outputs": [
        {
          "name": "conv2d:0", 
          "alias": "", 
          "dtype": "float32", 
          "layout": "NCHW", 
          "shape": [1, 6, 4, 4]
        }
      ], 
      "attrs": {
        "out_layout": "NCHW", 
        "data_layout": "NCHW", 
        "padding": "0,0,0,0", 
        "groups": "1", 
        "kernel_layout": "OIHW", 
        "strides": "1,1", 
        "dilation": "1,1", 
        "out_dtype": "float32"
      }, 
      "weights": {"weight": {
          "name": "const", 
          "alias": "", 
          "dtype": "float32", 
          "layout": "OIHW", 
          "shape": [6, 3, 1, 1]}}
    }, 
    {
      "index": 2, 
      "name": "relu", 
      "shared_ref": "", 
      "optype": "nn.relu", 
      "parents": [
        "conv2d"
      ], 
      "inputs": [
        "conv2d:0"
      ], 
      "outputs": [
        {
          "name": "relu:0", 
          "alias": "relu", 
          "dtype": "float32", 
          "layout": "NCHW", 
          "shape": [1, 6, 4, 4]
        }
      ], 
      "attrs": {}, 
      "weights": {}
    }
  ]
}

导出用于可视化的 prototxt 文件:

print(graph.visualize(".temp/graph.prototxt"))
name: "main"
layer {
  name: "inp_0"
  type: "input"
  top: "inp_0"
  layer_param {
    idx: 0
    output_0: "inp_0:0(inp_0)<1,3,4,4|float32|NCHW>"
  }
}
layer {
  name: "conv2d"
  type: "nn_conv2d"
  top: "conv2d"
  bottom: "inp_0"
  layer_param {
    out_layout: "NCHW"
    out_dtype: "float32"
    groups: "1"
    kernel_layout: "OIHW"
    param_weight: "const<6,3,1,1|float32|OIHW>"
    strides: "1,1"
    idx: 1
    padding: "0,0,0,0"
    data_layout: "NCHW"
    dilation: "1,1"
    output_0: "conv2d:0<1,6,4,4|float32|NCHW>"
    input_0: "inp_0:0(inp_0)<1,3,4,4|float32|NCHW>"
  }
}
layer {
  name: "relu"
  type: "nn_relu"
  top: "relu"
  bottom: "conv2d"
  layer_param {
    idx: 2
    input_0: "conv2d:0<1,6,4,4|float32|NCHW>"
    output_0: "relu:0(relu)<1,6,4,4|float32|NCHW>"
  }
}

通过将 relax/relay 结构转换成 MSCGraph,可以使用 DAG 风格的方法进行节点的查找与遍历:

for node in graph.get_nodes():
    if node.optype == "nn.conv2d":
        print(f"conv2d 节点 {node}")

for i in graph.get_inputs():
    print(f"input 节点 {i}")

node = graph.find_node("conv2d")
print(f"核心 conv 节点 {node}")
for p in node.parents:
    print(f"父节点 {p}" + str(p))
for c in node.children:
    print(f"子节点 {c}")
conv2d 节点 ID_1 conv2d <PARENTS: inp_0| CHILDERN: relu>
  IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
  OUT: conv2d:0<1,6,4,4|float32|NCHW>
  OPTYPE: nn.conv2d
  SCOPE: block
  ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW 
  WEIGHTS: 
    weight: const<6,3,1,1|float32|OIHW>

input 节点 inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
核心 conv 节点 ID_1 conv2d <PARENTS: inp_0| CHILDERN: relu>
  IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
  OUT: conv2d:0<1,6,4,4|float32|NCHW>
  OPTYPE: nn.conv2d
  SCOPE: block
  ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW 
  WEIGHTS: 
    weight: const<6,3,1,1|float32|OIHW>

父节点 ID_0 inp_0 <PARENTS: | CHILDERN: conv2d>
  OUT: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
  OPTYPE: input
ID_0 inp_0 <PARENTS: | CHILDERN: conv2d>
  OUT: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
  OPTYPE: input

子节点 ID_2 relu <PARENTS: conv2d| CHILDERN: >
  IN: conv2d:0<1,6,4,4|float32|NCHW>
  OUT: relu:0(relu)<1,6,4,4|float32|NCHW>
  OPTYPE: nn.relu
  SCOPE: block

MSCGraph 的基本结构和 onnx 类似:一个 MSCGraph 中包含多个 MSCJoint(计算节点)和 MSCTensor(数据)。

MSCJoint#

作用等同于 relax.Exprtorch.nn.Moduleonnx.Node 等,即一个计算图中计算逻辑的最小表达单元。一个 MSCJoint 对应一个 relax.ExprFunction(如果使用了一些 pattern 做分图,例如 conv2d+bias fuseconv2d_bais),MSCJointExpr 的区别在于 MSCJoint 包含更多拓扑信息,不仅仅有 relax.Call.args 对应的 inputs,也包括 parentschildren 以及 outputs ,可以更方便的获取获取拓扑关系。以下为一个 MSCJoint 的描述:

node = graph.find_node("conv2d")
type(node), node
(tvm.contrib.msc.core.ir.graph.MSCJoint,
 ID_1 conv2d <PARENTS: inp_0| CHILDERN: relu>
   IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
   OUT: conv2d:0<1,6,4,4|float32|NCHW>
   OPTYPE: nn.conv2d
   SCOPE: block
   ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW 
   WEIGHTS: 
     weight: const<6,3,1,1|float32|OIHW>)

其中包含几个重要成员:

  • ID(ID_1):节点的 index,也是遍历计算图的时候使用的排序数值

  • NAME(conv2d):节点名称,每个节点有唯一的 name,用来查找节点

  • PARENTS/CHILDREN(inp_0/batch_norm):节点的拓扑关系

  • ATTRS:节点属性,为了向后兼容都是用 string 类型保存,codegen 的时候会 cast 成对应类型

  • IN/OUT(relu):节点输入输出,每个节点有1到多个outputs,每个output都是一个 MSCTensor;Inputs 则为 parents 的 outputs 的引用

  • WEIGHTS:节点weights,每个weights是一个ref:MSCTensor pair,ref表示weight类型,如"weight",“bias”,"gamma"等,定义ref的原因主要是考虑到模型压缩针对不同的weight类型操作不同,故需要对weight进行分类

MSCTensor#

这种数据结构在relax中并未体现,可以理解为NDArray的抽象。MSCTensor描述每个节点outputs以及weights的信息。通过MSCTensor可以查找到producer和consumers,方便对tensor进行操作时获取上下文。MSCTensor格式设计参照了tensorflow的tensor。以下为一个MSCTensor的描述,包含几个重要属性:

type(node.outputs[0]), node.outputs
(tvm.contrib.msc.core.ir.graph.MSCTensor, [conv2d:0<1,6,4,4|float32|NCHW>])
  • Name(conv2d:0):tensor名字,格式为节点名称:数字格式,同tensorflow相同。以这种标记可以查找tensor的producer

  • Shape(1,6,4,4):tensor的shape,动态维度用-1表示

  • Dtype(float32):tensor的数据类型

  • Layout(NCHW):MSC中新加的属性,tensor的数据排布格式。这一属性在剪枝过程中比较重要,对一些计算过程算子的优化也起到参考作用

output = node.output_at(0)

print(f"producer {graph.find_producer(output.name)}")
for c in graph.find_consumers(output.name):
    print(f"has consumer {c}")
producer ID_1 conv2d <PARENTS: inp_0| CHILDERN: relu>
  IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
  OUT: conv2d:0<1,6,4,4|float32|NCHW>
  OPTYPE: nn.conv2d
  SCOPE: block
  ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW 
  WEIGHTS: 
    weight: const<6,3,1,1|float32|OIHW>

has consumer ID_2 relu <PARENTS: conv2d| CHILDERN: >
  IN: conv2d:0<1,6,4,4|float32|NCHW>
  OUT: relu:0(relu)<1,6,4,4|float32|NCHW>
  OPTYPE: nn.relu
  SCOPE: block

Codegen 模块#

codegen 模块和 MSCGraph 一起使用,用于将 MSCGraph 转译成 Python 脚本或 C++ 脚本。

from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen
from tvm.contrib.msc.framework.torch import codegen as torch_codegen
from tvm.contrib.msc.framework.torch.frontend import translate as torch_translate
from tvm.contrib.msc.core import utils as msc_utils

build_folder = msc_utils.msc_dir(".temp/tvm_test")
mod = tvm_codegen.to_relax(graph, weights, build_folder=build_folder)
mod.show()

build_folder = msc_utils.msc_dir(".temp/torch_test")
graph, weights = torch_translate.from_torch(torch_fx_model, input_info, via_relax=True)
model = torch_codegen.to_torch(graph, weights, build_folder=build_folder)
model
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(inp_0: R.Tensor((1, 3, 4, 4), dtype="float32")) -> R.Tensor((1, 6, 4, 4), dtype="float32"):
        with R.dataflow():
            conv2d: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(inp_0, metadata["relax.expr.Constant"][0], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            relu: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.relu(conv2d)
            R.output(relu)
        return relu

# Metadata omitted. Use show_meta=True in script() method to show it.
/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(folder.relpath(graph.name + ".pth"))
main(
  (conv2d): Conv2d(3, 6, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (relu): ReLU()
)

产生的代码片段:

# .temp/tvm_test/main.py
import os
import numpy as np
from typing import List, Dict, Any
import tvm
from tvm.contrib.msc.core import utils as msc_utils
from tvm import relax

# Define the helpers
def load_data(name: str, shape: List[int], dtype: str) -> np.ndarray:
  path = os.path.join("baseline", name + ".bin")
  if os.path.isfile(path):
    data = np.fromfile(path, dtype=dtype).reshape(shape)
  else:
    data = np.ones((shape)).astype(dtype)
  return data


# Define the graph
def main(res_0: relax.Var) -> tvm.IRModule:
  inputs = [res_0]
  # Define the weights
  weight_1 = relax.Var("const", relax.TensorStructInfo([6, 3, 1, 1], "float32"))
  inputs.append(weight_1)
  # Define the module
  block_builder = relax.BlockBuilder()
  with block_builder.function(name="main", params=inputs.copy()):
    # conv2d(nn.conv2d): <res_0> -> <res_1>
    res_1 = relax.op.nn.conv2d(res_0, weight_1, strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
    res_1 = block_builder.emit(res_1, name_hint="conv2d")
    # relu(nn.relu): <res_1> -> <res_2>
    res_2 = relax.op.nn.relu(res_1)
    res_2 = block_builder.emit(res_2, name_hint="relu")
    # Emit the outputs
    block_builder.emit_func_output(res_2)
  mod = block_builder.finalize()
  return mod


# Define the test
if __name__ == "__main__":
  # Prepare test datas
  inputs = {}
  golden = {}
  inputs["inp_0"] = load_data("inp_0", [1, 3, 4, 4], "float32")
  golden["relu"] = load_data("relu", [1, 6, 4, 4], "float32")
  # Build and inference the graph
  res_0 = relax.Var("inp_0", relax.TensorStructInfo([1, 3, 4, 4], "float32"))
  # Build Module
  mod = main(res_0)
  # Load weights
  with open("main_params.bin", "rb") as f:
    params = tvm.runtime.load_param_dict(f.read())
  bind_params = tvm.relax.transform.BindParams("main", params)
  mod = bind_params(mod)
  target = tvm.target.Target("llvm")
  mod = tvm.relax.transform.LegalizeOps()(mod)
  with tvm.transform.PassContext(opt_level=3):
    ex = relax.build(mod, target)
    vm = relax.VirtualMachine(ex, tvm.cpu())
  f_main = vm["main"]
  outputs = f_main(inputs["inp_0"])
  msc_utils.compare_arrays(golden, outputs, verbose="detail")
# .temp/torch_test/main.py
import os
import numpy as np
from typing import List, Dict, Any
import tvm
from tvm.contrib.msc.core import utils as msc_utils
import torch
from torch import nn
from torch.nn import functional

# Define the helpers
def load_data(name: str, shape: List[int], dtype: str) -> np.ndarray:
  path = os.path.join("baseline", name + ".bin")
  if os.path.isfile(path):
    data = np.fromfile(path, dtype=dtype).reshape(shape)
  else:
    data = np.ones((shape)).astype(dtype)
  return data


# Define the graph
class main(torch.nn.Module):
  def __init__(self: torch.nn.Module) -> torch.nn.Module:
    super(main, self).__init__()
    # conv2d(nn.conv2d): <res_0> -> <res_1>
    self.conv2d = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=[1, 1], stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1, bias=False)
    # relu(nn.relu): <res_1> -> <res_2>
    self.relu = nn.ReLU()

  def forward(self: torch.nn.Module, res_0: torch.Tensor) -> List[torch.Tensor]:
    # conv2d(nn.conv2d): <res_0> -> <res_1>
    res_1 = self.conv2d(res_0)
    # relu(nn.relu): <res_1> -> <res_2>
    res_2 = self.relu(res_1)
    outputs = res_2
    return outputs


# Define the test
if __name__ == "__main__":
  # Prepare test datas
  inputs = {}
  golden = {}
  inputs["inp_0"] = load_data("inp_0", [1, 3, 4, 4], "float32")
  golden["relu"] = load_data("relu", [1, 6, 4, 4], "float32")
  # Build and inference the graph
  # Build Model
  model = main()
  # Load weights
  weights = torch.load("main.pth")
  model.load_state_dict(weights)
  res_0 = torch.from_numpy(inputs["inp_0"])
  outputs = model(res_0)
  msc_utils.compare_arrays(golden, outputs, verbose="detail")

总结#

MSC架构设计用于解决多个系统协同优化的问题,技术实现是使用编译器思想,基于统一的中间描述在不同系统中构建计算图。同时MSC将计算逻辑和压缩算法分开,通过对压缩算法的解耦搭建通用的模型压缩平台。

Parser + MSCGraph + Codegen构成了MSC中信息在不同框架之间传递的通路,核心部分是MSCGraph作为计算信息的载体,MSCGraph的设计类似常见的DAG类型的IR格式,包括表示结算节点的MSCJoint和表示数据的MSCTensor。