MSCGraph 简介#
%cd ..
from pathlib import Path
temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
/media/pc/data/lxw/ai/tvm-book/doc/read/msc
MSCGraph 是 MSC 的核心,它对编译器的作用类似于 IR(中间表示)。MSCGraph 是 Relax
的 DAG(有向无环图)格式。构建 MSCGraph 的目标是使压缩算法的开发和权重管理(这在训练时很重要)更加容易。如果所选的运行时目标不支持所有的 Calls,那么 Relax/Relay 模块将拥有多个 MSCGraphs。MSCGraph 存在于编译过程各个阶段,用于管理模型计算信息。
%%file graph/model.py
import numpy as np
import torch
from torch import fx
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 6, 1, bias=False)
self.relu = torch.nn.ReLU()
def forward(self, data):
x = self.conv(data)
return self.relu(x)
def get_model(input_info):
# 转换前端模型为 IRModule
with torch.no_grad():
torch_fx_model = fx.symbolic_trace(M())
mod = from_fx(torch_fx_model, input_info, keep_params_as_input=False)
# mod, params = relax.frontend.detach_params(mod)
return mod, torch_fx_model
Overwriting graph/model.py
构建 MSC 计算图#
from tvm.contrib.msc.core.frontend import translate
从 relax
构建 msc 计算图:
from graph.model import get_model
input_info = [((1, 3, 4, 4), "float32")] # 给定输入 shape 和数据类型
mod, _ = get_model(input_info)
mod.show()
graph, weights = translate.from_relax(mod)
print(graph)
Show code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(inp_0: R.Tensor((1, 3, 4, 4), dtype="float32")) -> R.Tensor((1, 6, 4, 4), dtype="float32"):
with R.dataflow():
lv: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(inp_0, metadata["relax.expr.Constant"][0], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.relu(lv)
gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv1
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
main <INPUTS: inp_0:0| OUTPUTS: relu:0>
N_0 inp_0 <PARENTS: | CHILDERN: conv2d>
OUT: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
OPTYPE: input
N_1 conv2d <PARENTS: inp_0| CHILDERN: relu>
IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
OUT: conv2d:0<1,6,4,4|float32|NCHW>
OPTYPE: nn.conv2d
SCOPE: block
ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW
WEIGHTS:
weight: const<6,3,1,1|float32|OIHW>
N_2 relu <PARENTS: conv2d| CHILDERN: >
IN: conv2d:0<1,6,4,4|float32|NCHW>
OUT: relu:0(relu)<1,6,4,4|float32|NCHW>
OPTYPE: nn.relu
SCOPE: block
type(graph)
tvm.contrib.msc.core.ir.graph.MSCGraph
导出序列化文件以加载计算图:
print(graph.to_json())
Show code cell output
{
"name": "main",
"inputs": [
"inp_0:0"
],
"outputs": [
"relu:0"
],
"nodes": [
{
"index": 0,
"name": "inp_0",
"shared_ref": "",
"optype": "input",
"parents": [],
"inputs": [],
"outputs": [
{
"name": "inp_0:0",
"alias": "inp_0",
"dtype": "float32",
"layout": "NCHW",
"shape": [1, 3, 4, 4],
"prims": []
}
],
"attrs": {},
"weights": {}
},
{
"index": 1,
"name": "conv2d",
"shared_ref": "",
"optype": "nn.conv2d",
"parents": [
"inp_0"
],
"inputs": [
"inp_0:0"
],
"outputs": [
{
"name": "conv2d:0",
"alias": "",
"dtype": "float32",
"layout": "NCHW",
"shape": [1, 6, 4, 4],
"prims": []
}
],
"attrs": {
"out_layout": "NCHW",
"data_layout": "NCHW",
"padding": "0,0,0,0",
"groups": "1",
"kernel_layout": "OIHW",
"strides": "1,1",
"dilation": "1,1",
"out_dtype": "float32"
},
"weights": {"weight": {
"name": "const",
"alias": "",
"dtype": "float32",
"layout": "OIHW",
"shape": [6, 3, 1, 1],
"prims": []}}
},
{
"index": 2,
"name": "relu",
"shared_ref": "",
"optype": "nn.relu",
"parents": [
"conv2d"
],
"inputs": [
"conv2d:0"
],
"outputs": [
{
"name": "relu:0",
"alias": "relu",
"dtype": "float32",
"layout": "NCHW",
"shape": [1, 6, 4, 4],
"prims": []
}
],
"attrs": {},
"weights": {}
}
],
"prims": []
}
导出用于可视化的 prototxt 文件:
print(graph.visualize(f"{temp_dir}/graph.prototxt"))
Show code cell output
name: "main"
layer {
name: "inp_0"
type: "input"
top: "inp_0"
layer_param {
idx: 0
output_0: "inp_0:0(inp_0)<1,3,4,4|float32|NCHW>"
}
}
layer {
name: "conv2d"
type: "nn_conv2d"
top: "conv2d"
bottom: "inp_0"
layer_param {
out_layout: "NCHW"
out_dtype: "float32"
groups: "1"
kernel_layout: "OIHW"
param_weight: "const<6,3,1,1|float32|OIHW>"
strides: "1,1"
idx: 1
padding: "0,0,0,0"
data_layout: "NCHW"
dilation: "1,1"
output_0: "conv2d:0<1,6,4,4|float32|NCHW>"
input_0: "inp_0:0(inp_0)<1,3,4,4|float32|NCHW>"
}
}
layer {
name: "relu"
type: "nn_relu"
top: "relu"
bottom: "conv2d"
layer_param {
idx: 2
input_0: "conv2d:0<1,6,4,4|float32|NCHW>"
output_0: "relu:0(relu)<1,6,4,4|float32|NCHW>"
}
}
通过将 relax/relay
结构转换成 MSCGraph,可以使用 DAG 风格的方法进行节点的查找与遍历:
for node in graph.get_nodes():
if node.optype == "nn.conv2d":
print(f"conv2d 节点 {node}")
for i in graph.get_inputs():
print(f"input 节点 {i}")
node = graph.find_node("conv2d")
print(f"核心 conv 节点 {node}")
for p in node.parents:
print(f"父节点 {p}" + str(p))
for c in node.children:
print(f"子节点 {c}")
Show code cell output
conv2d 节点 N_1 conv2d <PARENTS: inp_0| CHILDERN: relu>
IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
OUT: conv2d:0<1,6,4,4|float32|NCHW>
OPTYPE: nn.conv2d
SCOPE: block
ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW
WEIGHTS:
weight: const<6,3,1,1|float32|OIHW>
input 节点 inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
核心 conv 节点 N_1 conv2d <PARENTS: inp_0| CHILDERN: relu>
IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
OUT: conv2d:0<1,6,4,4|float32|NCHW>
OPTYPE: nn.conv2d
SCOPE: block
ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW
WEIGHTS:
weight: const<6,3,1,1|float32|OIHW>
父节点 N_0 inp_0 <PARENTS: | CHILDERN: conv2d>
OUT: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
OPTYPE: input
N_0 inp_0 <PARENTS: | CHILDERN: conv2d>
OUT: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
OPTYPE: input
子节点 N_2 relu <PARENTS: conv2d| CHILDERN: >
IN: conv2d:0<1,6,4,4|float32|NCHW>
OUT: relu:0(relu)<1,6,4,4|float32|NCHW>
OPTYPE: nn.relu
SCOPE: block
MSCGraph 的基本结构和 onnx 类似:一个 MSCGraph 中包含多个 MSCJoint(计算节点)和 MSCTensor(数据)。
MSCJoint#
作用等同于 relax.Expr
,torch.nn.Module
,onnx.Node
等,即一个计算图中计算逻辑的最小表达单元。一个 MSCJoint
对应一个 relax.Expr
或 Function
(如果使用了一些 pattern
做分图,例如 conv2d+bias fuse
成 conv2d_bais
),MSCJoint
和 Expr
的区别在于 MSCJoint
包含更多拓扑信息,不仅仅有 relax.Call.args
对应的 inputs
,也包括 parents
、children
以及 outputs
,可以更方便的获取拓扑关系。
以下为 MSCJoint
的描述示例:
node = graph.find_node("conv2d")
type(node), node
(tvm.contrib.msc.core.ir.graph.MSCJoint,
N_1 conv2d <PARENTS: inp_0| CHILDERN: relu>
IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
OUT: conv2d:0<1,6,4,4|float32|NCHW>
OPTYPE: nn.conv2d
SCOPE: block
ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW
WEIGHTS:
weight: const<6,3,1,1|float32|OIHW>)
其中包含几个重要成员:
ID(ID_1)
:节点的index
,也是遍历计算图的时候使用的排序数值NAME(
conv2d
):节点名称,每个节点有唯一的name
,用来查找节点PARENTS/CHILDREN(inp_0/batch_norm):节点的拓扑关系
ATTRS
:节点属性,为了向后兼容都是用 string 类型保存,codegen 的时候会 cast 成对应类型IN/OUT(relu):节点输入输出,每个节点有1到多个outputs,每个output都是一个 MSCTensor;Inputs 则为 parents 的 outputs 的引用
WEIGHTS:节点weights,每个 weights 是一个
ref:MSCTensor
对,ref
表示 weight 类型,如"weight","bias","gamma"
等,定义ref的原因主要是考虑到模型压缩针对不同的weight类型操作不同,故需要对weight进行分类
MSCTensor#
这种数据结构在relax中并未体现,可以理解为NDArray的抽象。MSCTensor描述每个节点outputs以及weights的信息。通过MSCTensor可以查找到producer和consumers,方便对tensor进行操作时获取上下文。MSCTensor格式设计参照了tensorflow的tensor。以下为一个MSCTensor的描述,包含几个重要属性:
type(node.outputs[0]), node.outputs
(tvm.contrib.msc.core.ir.graph.MSCTensor, [conv2d:0<1,6,4,4|float32|NCHW>])
Name(conv2d:0):tensor名字,格式为节点名称:数字格式,同tensorflow相同。以这种标记可以查找tensor的producer
Shape(1,6,4,4):tensor的shape,动态维度用-1表示
Dtype(float32):tensor的数据类型
Layout(NCHW):MSC中新加的属性,tensor的数据排布格式。这一属性在剪枝过程中比较重要,对一些计算过程算子的优化也起到参考作用
output = node.output_at(0)
print(f"producer {graph.find_producer(output.name)}")
for c in graph.find_consumers(output.name):
print(f"has consumer {c}")
producer N_1 conv2d <PARENTS: inp_0| CHILDERN: relu>
IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>
OUT: conv2d:0<1,6,4,4|float32|NCHW>
OPTYPE: nn.conv2d
SCOPE: block
ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW
WEIGHTS:
weight: const<6,3,1,1|float32|OIHW>
has consumer N_2 relu <PARENTS: conv2d| CHILDERN: >
IN: conv2d:0<1,6,4,4|float32|NCHW>
OUT: relu:0(relu)<1,6,4,4|float32|NCHW>
OPTYPE: nn.relu
SCOPE: block