tvm.relay.analysis.call_graph 源代码
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Call graph used in Relay."""
from ...ir import IRModule
from ...runtime import Object
from ..expr import GlobalVar
from . import _ffi_api
[文档]
class CallGraph(Object):
"""Class to represent a call graph."""
def __init__(self, module):
"""Construct a call graph.
Parameters
----------
module : tvm.ir.IRModule
The IR module used to create a call graph
Returns
-------
call_graph: CallGraph
A constructed call graph.
"""
self.__init_handle_by_constructor__(_ffi_api.CallGraph, module)
@property
def module(self):
"""Return the contained Relay IR module.
Parameters
----------
None
Returns
-------
ret : tvm.ir.IRModule
The contained IRModule
"""
return _ffi_api.GetModule(self)
[文档]
def ref_count(self, var):
"""Return the number of references to the global var
Parameters
----------
var : Union[String, tvm.relay.GlobalVar]
Returns
-------
ret : int
The number reference to the global var
"""
var = self._get_global_var(var)
return _ffi_api.GetRefCountGlobalVar(self, var)
[文档]
def global_call_count(self, var):
"""Return the number of global function calls from a given global var.
Parameters
----------
var : Union[String, tvm.relay.GlobalVar]
Returns
-------
ret : int
The number of global function calls from the given var.
"""
var = self._get_global_var(var)
return _ffi_api.GetGlobalVarCallCount(self, var)
[文档]
def is_recursive(self, var):
"""Return if the function corresponding to a var is a recursive
function.
Parameters
----------
var : Union[String, tvm.relay.GlobalVar]
Returns
-------
ret : Boolean
If the function corresponding to var is recurisve.
"""
var = self._get_global_var(var)
return _ffi_api.IsRecursive(self, var)
def _get_global_var(self, var):
"""Return the global var using a given name or GlobalVar.
Parameters
----------
var : Union[String, tvm.relay.GlobalVar]
Returns
-------
ret : tvm.relay.GlobalVar
The global var.
"""
if isinstance(var, str):
mod = self.module
var = mod.get_global_var(var)
if isinstance(var, GlobalVar):
return var
else:
raise TypeError("var should be either a string or GlobalVar")
[文档]
def print_var(self, var):
"""Print a call graph of a global function by name or by variable.
Parameters
----------
var: Union[String, tvm.relay.GlobalVar]
The name or global variable.
Returns
-------
ret : String
The call graph represented in string.
"""
var = self._get_global_var(var)
return _ffi_api.PrintCallGraphGlobalVar(self, var)
def __str__(self):
"""Print the call graph in the topological order."""
return _ffi_api.PrintCallGraph(self)