# Tensorflow1 前端

参考: [TVM Tensorflow 前端](https://xinetzone.github.io/tvm/docs/arch/frontend/tensorflow.html)

下面以 [mobilenet_v2 float_v2_1.4_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz) 为例，展示 Tensorflow 前端。

先运行简单的测试：

In [1]:
import os
import warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
tf.get_logger().setLevel("ERROR")
warnings.simplefilter("ignore")
try:
    tf1 = tf.compat.v1
except (ImportError, AttributeError):
    tf1 = tf
import numpy as np
import set_env # 加载 TVM
import tvm.relay.testing.tf as tf_testing
import tvm
from tvm import relay

In [None]:
shape = 1, 224, 224, 3
data = np.random.uniform(size=shape).astype("float32")
output_name = "MobilenetV2/Predictions/Reshape_1"
input_name = "input"
input_dict = {f"{input_name}:0": data}
with tf.Graph().as_default():
    graph_def = tf_testing.get_workload(
        "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz",
        "mobilenet_v2_1.4_224_frozen.pb",
    )
    # 调用实用程序将图定义导入默认 graph
    graph_def = tf_testing.ProcessGraphDefParam(graph_def)
    with tf1.Session() as sess:
        # 添加 shapes 到 graph
        graph_def = tf_testing.AddShapesToGraphDef(sess, output_name)
        # 获取 TF 结果
        out_tensor = sess.graph.get_tensor_by_name(f"{output_name}:0")
        tf_output = sess.run(out_tensor, input_dict)
        # TVM 编译
        mod, params = relay.frontend.from_tensorflow(
            graph_def,
            shape={
                input_name: shape
            }
        )

## TensorFlow 数据布局变换

原始模型输入布局为 NHWC，可将其转换为 NCHW：

In [3]:
desired_layouts = {
    # 'image.resize2d': ['NCHW'],
    'nn.conv2d': ['NCHW', 'default'],
    'nn.max_pool2d': ['NCHW', 'default'],
    'nn.avg_pool2d': ['NCHW', 'default'],
}

# 将布局转换为 NCHW
# RemoveUnusedFunctions 用于清理图。
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
                                relay.transform.ConvertLayout(desired_layouts)])
with tvm.transform.PassContext(opt_level=3):
    mod = seq(mod)

In [None]:
target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
with relay.build_config(opt_level=3):
     lib = relay.build(mod, target, params=params)
m = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
m.set_input(**{input_name: data})
m.run()
tvm_output = [m.get_output(kk).numpy() for kk in range(m.get_num_outputs())]
np.testing.assert_allclose(
    np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-2, atol=1e-3
)

## TF1 变换为 PyTorch 模型

简化模型：

In [5]:
with tvm.transform.PassContext(opt_level=3):
    mod = relay.quantize.prerequisite_optimize(mod, params)

替换 NHWC 为 NCHW 模型：

In [71]:
from tvm.relay.dataflow_pattern import (
    wildcard, is_op, is_var,
    # FunctionPattern,
    DFPatternCallback,
    rewrite
)

class InputNHWC2NCHW(DFPatternCallback):
    def __init__(self):
        super().__init__()
        self.x = is_var()
        self.layout_transform = is_op("layout_transform")(self.x)
        self.pattern = self.layout_transform

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        layout_transform = node_map[self.layout_transform][0]
        inp_type = relay.transform.InferTypeLocal(x)
        N, H, W, C = inp_type.shape
        if layout_transform.attrs.src_layout=="NHWC" and layout_transform.attrs.dst_layout=="NCHW":
            x = relay.var(x.name_hint, shape=(N, C, H, W), dtype=inp_type.dtype, span=x.span)
            # relay.transform.InferTypeLocal(x)
            return x
        return post

In [102]:
expr = rewrite(InputNHWC2NCHW(), mod["main"].body)
run_mod = tvm.IRModule.from_expr(expr)
run_mod = relay.transform.InferType()(run_mod)
with tvm.transform.PassContext(opt_level=3):
    run_mod = relay.quantize.prerequisite_optimize(run_mod, params)

验证数值一致性：

In [90]:
target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
with relay.build_config(opt_level=3):
     lib = relay.build(run_mod, target, params=params)
m = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
m.set_input(**{input_name: data.transpose((0, 3, 1, 2))})
m.run()
new_tvm_output = [m.get_output(kk).numpy() for kk in range(m.get_num_outputs())]
np.testing.assert_allclose(
    np.squeeze(tvm_output[0]), np.squeeze(new_tvm_output[0]), rtol=1e-2, atol=1e-3
)

