MSCGraph 简介

MSCGraph 简介#

%cd ..
from pathlib import Path

temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
/media/pc/data/lxw/ai/tvm-book/doc/read/msc

MSCGraph 是 MSC 的核心,它对编译器的作用类似于 IR(中间表示)。MSCGraph 是 Relax 的 DAG(有向无环图)格式。构建 MSCGraph 的目标是使压缩算法的开发和权重管理(这在训练时很重要)更加容易。如果所选的运行时目标不支持所有的 Calls,那么 Relax/Relay 模块将拥有多个 MSCGraphs。MSCGraph 存在于编译过程各个阶段,用于管理模型计算信息。

%%file graph/model.py
import numpy as np
import torch
from torch import fx
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 6, 1, bias=False)
        self.relu = torch.nn.ReLU()

    def forward(self, data):
        x = self.conv(data)
        return self.relu(x)

def get_model(input_info):
    # 转换前端模型为 IRModule
    with torch.no_grad():
        torch_fx_model = fx.symbolic_trace(M())
        mod = from_fx(torch_fx_model, input_info, keep_params_as_input=False)
    # mod, params = relax.frontend.detach_params(mod)
    return mod, torch_fx_model
Overwriting graph/model.py

构建 MSC 计算图#

from tvm.contrib.msc.core.frontend import translate

relax 构建 msc 计算图:

