TensorFlow2 Keras 推理#
下面以模型 resnet_v2_50 为例展示。
需要克隆项目 models,然后执行如下操作。
import os
m_gpu = -1 # 禁用 GPU
os.environ['CUDA_VISIBLE_DEVICES'] = str(m_gpu)
os.environ['CUDA_LAUNCH_BLOCKING'] = str(m_gpu)
import tensorflow as tf
try:
tf1 = tf.compat.v1
except (ImportError, AttributeError):
tf1 = tf
tf.get_logger().setLevel('ERROR')
2023-06-21 16:49:34.172559: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-06-21 16:49:34.247466: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-06-21 16:49:34.248487: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-06-21 16:49:35.798317: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
切换到 models/research/slim
目录下:
%cd /media/pc/data/lxw/ai/tasks/models/research/slim
/media/pc/data/lxw/ai/tasks/models/research/slim
将 TF1 升级为 TF2:
from nets import resnet_v2
import tf_slim as slim
class ResnetV2_50(tf.keras.Model):
def __init__(self, trainable=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.trainable = trainable
@tf.function(input_signature=[tf.TensorSpec([1, 3, 299, 299],
tf.float32, name="data")])
@tf1.keras.utils.track_tf1_style_variables
def call(self, x):
# x = tf.convert_to_tensor(x, tf.float32) # 确保输入是 tensor
x = tf.transpose(x, perm=(0, 2, 3, 1)) # NCHW -> NHWC
with slim.arg_scope(resnet_v2.resnet_arg_scope()):
logits, end_points = resnet_v2.resnet_v2_50(
x,
num_classes=1001,
global_pool=True,
is_training=self.trainable,
scope="resnet_v2_50"
)
del end_points
return tf.nn.softmax(logits)
预处理:
from PIL import Image
import numpy as np
from nets import resnet_v2
from tvm_book.data.classification import ImageFolderDataset
import tf_slim as slim
import tensorflow as tf
@tf.function
def preprocessing(
image,
use_grayscale=False,
central_fraction=0.875,
central_crop=True,
height=299,
width=299,
mean: tuple[float, ...] = (0.485, 0.456, 0.406),
std: tuple[float, ...] = (1, 1, 1)
):
# image = tf.constant(image)
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
if use_grayscale:
image = tf.image.rgb_to_grayscale(image)
if central_crop and central_fraction:
image = tf.image.central_crop(image, central_fraction=central_fraction)
if height and width:
image = tf.expand_dims(image, 0)
image = tf.image.resize(image, [height, width],
method='bilinear',
preserve_aspect_ratio=False,
antialias=False)
image = tf.squeeze(image, [0])
image = tf.subtract(image, mean)
image = tf.divide(image, std)
return image
# 预处理
root = "/media/pc/data/lxw/home/data/datasets/ILSVRC/val"
valset = ImageFolderDataset(root)
image, label_id = valset[1001]
model_dir = 'temp/resnet_v2_50'
# remove_dir(model_dir)
processed_image = preprocessing(
image,
use_grayscale=False,
central_fraction=0.875,
central_crop=True,
height=299,
width=299,
mean=(0.485, 0.456, 0.406),
std=(1, 1, 1)
)
np_processed_images = np.expand_dims(processed_image.numpy(), axis=0)
np_processed_images = np_processed_images.transpose(0, 3, 1, 2)
2023-06-21 16:49:39.660093: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2023-06-21 16:49:39.660172: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:168] retrieving CUDA diagnostic information for host: Alg
2023-06-21 16:49:39.660183: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:175] hostname: Alg
2023-06-21 16:49:39.660370: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:199] libcuda reported version is: 530.30.2
2023-06-21 16:49:39.660427: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:203] kernel reported version is: 530.30.2
2023-06-21 16:49:39.660443: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:309] kernel version seems to match DSO: 530.30.2
前向推理:
model = ResnetV2_50()
model(tf.ones(shape=(1, 3, 299, 299), dtype=tf.float32))
ckpt = tf.train.Checkpoint(model=model)
ckpt_path = "/media/pc/data/board/arria10/lxw/tests/npu_user_demos/models/resnet50_v2_tf/weight/resnet_v2_50.ckpt"
ckpt.restore(ckpt_path) # 更新模型参数
outputs = model(np_processed_images)
outputs = outputs.numpy()
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/keras/engine/base_layer.py:2212: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.
warnings.warn('`layer.apply` is deprecated and '
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/keras/engine/base_layer.py:1345: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
warnings.warn('`layer.updates` will be removed in a future version. '
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/keras/legacy_tf_layers/base.py:627: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS
model.summary()
Model: "resnet_v2_50"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
=================================================================
Total params: 25,615,849
Trainable params: 0
Non-trainable params: 25,615,849
_________________________________________________________________
打印标签信息:
from tvm_book.data.imagenet.classification import ImageNet1kAttr
imagenet1k_attr = ImageNet1kAttr()
sorted_inds = outputs[0].argsort()[::-1]
topk = 5
print(f"真实标签:{imagenet1k_attr.classes_long[label_id]}")
for sorted_ind in sorted_inds[:topk]:
label = imagenet1k_attr.classes_long[sorted_ind-1]
print(f"{sorted_ind-1}: {label.ljust(38)}\t{outputs[0, sorted_ind]}")
真实标签:water ouzel, dipper
20: water ouzel, dipper 0.9207783937454224
143: oystercatcher, oyster catcher 0.014078204520046711
141: redshank, Tringa totanus 0.0032907347194850445
146: albatross, mollymawk 0.0032017454504966736
139: ruddy turnstone, Arenaria interpres 0.002742304001003504
将其模型和参数与加载下来:
# # model = ResnetV2_50()
# inputs = tf.keras.Input(shape=(224, 224, 3), dtype=tf.float32, name="data")
# outputs = model(inputs)
# model2 = tf.keras.Model(inputs=inputs, outputs=outputs, name="resnet_v2_50_model")
# model2.save(module_with_signature_path)
module_with_signature_path = "/tmp/resnet_v2_50_keras"
model.save(module_with_signature_path)
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/keras/engine/base_layer.py:2212: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.
warnings.warn('`layer.apply` is deprecated and '
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/keras/engine/base_layer.py:1345: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
warnings.warn('`layer.updates` will be removed in a future version. '
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/keras/legacy_tf_layers/base.py:627: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS
imported_with_signatures = tf.saved_model.load(module_with_signature_path)
infer = imported_with_signatures.signatures['serving_default']
labeling = infer(tf.constant(np_processed_images))
from tvm_book.data.imagenet.classification import ImageNet1kAttr
outputs = labeling['output_1'].numpy()
imagenet1k_attr = ImageNet1kAttr()
sorted_inds = outputs[0].argsort()[::-1]
topk = 5
print(f"真实标签:{imagenet1k_attr.classes_long[label_id]}")
for sorted_ind in sorted_inds[:topk]:
label = imagenet1k_attr.classes_long[sorted_ind-1]
print(f"{sorted_ind-1}: {label.ljust(38)}\t{outputs[0, sorted_ind]}")
真实标签:water ouzel, dipper
20: water ouzel, dipper 0.9207783937454224
143: oystercatcher, oyster catcher 0.014078204520046711
141: redshank, Tringa totanus 0.0032907347194850445
146: albatross, mollymawk 0.0032017454504966736
139: ruddy turnstone, Arenaria interpres 0.002742304001003504
转换为 ONNX 模型#
import tf2onnx
import onnx
input_signature = [tf.TensorSpec([None, 3, 299, 299], tf.float32, name="data")]
onnx_model, external_tensor_storage = tf2onnx.convert.from_keras(model, input_signature)
onnx.save(onnx_model, "/tmp/resnet_v2_50_tf.onnx")
2023-06-21 16:50:08.883734: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2023-06-21 16:50:08.883960: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2023-06-21 16:50:13.670424: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2023-06-21 16:50:13.671135: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
构建库:
import set_env
from tvm.relay.frontend import from_onnx
shape_dict = {"data": [1, 3, 299, 299]}
graph_def = onnx.load("/tmp/resnet_v2_50_tf.onnx")
mod, params = from_onnx(
graph_def,
shape_dict,
freeze_params=True
)
推理:
import tvm
from tvm import relay
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, "llvm", params=params)
inputs_dict = {"data": np_processed_images}
mlib_proxy = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu()))
mlib_proxy.run(**inputs_dict)
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
验证一致性:
np.testing.assert_allclose(
labeling['output_1'].numpy(),
mlib_proxy.get_output(0).numpy(),
rtol=1e-07, atol=1e-5
)
转换为 TFLite 模型#
import tensorflow as tf
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(module_with_signature_path)
tflite_model = converter.convert()
# Save the model.
with open('temp/resnet_v2_50.tflite', 'wb') as f:
f.write(tflite_model)
2023-06-21 16:52:14.629868: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
2023-06-21 16:52:14.630049: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2023-06-21 16:52:14.649710: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /tmp/resnet_v2_50_keras
2023-06-21 16:52:14.679437: I tensorflow/cc/saved_model/reader.cc:89] Reading meta graph with tags { serve }
2023-06-21 16:52:14.679522: I tensorflow/cc/saved_model/reader.cc:130] Reading SavedModel debug info (if present) from: /tmp/resnet_v2_50_keras
2023-06-21 16:52:14.765580: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:353] MLIR V1 optimization pass is not enabled
2023-06-21 16:52:14.785594: I tensorflow/cc/saved_model/loader.cc:231] Restoring SavedModel bundle.
2023-06-21 16:52:15.679311: I tensorflow/cc/saved_model/loader.cc:215] Running initialization op on SavedModel bundle at path: /tmp/resnet_v2_50_keras
2023-06-21 16:52:15.933247: I tensorflow/cc/saved_model/loader.cc:314] SavedModel load for tags { serve }; Status: success: OK. Took 1283554 microseconds.
2023-06-21 16:52:16.844598: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-06-21 16:52:18.902510: I tensorflow/compiler/mlir/lite/flatbuffer_export.cc:2116] Estimated count of arithmetic ops: 13.119 G ops, equivalently 6.559 G MACs
加载 TFLite 模型:
import tflite
with open('temp/resnet_v2_50.tflite', "rb") as fp:
tflite_model_buf = fp.read()
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
mod, params = relay.frontend.from_tflite(
tflite_model, shape_dict=shape_dict,
dtype_dict={"data": "float32"}
)
desired_layouts = {
# 'image.resize2d': ['NCHW'],
'nn.conv2d': ['NCHW', 'default'],
'nn.max_pool2d': ['NCHW', 'default'],
'nn.avg_pool2d': ['NCHW', 'default'],
}
# NHWC 将布局转换为 NCHW 且移除未使用算子
seq = tvm.transform.Sequential([
relay.transform.RemoveUnusedFunctions(),
relay.transform.ConvertLayout(desired_layouts)
])
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
验证结果一致性:
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, "llvm", params=params)
inputs_dict = {"data": np_processed_images}
mlib_proxy = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu()))
mlib_proxy.run(**inputs_dict)
np.testing.assert_allclose(
labeling['output_1'].numpy(),
mlib_proxy.get_output(0).numpy(),
rtol=1e-07, atol=1e-5
)
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[18], line 7
5 mlib_proxy = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu()))
6 mlib_proxy.run(**inputs_dict)
----> 7 np.testing.assert_allclose(
8 labeling['output_1'].numpy(),
9 mlib_proxy.get_output(0).numpy(),
10 rtol=1e-07, atol=1e-5
11 )
[... skipping hidden 1 frame]
File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/numpy/testing/_private/utils.py:844, in assert_array_compare(comparison, x, y, err_msg, verbose, header, precision, equal_nan, equal_inf)
840 err_msg += '\n' + '\n'.join(remarks)
841 msg = build_err_msg([ox, oy], err_msg,
842 verbose=verbose, header=header,
843 names=('x', 'y'), precision=precision)
--> 844 raise AssertionError(msg)
845 except ValueError:
846 import traceback
AssertionError:
Not equal to tolerance rtol=1e-07, atol=1e-05
Mismatched elements: 978 / 1001 (97.7%)
Max absolute difference: 0.92047846
Max relative difference: 3068.9646
x: array([[3.429268e-05, 1.693668e-05, 3.029113e-05, ..., 1.208637e-05,
9.920573e-06, 2.882769e-05]], dtype=float32)
y: array([[1.464797e-04, 1.271962e-04, 3.613982e-04, ..., 5.739641e-05,
8.909408e-05, 1.503596e-03]], dtype=float32)
警告
TFLite 转换出现了问题,暂时搁置。