TensorFlow 前端#

TensorFlow 前端有助于将 TensorFlow 模型导入 TVM。

受支持版本:

  • 1.12 and below

测试模型:

  • Inception (V1/V2/V3/V4)

  • Resnet (All)

  • Mobilenet (V1/V2 All)

  • Vgg (16/19)

  • BERT (Base/3-layer)

为推理准备模型#

移除不必要的节点#

导出过程将删除许多在推理中不需要的节点,但不幸的是会留下一些剩余的节点。应该手动删除的节点有:

将 None 尺寸的值转换为常量#

TVM 对动态张量形状的支持非常有限。应将 None 的尺寸替换为常量。例如,模型可能会接受形状为 (None,20) 的输入。这应该转换为类似于 (1,20) 的形状。应相应地修改模型,以确保这些形状在整个图中匹配。

导出#

TensorFlow 前端期望冻结的 protobuf(.pb)或 saved_model 的模型作为输入。它目前不支持 checkpoint(.ckpt)。TensorFlow 前端需要的 graphdef 可以从活动的 sess 中提取,或者使用 TFParser 辅助类来获取。

为了准备模型进行推断,应使用多种转换方式导出模型。同时,设置 `add_shapes=True` 非常重要,因为这会将每个节点的输出形状嵌入到图中。以下是一种导出模型的 protobuf 函数,只需提供会话即可:”

import tensorflow as tf
from tensorflow.tools.graph_transforms import TransformGraph

def export_pb(session):
    with tf.gfile.GFile("myexportedmodel.pb", "wb") as f:
        inputs = ["myinput1", "myinput2"] # replace with your input names
        outputs = ["myoutput1"] # replace with your output names
        graph_def = session.graph.as_graph_def(add_shapes=True)
        graph_def = tf.graph.util.convert_variables_to_constants(session, graph_def, outputs)
        graph_def = TransformGraph(
            graph_def,
            inputs,
            outputs,
            [
                "remove_nodes(op=Identity, op=CheckNumerics, op=StopGradient)",
                "sort_by_execution_order", # sort by execution order after each transform to ensure correct node ordering
                "remove_attribute(attribute_name=_XlaSeparateCompiledGradients)",
                "remove_attribute(attribute_name=_XlaCompile)",
                "remove_attribute(attribute_name=_XlaScope)",
                "sort_by_execution_order",
                "remove_device",
                "sort_by_execution_order",
                "fold_batch_norms",
                "sort_by_execution_order",
                "fold_old_batch_norms",
                "sort_by_execution_order"
            ]
        )
        f.write(graph_def.SerializeToString())

另一种方法是 导出并冻结图

导入模型#

明确的形状(Explicit Shape):#

为确保形状可以在整个图中被知晓,在调用 `from_tensorflow` 时请传递 `shape` 参数。该字典将输入名称映射到输入形状。请参考这些测试用例获取示例。

数据布局#

大多数 TensorFlow 模型都以 NHWC 布局发布。NCHW 布局通常可以提供更好的性能,尤其是在 GPU 上。TensorFlow 前端可以通过向 `from_tensorflow` 传递参数 `layout='NCHW'` 来自动转换模型的数据布局。

最佳实践#

  • 使用静态张量形状而不是动态形状(移除 `None` 维度)。

  • 使用静态 RNN 而不是动态 RNN,因为 `TensorArray` 目前还不受支持。

受支持算子#

  • Abs

  • Add

  • AddN

  • All

  • Any

  • ArgMax

  • ArgMin

  • AvgPool

  • BatchMatMul

  • BatchMatMulV2

  • BatchNormWithGlobalNormalization

  • BatchToSpaceND

  • BiasAdd

  • BroadcastTo

  • Cast

  • Ceil

  • CheckNumerics

  • ClipByValue

  • Concat

  • ConcatV2

  • Conv2D

  • Cos

  • Tan

  • CropAndResize

  • DecodeJpeg

  • DepthwiseConv2dNative

  • DepthToSpace

  • Dilation2D

  • Equal

  • Elu

  • Enter

  • Erf

  • Exit

  • Exp

  • ExpandDims

  • Fill

  • Floor

  • FloorDiv

  • FloorMod

  • FusedBatchNorm

  • FusedBatchNormV2

  • Gather

  • GatherNd

  • GatherV2

  • Greater

  • GreaterEqual

  • Identity

  • IsFinite

  • IsInf

  • IsNan

  • LeakyRelu

  • LeftShift

  • Less

  • LessEqual

  • Log

  • Log1p

  • LoopCond

  • LogicalAnd

  • LogicalOr

  • LogicalNot

  • LogSoftmax

  • LRN

  • LSTMBlockCell

  • MatMul

  • Max

  • MaxPool

  • Maximum

  • Mean

  • Merge

  • Min

  • Minimum

  • MirrorPad

  • Mod

  • Mul

  • Neg

  • NextIteration

  • NotEqual

  • OneHot

  • Pack

  • Pad

  • PadV2

  • Pow

  • Prod

  • Range

  • Rank

  • RealDiv

  • Relu

  • Relu6

  • Reshape

  • ResizeBilinear

  • ResizeBicubic

  • ResizeNearestNeighbor

  • ReverseV2

  • RightShift

  • Round

  • Rsqrt

  • Select

  • Selu

  • Shape

  • Sigmoid

  • Sign

  • Sin

  • Size

  • Slice

  • Softmax

  • Softplus

  • SpaceToBatchND

  • SpaceToDepth,

  • Split

  • SplitV

  • Sqrt

  • Square

  • SquareDifference

  • Squeeze

  • StridedSlice

  • Sub

  • Sum

  • Switch

  • Tanh

  • TensorArrayV3

  • TensorArrayScatterV3

  • TensorArrayGatherV3

  • TensorArraySizeV3

  • TensorArrayWriteV3

  • TensorArrayReadV3

  • TensorArraySplitV3

  • TensorArrayConcatV3

  • Tile

  • TopKV2

  • Transpose

  • TruncateMod

  • Unpack

  • UnravelIndex

  • Where

  • ZerosLike