from graph.model import get_model
input_info = [((1, 3, 4, 4), "float32")] # 给定输入 shape 和数据类型
mod, _ = get_model(input_info)
mod.show()
graph, weights = translate.from_relax(mod)
print(graph)
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(inp_0: R.Tensor((1, 3, 4, 4), dtype="float32")) -> R.Tensor((1, 6, 4, 4), dtype="float32"):
        with R.dataflow():
            lv: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(inp_0, metadata["relax.expr.Constant"][0], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.relu(lv)
            gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv1
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.
main <INPUTS: inp_0:0| OUTPUTS: relu:0>
N_0 inp_0 <PARENTS: | CHILDERN: conv2d>
  OUT: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
  OPTYPE: input

N_1 conv2d <PARENTS: inp_0| CHILDERN: relu>
  IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
  OUT: conv2d:0<1,6,4,4|float32|NCHW>
  OPTYPE: nn.conv2d
  SCOPE: block
  ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW 
  WEIGHTS: 
    weight: const<6,3,1,1|float32|OIHW>

N_2 relu <PARENTS: conv2d| CHILDERN: >
  IN: conv2d:0<1,6,4,4|float32|NCHW>
  OUT: relu:0(relu)<1,6,4,4|float32|NCHW>
  OPTYPE: nn.relu
  SCOPE: block
type(graph)
tvm.contrib.msc.core.ir.graph.MSCGraph

导出序列化文件以加载计算图:

print(graph.to_json())
Hide code cell output
{
  "name": "main", 
  "inputs": [
    "inp_0:0"
  ], 
  "outputs": [
    "relu:0"
  ], 
  "nodes": [
    {
      "index": 0, 
      "name": "inp_0", 
      "shared_ref": "", 
      "optype": "input", 
      "parents": [], 
      "inputs": [], 
      "outputs": [
        {
          "name": "inp_0:0", 
          "alias": "inp_0", 
          "dtype": "float32", 
          "layout": "NCHW", 
          "shape": [1, 3, 4, 4], 
          "prims": []
        }
      ], 
      "attrs": {}, 
      "weights": {}
    }, 
    {
      "index": 1, 
      "name": "conv2d", 
      "shared_ref": "", 
      "optype": "nn.conv2d", 
      "parents": [
        "inp_0"
      ], 
      "inputs": [
        "inp_0:0"
      ], 
      "outputs": [
        {
          "name": "conv2d:0", 
          "alias": "", 
          "dtype": "float32", 
          "layout": "NCHW", 
          "shape": [1, 6, 4, 4], 
          "prims": []
        }
      ], 
      "attrs": {
        "out_layout": "NCHW", 
        "data_layout": "NCHW", 
        "padding": "0,0,0,0", 
        "groups": "1", 
        "kernel_layout": "OIHW", 
        "strides": "1,1", 
        "dilation": "1,1", 
        "out_dtype": "float32"
      }, 
      "weights": {"weight": {
          "name": "const", 
          "alias": "", 
          "dtype": "float32", 
          "layout": "OIHW", 
          "shape": [6, 3, 1, 1], 
          "prims": []}}
    }, 
    {
      "index": 2, 
      "name": "relu", 
      "shared_ref": "", 
      "optype": "nn.relu", 
      "parents": [
        "conv2d"
      ], 
      "inputs": [
        "conv2d:0"
      ], 
      "outputs": [
        {
          "name": "relu:0", 
          "alias": "relu", 
          "dtype": "float32", 
          "layout": "NCHW", 
          "shape": [1, 6, 4, 4], 
          "prims": []
        }
      ], 
      "attrs": {}, 
      "weights": {}
    }
  ], 
  "prims": []
}

导出用于可视化的 prototxt 文件:

print(graph.visualize(f"{temp_dir}/graph.prototxt"))
Hide code cell output
name: "main"
layer {
  name: "inp_0"
  type: "input"
  top: "inp_0"
  layer_param {
    idx: 0
    output_0: "inp_0:0(inp_0)<1,3,4,4|float32|NCHW>"
  }
}
layer {
  name: "conv2d"
  type: "nn_conv2d"
  top: "conv2d"
  bottom: "inp_0"
  layer_param {
    out_layout: "NCHW"
    out_dtype: "float32"
    groups: "1"
    kernel_layout: "OIHW"
    param_weight: "const<6,3,1,1|float32|OIHW>"
    strides: "1,1"
    idx: 1
    padding: "0,0,0,0"
    data_layout: "NCHW"
    dilation: "1,1"
    output_0: "conv2d:0<1,6,4,4|float32|NCHW>"
    input_0: "inp_0:0(inp_0)<1,3,4,4|float32|NCHW>"
  }
}
layer {
  name: "relu"
  type: "nn_relu"
  top: "relu"
  bottom: "conv2d"
  layer_param {
    idx: 2
    input_0: "conv2d:0<1,6,4,4|float32|NCHW>"
    output_0: "relu:0(relu)<1,6,4,4|float32|NCHW>"
  }
}

通过将 relax/relay 结构转换成 MSCGraph,可以使用 DAG 风格的方法进行节点的查找与遍历:

for node in graph.get_nodes():
    if node.optype == "nn.conv2d":
        print(f"conv2d 节点 {node}")

for i in graph.get_inputs():
    print(f"input 节点 {i}")

node = graph.find_node("conv2d")
print(f"核心 conv 节点 {node}")
for p in node.parents:
    print(f"父节点 {p}" + str(p))
for c in node.children:
    print(f"子节点 {c}")
Hide code cell output
conv2d 节点 N_1 conv2d <PARENTS: inp_0| CHILDERN: relu>
  IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
  OUT: conv2d:0<1,6,4,4|float32|NCHW>
  OPTYPE: nn.conv2d
  SCOPE: block
  ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW 
  WEIGHTS: 
    weight: const<6,3,1,1|float32|OIHW>

input 节点 inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
核心 conv 节点 N_1 conv2d <PARENTS: inp_0| CHILDERN: relu>
  IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
  OUT: conv2d:0<1,6,4,4|float32|NCHW>
  OPTYPE: nn.conv2d
  SCOPE: block
  ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW 
  WEIGHTS: 
    weight: const<6,3,1,1|float32|OIHW>

父节点 N_0 inp_0 <PARENTS: | CHILDERN: conv2d>
  OUT: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
  OPTYPE: input
N_0 inp_0 <PARENTS: | CHILDERN: conv2d>
  OUT: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
  OPTYPE: input

子节点 N_2 relu <PARENTS: conv2d| CHILDERN: >
  IN: conv2d:0<1,6,4,4|float32|NCHW>
  OUT: relu:0(relu)<1,6,4,4|float32|NCHW>
  OPTYPE: nn.relu
  SCOPE: block

MSCGraph 的基本结构和 onnx 类似:一个 MSCGraph 中包含多个 MSCJoint(计算节点)和 MSCTensor(数据)。

MSCJoint#

作用等同于 relax.Exprtorch.nn.Moduleonnx.Node 等,即一个计算图中计算逻辑的最小表达单元。一个 MSCJoint 对应一个 relax.ExprFunction(如果使用了一些 pattern 做分图,例如 conv2d+bias fuseconv2d_bais),MSCJointExpr 的区别在于 MSCJoint 包含更多拓扑信息,不仅仅有 relax.Call.args 对应的 inputs,也包括 parentschildren 以及 outputs ,可以更方便的获取拓扑关系。

以下为 MSCJoint 的描述示例:

node = graph.find_node("conv2d")
type(node), node
(tvm.contrib.msc.core.ir.graph.MSCJoint,
 N_1 conv2d <PARENTS: inp_0| CHILDERN: relu>
   IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
   OUT: conv2d:0<1,6,4,4|float32|NCHW>
   OPTYPE: nn.conv2d
   SCOPE: block
   ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW 
   WEIGHTS: 
     weight: const<6,3,1,1|float32|OIHW>)

其中包含几个重要成员:

  • ID(ID_1):节点的 index,也是遍历计算图的时候使用的排序数值

  • NAME(conv2d):节点名称,每个节点有唯一的 name,用来查找节点

  • PARENTS/CHILDREN(inp_0/batch_norm):节点的拓扑关系

  • ATTRS:节点属性,为了向后兼容都是用 string 类型保存,codegen 的时候会 cast 成对应类型

  • IN/OUT(relu):节点输入输出,每个节点有1到多个outputs,每个output都是一个 MSCTensor;Inputs 则为 parents 的 outputs 的引用

  • WEIGHTS:节点weights,每个 weights 是一个 ref:MSCTensor 对,ref 表示 weight 类型,如 "weight","bias","gamma" 等,定义ref的原因主要是考虑到模型压缩针对不同的weight类型操作不同,故需要对weight进行分类

MSCTensor#

这种数据结构在relax中并未体现,可以理解为NDArray的抽象。MSCTensor描述每个节点outputs以及weights的信息。通过MSCTensor可以查找到producer和consumers,方便对tensor进行操作时获取上下文。MSCTensor格式设计参照了tensorflow的tensor。以下为一个MSCTensor的描述,包含几个重要属性:

type(node.outputs[0]), node.outputs
(tvm.contrib.msc.core.ir.graph.MSCTensor, [conv2d:0<1,6,4,4|float32|NCHW>])
  • Name(conv2d:0):tensor名字,格式为节点名称:数字格式,同tensorflow相同。以这种标记可以查找tensor的producer

  • Shape(1,6,4,4):tensor的shape,动态维度用-1表示

  • Dtype(float32):tensor的数据类型

  • Layout(NCHW):MSC中新加的属性,tensor的数据排布格式。这一属性在剪枝过程中比较重要,对一些计算过程算子的优化也起到参考作用

output = node.output_at(0)

print(f"producer {graph.find_producer(output.name)}")
for c in graph.find_consumers(output.name):
    print(f"has consumer {c}")
producer N_1 conv2d <PARENTS: inp_0| CHILDERN: relu>
  IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
  OUT: conv2d:0<1,6,4,4|float32|NCHW>
  OPTYPE: nn.conv2d
  SCOPE: block
  ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW 
  WEIGHTS: 
    weight: const<6,3,1,1|float32|OIHW>

has consumer N_2 relu <PARENTS: conv2d| CHILDERN: >
  IN: conv2d:0<1,6,4,4|float32|NCHW>
  OUT: relu:0(relu)<1,6,4,4|float32|NCHW>
  OPTYPE: nn.relu
  SCOPE: block