{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "wJcYs_ERTnnI" }, "outputs": [], "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "HMUDt0CiUJk9", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "77z2OchJTk0l" }, "source": [ "# 将您的 TFLite 代码迁移到 TF2\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "meUTrR4I6m1C" }, "source": [ "[TensorFlow Lite](https://tensorflow.google.cn/lite/guide) (TFLite) 是一套工具,可帮助开发者在设备端(移动、嵌入式和物联网设备)上运行机器学习推断。[TFLite 转换器](https://tensorflow.google.cn/lite/convert)可以将现有 TF 模型转换为可在设备端高效运行的优化 TFLite 模型格式。\n", "\n", "在本文档中,您将了解需要对 TF 到 TFLite 的转换代码进行哪些更改,然后是几个实现相同目标的示例。\n", "\n", "## TF 到 TFLite 转换代码的更改\n", "\n", "- 如果您使用的是旧版 TF1 模型格式(例如,Keras 文件、冻结的 GraphDef、检查点、tf.Session 等),请将其更新为 TF1/TF2 SavedModel,并使用 TF2 转换器 API `tf.lite.TFLiteConverter.from_saved_model(...)` 将其转换为 TFLite 模型(请参见表 1)。\n", "\n", "- 更新转换器 API 标志(请参见表 2)。\n", "\n", "- 移除旧版 API,例如 `tf.lite.constants`。(例如:将 `tf.lite.constants.INT8` 替换为 `tf.int8`)\n", "\n", "// 表 1 // TFLite Python 转换器 API 更新\n", "\n", "TF1 API | TF2 API\n", "--- | ---\n", "`tf.lite.TFLiteConverter.from_saved_model('saved_model/',..)` | *支持*\n", "`tf.lite.TFLiteConverter.from_keras_model_file('model.h5',..)` | *已移除(更新为 SavedModel 格式)*\n", "`tf.lite.TFLiteConverter.from_frozen_graph('model.pb',..)` | *已移除(更新为 SavedModel 格式)*\n", "`tf.lite.TFLiteConverter.from_session(sess,...)` | *已移除(更新为 SavedModel 格式)*" ] }, { "cell_type": "markdown", "metadata": { "id": "Rf75rjeedigq" }, "source": [ "<style> .table {margin-left: 0 !important;} </style>" ] }, { "cell_type": "markdown", "metadata": { "id": "XbVlZNizW1-Y" }, "source": [ "// 表 2 // TFLite Python 转换器 API 标志更新\n", "\n", "TF1 API | TF2 API\n", "--- | ---\n", "`allow_custom_ops`
`optimizations`
`representative_dataset`
`target_spec`
`inference_input_type`
`inference_output_type`
`experimental_new_converter`
`experimental_new_quantizer` | *支持*







\n", "`input_tensors`
`output_tensors`
`input_arrays_with_shape`
`output_arrays`
`experimental_debug_info_func` | *已移除(不支持的转换器 API 参数)*




\n", "`change_concat_input_ranges`
`default_ranges_stats`
`get_input_arrays()`
`inference_type`
`quantized_input_stats`
`reorder_across_fake_quant` | *已移除(不支持的量化工作流)*