In [92]:
from tvm.contrib.msc.core.frontend import translate
from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen
from tvm.contrib.msc.framework.torch import codegen

In [118]:
graph, weights = translate.from_relay(run_mod, params, opt_config={"opt_level": 3})

In [119]:
from tvm.relax.testing import relay_translator

target = tvm.target.Target("llvm", host="llvm")
relax_mod = relay_translator.from_relay(run_mod["main"], target)



In [120]:
from tvm.contrib.msc.core.frontend import translate

graph, weights = translate.from_relax(relax_mod)
print(graph)

main <INPUTS: input:0| OUTPUTS: call_tir_159:0>
ID_0 input <PARENTS: | CHILDERN: call_tir>
  OUT: input:0(input)<1,3,224,224|float32>
  OPTYPE: input

ID_1 call_tir <PARENTS: input| CHILDERN: call_tir_1>
  IN: input:0(input)<1,3,224,224|float32>
  OUT: call_tir:0<1,1,224,224,3|float32>
  OPTYPE: call_tir
  SCOPE: block

ID_2 const <PARENTS: | CHILDERN: call_tir_1>
  OUT: const:0<12,1,3,3,3,4|float32>
  OPTYPE: constant
  WEIGHTS: 
    const: const<12,1,3,3,3,4|float32>

ID_3 call_tir_1 <PARENTS: call_tir,const| CHILDERN: call_tir_2>
  IN: call_tir:0<1,1,224,224,3|float32>,const:0<12,1,3,3,3,4|float32>
  OUT: call_tir_1:0<1,12,112,112,4|float32>
  OPTYPE: call_tir
  SCOPE: block

ID_4 const_1 <PARENTS: | CHILDERN: call_tir_2>
  OUT: const_1:0<1,12,1,1,4|float32>
  OPTYPE: constant
  WEIGHTS: 
    const: const_1<1,12,1,1,4|float32>

ID_5 call_tir_2 <PARENTS: call_tir_1,const_1| CHILDERN: call_tir_3>
  IN: call_tir_1:0<1,12,112,112,4|float32>,const_1:0<1,12,1,1,4|float32>
  OUT: call_tir_

In [121]:
model = codegen.to_torch(graph, weights)

InternalError: Traceback (most recent call last):
  6: _ZN3tvm7runtime13PackedFun
  5: tvm::runtime::TypedPackedFunc<tvm::runtime::Map<tvm::runtime::String, tvm::runtime::String, void, void> (tvm::contrib::msc::MSCGraph const&, tvm::runtime::String const&, tvm::runtime::String const&)>::AssignTypedLambda<tvm::contrib::msc::__mk_TVM0::{lambda(tvm::contrib::msc::MSCGraph const&, tvm::runtime::String const&, tvm::runtime::String const&)#1}>(tvm::contrib::msc::__mk_TVM0::{lambda(tvm::contrib::msc::MSCGraph const&, tvm::runtime::String const&, tvm::runtime::String const&)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  4: tvm::contrib::msc::PyCodeGen<tvm::contrib::msc::TorchCodeGenConfig, tvm::contrib::msc::TorchCodeGenHelper>::GetSources(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)
  3: tvm::contrib::msc::PyCodeGen<tvm::contrib::msc::TorchCodeGenConfig, tvm::contrib::msc::TorchCodeGenHelper>::CodeGenScript()
  2: tvm::contrib::msc::TorchCodeGen::CodeGenGraph()
  1: tvm::contrib::msc::PyCodeGen<tvm::contrib::msc::TorchCodeGenConfig, tvm::contrib::msc::TorchCodeGenHelper>::CodeGenNode(tvm::contrib::msc::MSCJoint const&, bool)
  0: tvm::contrib::msc::TorchCodeGen::GetOpCodes(tvm::contrib::msc::MSCJoint const&)
  SCOPE: block
  OPTYPE: call_tir
  OUT: call_tir:0<1,1,224,224,3|float32>
  IN: input:0(input)<1,3,224,224|float32>
  File "/media/pc/data/lxw/ai/tvm/src/contrib/msc/framework/torch/codegen.cc", line 144
InternalError: Check failed: (it != ops_map->end()) is false: Unsupported torch op(call_tir): ID_1 call_tir <PARENTS: input| CHILDERN: call_tir_1>