# Tensorflow(pb) 转 ONNX

参考: [TVM Tensorflow 前端](https://xinetzone.github.io/tvm/docs/arch/frontend/tensorflow.html)

下面以 [mobilenet_v2 float_v2_1.4_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz) 为例，展示 Tensorflow pb 模型转换为 ONNX 模型的过程:

In [1]:
import numpy as np
import os
import warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
tf.get_logger().setLevel("ERROR")
warnings.simplefilter("ignore")
import tensorflow as tf
from tensorflow.core.framework.graph_pb2 import GraphDef

In [2]:
import set_env # 加载 TVM
import tvm.relay.testing.tf as tf_testing
class MobilenetV2(tf.keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        path_model = tf_testing.get_workload_official(
            "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz",
            "mobilenet_v2_1.4_224_frozen.pb"
        )
        self.graph_def = self._read_graph_def(path_model)
        self.output_names = ['output']
    
    def _read_graph_def(self, frozen_path):
        with open(frozen_path, 'rb') as f:
            graph_def = GraphDef()
            graph_def.ParseFromString(f.read())
        return graph_def
    
    @tf.function(input_signature=[tf.TensorSpec([1, 3, 224, 224], 
                                                 tf.float32, name="input")])
    def call(self, x):
        x = tf.convert_to_tensor(x, tf.float32) # 确保输入是 tensor
        x = tf.transpose(x, perm=(0, 2, 3, 1)) # NCHW -> NHWC
        x = tf.graph_util.import_graph_def(
            self.graph_def, input_map={'input:0': x}, 
            return_elements=['MobilenetV2/Predictions/Reshape_1:0']
            # return_elements=["MobilenetV2/Logits/AvgPool:0"]
            # return_elements=["MobilenetV2/Logits/Conv2d_1c_1x1/BiasAdd:0"]
        )[0]
        return x

In [4]:
import tf2onnx
import onnx

input_signature = [tf.TensorSpec([1, 3, 224, 224], tf.float32, name="data")]
model = MobilenetV2()
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature)
onnx.save(onnx_model, ".temp/mobilenet_v2_tf.onnx")

I0000 00:00:1726190666.909964 1341751 devices.cc:67] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 2
I0000 00:00:1726190668.185528 1341751 devices.cc:67] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 2


测试一致性:

In [5]:
from PIL import Image
image_size = 224
path = 'images/Giant_Panda_in_Beijing_Zoo_1.jpg' # 将要预测的图片路径

with Image.open(path) as im:
    if im.mode != "RGB":
        im.convert("RGB")
    im = im.resize((224, 224))
    image = np.asarray(im)
image = image/128 -1
images = np.expand_dims(image, 0)
images = images.transpose((0, 3, 1, 2))

In [None]:
model = MobilenetV2()
tf_output = model(images)
model.summary()

In [8]:
import set_env
import tvm
from tvm import relay
from tvm.relay.frontend import from_onnx

shape_dict = {"data": [1, 3, 224, 224]}
mod, params = from_onnx(
    onnx_model,
    shape_dict,
    freeze_params=True
)
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, "llvm", params=params)
inputs_dict = {"data": images}
mlib_proxy = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu()))
mlib_proxy.run(**inputs_dict)
np.testing.assert_allclose(
    tf_output.numpy(), 
    mlib_proxy.get_output(0).numpy(),
    rtol=1e-02, atol=1e-5
)

In [9]:
from tvm.contrib.msc.core.frontend import translate

graph, weights = translate.from_relay(mod, params, opt_config={"opt_level": 3})

In [10]:
graph

main <INPUTS: data:0| OUTPUTS: reshape_2:0>
ID_0 data <PARENTS: | CHILDERN: msc.conv2d_bias_52>
  OUT: data:0(data)<1,3,224,224|float32>
  OPTYPE: input

ID_1 msc.conv2d_bias_52 <PARENTS: data| CHILDERN: clip>
  IN: data:0(data)<1,3,224,224|float32>
  OUT: msc.conv2d_bias_52:0<1,48,112,112|float32>
  OPTYPE: msc.conv2d_bias
  SCOPE: block
  ATTRS: out_dtype= strides=2,2 kernel_layout=OIHW groups=1 channels=48 kernel_size=3,3 axis=1 padding=0,0,1,1 data_layout=NCHW dilation=1,1 out_layout= 
  WEIGHTS: 
    weight: const_104<48,3,3,3|float32>
    bias: const_105<48|float32>

ID_2 clip <PARENTS: msc.conv2d_bias_52| CHILDERN: msc.conv2d_bias_51>
  IN: msc.conv2d_bias_52:0<1,48,112,112|float32>
  OUT: clip:0<1,48,112,112|float32>
  OPTYPE: clip
  SCOPE: block
  ATTRS: a_min=0.000000 a_max=6.000000 

ID_3 msc.conv2d_bias_51 <PARENTS: clip| CHILDERN: clip_1>
  IN: clip:0<1,48,112,112|float32>
  OUT: msc.conv2d_bias_51:0<1,48,112,112|float32>
  OPTYPE: msc.conv2d_bias
  SCOPE: block
  ATTRS: o