解读 GraphExecutorCodegen
#
以双头网络作为引子#
创建双头输出小网络:
import numpy as np
import tvm
from tvm import relay
from tvm.relay.build_module import bind_params_by_name
x = relay.var("x", shape=(1, 1, 8, 8), dtype="int8")
w = relay.var("w", shape=(2, 1, 3, 3), dtype="int8")
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
mod = tvm.IRModule.from_expr(relay.Tuple([conv2d, relu]))
mod["main"] = bind_params_by_name(mod["main"],
{"w": tvm.nd.array(np.ones(shape=(2, 1, 3, 3),
dtype="int8"))})
rt_lib = relay.build(mod, target="llvm")
rt_lib.params.keys(), rt_lib.params["p0"].shape, rt_lib.params["p0"].dtype
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
(dict_keys(['p0']), (2, 1, 3, 3), 'int8')
此网络结构如下:
print(rt_lib.ir_mod)
def @main(%x: Tensor[(1, 1, 8, 8), int8]) {
%0 = nn.conv2d(%x, meta[relay.Constant][0], padding=[0, 0, 0, 0]);
%1 = nn.relu(%0);
(%0, %1)
}
查看 Graph Json:
import toml
bunch = eval(rt_lib.graph_json)
print(toml.dumps(bunch))
arg_nodes = [ 0, 1,]
heads = [ [ 2, 0, 0,], [ 3, 0, 0,],]
node_row_ptr = [ 0, 1, 2, 3, 4,]
[[nodes]]
op = "null"
name = "x"
inputs = []
[[nodes]]
op = "null"
name = "p0"
inputs = []
[[nodes]]
op = "tvm_op"
name = "tvmgen_default_fused_nn_conv2d"
inputs = [ [ 0, 0, 0,], [ 1, 0, 0,],]
[nodes.attrs]
num_outputs = "1"
num_inputs = "2"
flatten_data = "0"
func_name = "tvmgen_default_fused_nn_conv2d"
out_layout = ""
kernel_layout = "OIHW"
data_layout = "NCHW"
hash = "8f5bab575bcb83dc"
[[nodes]]
op = "tvm_op"
name = "tvmgen_default_fused_nn_relu"
inputs = [ [ 2, 0, 0,],]
[nodes.attrs]
num_outputs = "1"
num_inputs = "1"
flatten_data = "0"
func_name = "tvmgen_default_fused_nn_relu"
hash = "fd6e720bc47ba75c"
[attrs]
dltype = [ "list_str", [ "int8", "int8", "int8", "int8",],]
device_index = [ "list_int", [ 1, 1, 1, 1,],]
storage_id = [ "list_int", [ 0, 1, 2, 3,],]
shape = [ "list_shape", [ [ 1, 1, 8, 8,], [ 2, 1, 3, 3,], [ 1, 2, 6, 6,], [ 1, 2, 6, 6,],],]
解读 CreateGraphCodegenMod
源码#
定义计算图节点类型枚举类:
/*! \brief Node types */
enum GraphNodeType {
kGraphNop,
kGraphInputNode,
kGraphOpNode,
};
使用 Python 实现为:
from enum import Enum
class GraphNodeType(Enum):
"""节点枚举类型
Attrs:
kGraphNop: 非算子节点
kGraphInputNode: 参数节点的索引列表,它是计算图的占位符/变量/输入节点 或 constant/param。
kGraphOpNode: 算子节点
"""
kGraphNop: int = 0
kGraphInputNode: int = 1
kGraphOpNode: int = 2
节点基类定义如下:
/*! \brief Base Node class */
class GraphNode {
public:
GraphNode() {}
virtual void Save(dmlc::JSONWriter* writer) const {}
virtual void Load(dmlc::JSONReader* reader) {}
virtual GraphNodeType Type() const { return kGraphNop; }
virtual ~GraphNode() {}
public:
int num_outputs_{1};
std::string name_;
GraphAttrs attrs_;
};
使用 Python 实现如下:
from typing import Any
from dataclasses import dataclass
from abc import ABC, abstractmethod
GraphAttrs = dict[str, Any]
@dataclass
class GraphNode(ABC):
name: str
attrs: GraphAttrs
@abstractmethod
def Save(self, writer) -> None:
...
@abstractmethod
def Load(self, reader) -> None:
...
@abstractmethod
def Type(self) -> GraphNodeType:
return GraphNodeType.kGraphNop
输入节点:
/*! \brief Input Node */
class GraphInputNode : public GraphNode {
public:
GraphInputNode() {}
GraphInputNode(const std::string& name, const GraphAttrs& attrs) {
name_ = name;
attrs_ = attrs;
}
GraphNodeType Type() const override { return kGraphInputNode; }
void Save(dmlc::JSONWriter* writer) const override {
const std::string op_name{"null"};
writer->BeginObject();
writer->WriteObjectKeyValue("op", op_name);
writer->WriteObjectKeyValue("name", this->name_);
writer->WriteObjectKeyValue("inputs", std::list<int>());
writer->EndObject();
}
static std::shared_ptr<GraphNode> make_node_ptr(const std::string& name,
const GraphAttrs& attrs) {
auto ptr = std::make_shared<GraphInputNode>(name, attrs);
return std::dynamic_pointer_cast<GraphNode>(ptr);
}
};
使用 Python 实现:
@dataclass
class GraphInputNode(GraphNode):
inputs: list[int]
def Type(self) -> GraphNodeType:
return GraphNodeType.kGraphInputNode
def Save(self, writer) -> None:
bunch = {
"op": "null",
"name": self.name,
"inputs": []
}
# 写入到 writer 句柄
...
def Load(self, reader) -> None:
...
def make_node_ptr(self):
# make_node(name, attrs)
...
同样使用 Python 实现算子节点类:
@dataclass
class GraphNodeRef:
ident: int # 节点引用索引
index: int = 0 # 暂不知作用
version: int = 0 # 暂不知作用
@dataclass
class GraphOpNode(GraphNode):
nd_attrs: GraphAttrs
op_name: str
inputs: list[GraphNodeRef]
num_outputs: int = 1
def __post_init__(self):
self.attrs["func_name"] = self.op_name
self.attrs["flatten_data"] = "0"
self.attrs["num_inputs"] = str(sum(self.inputs))
self.attrs["num_outputs"] = str(self.num_outputs)
def Type(self) -> GraphNodeType:
return GraphNodeType.kGraphOpNode
def Save(self, writer) -> None:
bunch = {
"op": "tvm_op",
"name": self.name,
"attrs": self.attrs,
"inputs": self.inputs
}
# 写入到 writer 句柄
...
def Load(self, reader) -> None:
...
def make_node_ptr(self):
# make_node(name, nd_attrs, op_name, inputs, attrs, num_outputs)
...
下面进入正题:
代码生成器 GraphExecutorCodegen
#
图执行器的代码生成器,生成包含 Graph JSON、模块和模块的参数。
@dataclass
class LoweredOutput:
graph_json: str
lowered_funcs: dict[str, tvm.IRModule]
external_mods: list[tvm.IRModule]
params: dict[str, tvm.runtime.NDArray]
@dataclass
class GraphExecutorCodegen:
mod: tvm.runtime.Module
targets: list[tvm.target.Target]
def GetStorageInfo(self, expr) -> "tvm.relay.backend.StorageInfo":
"""获取单个表达式的存储信息"""
...
def Codegen(self, mod: tvm.IRModule,
func: relay.Function,
mod_name: str) -> "tvm.relay.backend.LoweredOutput":
"""
1. lowering 前需要规划内存并更新 workspace 大小
2. 获取 lowered_main_func
3. 将所有参数转换为输入节点。
4. 收集外部代码生成的任何运行时模块。
5. 收集外部代码提取的任何常量。
6. 收集在 lowering 过程中提取的任何常数。
7. 按目标分隔模块中的函数
8. 需要保存 Graph Json 到输出
"""
...
回到双头网络的例子中#
下面仔细解读这些 Graph Json 信息。
由于双头网络有两个输出,故而
heads = [ [ 2, 0, 0,], [ 3, 0, 0,],]
指示两个输出节点的索引。arg_nodes = [ 0, 1,]
说明参数节点的位置。
使用 Python 实现:
from dataclasses import field
@dataclass
class GraphAttrs:
"""`
Args:
dltype: 每个节点的数据类型按顺序排列。
device_index: 按顺序为每个节点分配设备。
storage_id: 存储布局中每个节点的内存 slot id。
shape: 每个节点的 k 阶形状。
storage_id: 存储布局中每个节点的内存 slot id。
将参数名称映射到一对 ({storage_id: tvm.runtime.NDArray})。在运行时,可以使用 storage_id 查找参数。
"""
dltype: list
device_index: list
storage_id: list
shape: list
@dataclass
class GraphNodeAttrs:
"""
Args:
flatten_data: 是否需要在执行前将数据扁平化(flattened)
func_name: 融合函数名,对应于 Relay 编译过程生成的库中的符号。
num_inputs: 此节点的 inputs 个数
num_outputs: 此节点产生的 outputs 个数
"""
func_name: str
num_inputs: str
num_outputs: str
flatten_data: str = "0"
hash: str|None = None
@dataclass
class GraphNode:
"""
Args:
op: 运算类型,`null` 意味着它是占位符/变量/输入节点,`tvm_op` 意味着这个节点可以被执行
name: 节点名字
inputs: 运算的 inputs 位置,inputs 是包含 `(nodeid, index, version)` 的元组列表。(可选)
"""
op: str
name: str
inputs: list[int] = field(default_factory=list)
attrs: Any = None
@dataclass
class GraphJson:
"""
Args:
arg_nodes:参数节点的索引列表,它是计算图的占位符/变量/输入节点或 constant/param。
heads: 输出节点的索引列表。
node_row_ptr: 存储 forward 路径的历史,所以推断任务中可以跳过某些算子来构建子图。
attrs: 可以包含版本号或类似的有用信息。
nodes: 节点是占位符或可计算节点。
"""
arg_nodes: list[int]
heads: list[GraphNodeRef]
node_row_ptr: list[int]
attrs: GraphAttrs
nodes: list[GraphNode]
def __post_init__(self):
self.heads = [GraphNodeRef(*head) for head in self.heads]
self.attrs = GraphAttrs(**self.attrs)
self.nodes = [GraphNode(**node) for node in self.nodes]
备注
代码被维护在 tvm_book
API 中。
from dataclasses import asdict
from tvm_book.tvm_utils.graph_json import GraphJson
from tvm_book.data.dataclass import TensorType
@dataclass
class Node:
inputs: list[TensorType]
outputs: list[TensorType]
attrs: dict[str, Any]
graph_json = GraphJson(**eval(rt_lib.graph_json))
转换为字典:
asdict(graph_json).keys()
dict_keys(['arg_nodes', 'heads', 'node_row_ptr', 'attrs', 'nodes'])
其他信息:
graph_json.heads
[GraphNodeRef(ident=2, index=0, version=0),
GraphNodeRef(ident=3, index=0, version=0)]
graph_json.attrs
GraphAttrs(dltype=['list_str', ['int8', 'int8', 'int8', 'int8']], device_index=['list_int', [1, 1, 1, 1]], storage_id=['list_int', [0, 1, 2, 3]], shape=['list_shape', [[1, 1, 8, 8], [2, 1, 3, 3], [1, 2, 6, 6], [1, 2, 6, 6]]])
graph_json.nodes
[GraphNode(op='null', name='x', inputs=[], attrs=None),
GraphNode(op='null', name='p0', inputs=[], attrs=None),
GraphNode(op='tvm_op', name='tvmgen_default_fused_nn_conv2d', inputs=[[0, 0, 0], [1, 0, 0]], attrs={'num_outputs': '1', 'num_inputs': '2', 'flatten_data': '0', 'func_name': 'tvmgen_default_fused_nn_conv2d', 'out_layout': '', 'kernel_layout': 'OIHW', 'data_layout': 'NCHW', 'hash': '8f5bab575bcb83dc'}),
GraphNode(op='tvm_op', name='tvmgen_default_fused_nn_relu', inputs=[[2, 0, 0]], attrs={'num_outputs': '1', 'num_inputs': '1', 'flatten_data': '0', 'func_name': 'tvmgen_default_fused_nn_relu', 'hash': 'fd6e720bc47ba75c'})]
graph_json.attrs.shape
['list_shape', [[1, 1, 8, 8], [2, 1, 3, 3], [1, 2, 6, 6], [1, 2, 6, 6]]]
attrs = []
dtypes = graph_json.attrs.dltype[1]
device_indexes = graph_json.attrs.device_index[1]
storage_ids = graph_json.attrs.storage_id[1]
shapes = graph_json.attrs.shape[1]
for shape, dtype, storage_id, device_index, node in zip(shapes, dtypes, storage_ids, device_indexes, graph_json.nodes):
attr = {
"storage_id": storage_id,
"device_index": device_index,
"inputs": node.inputs,
"op": node.op,
"op_type": TensorType(shape=shape, dtype=dtype, name=node.name),
}
if node.name == "tvm_op":
attr.update(**node.attrs)
attrs.append(attr)
attrs
[{'storage_id': 0,
'device_index': 1,
'inputs': [],
'op': 'null',
'op_type': TensorType(shape=[1, 1, 8, 8], dtype='int8', name='x')},
{'storage_id': 1,
'device_index': 1,
'inputs': [],
'op': 'null',
'op_type': TensorType(shape=[2, 1, 3, 3], dtype='int8', name='p0')},
{'storage_id': 2,
'device_index': 1,
'inputs': [[0, 0, 0], [1, 0, 0]],
'op': 'tvm_op',
'op_type': TensorType(shape=[1, 2, 6, 6], dtype='int8', name='tvmgen_default_fused_nn_conv2d')},
{'storage_id': 3,
'device_index': 1,
'inputs': [[2, 0, 0]],
'op': 'tvm_op',
'op_type': TensorType(shape=[1, 2, 6, 6], dtype='int8', name='tvmgen_default_fused_nn_relu')}]