\n", "`conversion_summary_dir`
`dump_graphviz_dir`
`dump_graphviz_video` | *已移除(改为使用 [Netron](https://lutzroeder.github.io/netron/) 或 [visualize.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py) 呈现模型)*


\n", "`output_format`
`drop_control_dependency` | *已移除(TF2 中不支持的功能)*

" ] }, { "cell_type": "markdown", "metadata": { "id": "YdZSoIXEbhg-" }, "source": [ "## 示例\n", "\n", "您现在将演练一些示例,将旧版 TF1 模型转换为 TF1/TF2 SavedModel,然后将其转换为 TF2 TFLite 模型。\n", "\n", "### 安装\n", "\n", "从必要的 TensorFlow 导入开始。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iE0vSfMXumKI", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow.compat.v1 as tf1\n", "import numpy as np\n", "\n", "import logging\n", "logger = tf.get_logger()\n", "logger.setLevel(logging.ERROR)\n", "\n", "import shutil\n", "def remove_dir(path):\n", " try:\n", " shutil.rmtree(path)\n", " except:\n", " pass" ] }, { "cell_type": "markdown", "metadata": { "id": "89VllCprnFto" }, "source": [ "创建所有必要的 TF1 模型格式。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Bwq8EFiwjzjx", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Create a TF1 SavedModel\n", "SAVED_MODEL_DIR = \"tf_saved_model/\"\n", "remove_dir(SAVED_MODEL_DIR)\n", "with tf1.Graph().as_default() as g:\n", " with tf1.Session() as sess:\n", " input = tf1.placeholder(tf.float32, shape=(3,), name='input')\n", " output = input + 2\n", " # print(\"result: \", sess.run(output, {input: [0., 2., 4.]}))\n", " tf1.saved_model.simple_save(\n", " sess, SAVED_MODEL_DIR,\n", " inputs={'input': input}, \n", " outputs={'output': output})\n", "print(\"TF1 SavedModel path: \", SAVED_MODEL_DIR)\n", "\n", "# Create a TF1 Keras model\n", "KERAS_MODEL_PATH = 'tf_keras_model.h5'\n", "model = tf1.keras.models.Sequential([\n", " tf1.keras.layers.InputLayer(input_shape=(128, 128, 3,), name='input'),\n", " tf1.keras.layers.Dense(units=16, input_shape=(128, 128, 3,), activation='relu'),\n", " tf1.keras.layers.Dense(units=1, name='output')\n", "])\n", "model.save(KERAS_MODEL_PATH, save_format='h5')\n", "print(\"TF1 Keras Model path: \", KERAS_MODEL_PATH)\n", "\n", "# Create a TF1 frozen GraphDef model\n", "GRAPH_DEF_MODEL_PATH = tf.keras.utils.get_file(\n", " 'mobilenet_v1_0.25_128',\n", " origin='https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.25_128_frozen.tgz',\n", " untar=True,\n", ") + '/frozen_graph.pb'\n", "\n", "print(\"TF1 frozen GraphDef path: \", GRAPH_DEF_MODEL_PATH)" ] }, { "cell_type": "markdown", "metadata": { "id": "EzMBpG5rdt-7" }, "source": [ "### 1. 将 TF1 SavedModel 转换为 TFLite 模型\n" ] }, { "cell_type": "markdown", "metadata": { "id": "GFWIlVridt_F" }, "source": [ "#### 之前:使用 TF1 进行转换\n", "\n", "下面是 TF1 样式 TFlite 转换的典型代码。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dzXHHBQRdt_F", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "converter = tf1.lite.TFLiteConverter.from_saved_model(\n", " saved_model_dir=SAVED_MODEL_DIR,\n", " input_arrays=['input'],\n", " input_shapes={'input' : [3]}\n", ")\n", "converter.optimizations = {tf.lite.Optimize.DEFAULT}\n", "converter.change_concat_input_ranges = True\n", "tflite_model = converter.convert()\n", "# Ignore warning: \"Use '@tf.function' or '@defun' to decorate the function.\"" ] }, { "cell_type": "markdown", "metadata": { "id": "NUptsxK_MUy2" }, "source": [ "#### 之后:使用 TF2 进行转换\n", "\n", "将 TF1 SavedModel 直接转换为 TFLite 模型,并设置较小的 v2 转换器标志。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0OyBjZ6Kdt_F", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Convert TF1 SavedModel to a TFLite model.\n", "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir=SAVED_MODEL_DIR)\n", "converter.optimizations = {tf.lite.Optimize.DEFAULT}\n", "tflite_model = converter.convert()" ] }, { "cell_type": "markdown", "metadata": { "id": "yiwu3sso__fH" }, "source": [ "### 2. 将 TF1 Keras 模型文件转换为 TFLite 模型" ] }, { "cell_type": "markdown", "metadata": { "id": "9WTPvPih__fR" }, "source": [ "#### 之前:使用 TF1 进行转换\n", "\n", "下面是 TF1 样式 TFlite 转换的典型代码。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9EXO0xYq__fR", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "converter = tf1.lite.TFLiteConverter.from_keras_model_file(model_file=KERAS_MODEL_PATH)\n", "converter.optimizations = {tf.lite.Optimize.DEFAULT}\n", "converter.change_concat_input_ranges = True\n", "tflite_model = converter.convert()" ] }, { "cell_type": "markdown", "metadata": { "id": "9l6ppTtTZ5Bz" }, "source": [ "#### 之后:使用 TF2 进行转换\n", "\n", "首先,将 TF1 Keras 模型文件转换为 TF2 SavedModel,然后将其转换为 TFLite 模型,并设置较小的 v2 转换器标志。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IGB5ZMGl__fR", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Convert TF1 Keras model file to TF2 SavedModel.\n", "model = tf.keras.models.load_model(KERAS_MODEL_PATH)\n", "model.save(filepath='saved_model_2/')\n", "\n", "# Convert TF2 SavedModel to a TFLite model.\n", "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir='saved_model_2/')\n", "tflite_model = converter.convert()" ] }, { "cell_type": "markdown", "metadata": { "id": "v5Zf6G4M-sZz" }, "source": [ "### 3. 将 TF1 冻结的 GraphDef 转换为 TFLite 模型\n" ] }, { "cell_type": "markdown", "metadata": { "id": "DzCJOV7AUlGZ" }, "source": [ "#### 之前:使用 TF1 进行转换\n", "\n", "下面是 TF1 样式 TFlite 转换的典型代码。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "r7RvcdRv6lll", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "converter = tf1.lite.TFLiteConverter.from_frozen_graph(\n", " graph_def_file=GRAPH_DEF_MODEL_PATH,\n", " input_arrays=['input'],\n", " input_shapes={'input' : [1, 128, 128, 3]},\n", " output_arrays=['MobilenetV1/Predictions/Softmax'],\n", ")\n", "converter.optimizations = {tf.lite.Optimize.DEFAULT}\n", "converter.change_concat_input_ranges = True\n", "tflite_model = converter.convert()" ] }, { "cell_type": "markdown", "metadata": { "id": "ZdIogJsKaMNH" }, "source": [ "#### 之后:使用 TF2 进行转换\n", "\n", "首先,将 TF1 冻结的 GraphDef 转换为 TF1 SavedModel,然后将其转换为 TFLite 模型,并设置较小的 v2 转换器标志。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Oigap0TZxjWG", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "## Convert TF1 frozen Graph to TF1 SavedModel.\n", "\n", "# Load the graph as a v1.GraphDef\n", "import pathlib\n", "gdef = tf.compat.v1.GraphDef()\n", "gdef.ParseFromString(pathlib.Path(GRAPH_DEF_MODEL_PATH).read_bytes())\n", "\n", "# Convert the GraphDef to a tf.Graph\n", "with tf.Graph().as_default() as g:\n", " tf.graph_util.import_graph_def(gdef, name=\"\")\n", "\n", "# Look up the input and output tensors.\n", "input_tensor = g.get_tensor_by_name('input:0') \n", "output_tensor = g.get_tensor_by_name('MobilenetV1/Predictions/Softmax:0')\n", "\n", "# Save the graph as a TF1 Savedmodel\n", "remove_dir('saved_model_3/')\n", "with tf.compat.v1.Session(graph=g) as s:\n", " tf.compat.v1.saved_model.simple_save(\n", " session=s,\n", " export_dir='saved_model_3/',\n", " inputs={'input':input_tensor},\n", " outputs={'output':output_tensor})\n", "\n", "# Convert TF1 SavedModel to a TFLite model.\n", "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir='saved_model_3/')\n", "converter.optimizations = {tf.lite.Optimize.DEFAULT}\n", "tflite_model = converter.convert()" ] }, { "cell_type": "markdown", "metadata": { "id": "MFbsddkOw4Wl" }, "source": [ "# 延伸阅读\n", "\n", "- 请参阅 [TFLite 指南](https://tensorflow.google.cn/lite/guide)来详细了解工作流和最新功能。\n", "- 如果您使用的是 TF1 代码或旧版 TF1 模型格式(Keras `.h5` 文件、冻结的 GraphDef `.pb` 等),请更新您的代码并将您的模型迁移到 [TF2 SavedModel 格式](https://tensorflow.google.cn/guide/saved_model)。\n" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "tflite.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }