tvm.relay.analysis.count_layers 源代码
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Utilities that enable counting the number of layers in a graph."""
import tvm
from tvm import relay
from ..expr_functor import ExprVisitor
class LayerCounter(ExprVisitor):
"""A visitor pass that computes the deepest chain of specified ops in graph."""
def __init__(self, valid_ops):
self.depth_count = 0
self.deepest_count = 0
self.valid_ops = [relay.op.get(op) for op in valid_ops]
super().__init__()
def visit_call(self, call):
if call.op in self.valid_ops:
self.depth_count += 1
current_count = self.depth_count
self.deepest_count = max(self.deepest_count, current_count)
for arg in call.args:
self.visit(arg)
self.depth_count = current_count
def count(self):
return self.deepest_count
[文档]
def count_layers(expr, valid_ops):
"""Determine the number of layers of specified ops in a graph.
This pass computes only the deepest chain of ops rather than the
total number of ops in a graph. Thus, if there are two parallel
convolutions (for example), they would be considered a single layer.
Parameters
----------
expr : tvm.relay.Expr, tvm.relay.Function, or tvm.ir.IRModule.
The input expression.
valid_ops: List[str]
A list of the operations that should be included in the count.
Returns
-------
layer_count : int
The number of layers of the specified operations found in the graph.
"""
if isinstance(expr, tvm.ir.IRModule):
expr = expr["main"]
count_pass = LayerCounter(valid_ops)
count_pass.visit(expr)
return count_pass.count()