解读 InferTypeLocal()

解读 InferTypeLocal()#

import set_env
import tvm
from tvm import relay
relay.transform.InferTypeLocal??
Signature: relay.transform.InferTypeLocal(expr)
Source:   
def InferTypeLocal(expr):
    """Infer the type of a single expr, reusing type information to do so.

    This populates the checked_type field in expr. We assume existing type information
    in the graph is correct!

    Parameters
    ----------
    expr: relay.Expr
        The expression we want to know the type of

    Returns
    -------
    type: relay.Type
        The type of the expression
    """
    return _ffi_api.InferTypeLocal(expr)
File:      /media/pc/data/lxw/ai/tvm/python/tvm/relay/transform/transform.py
Type:      function

InferTypeLocal() 函数的作用是推断单个 expr 的类型,并重用类型信息来实现这一点。它会填充表达式中的 checked_type 字段。我们假设计算图中现有的类型信息是正确的!

参数:

  • expr: relay.Expr,我们想要知道其类型的表达式

返回值:

  • type: relay.Type,表达式的类型

/*!
 * \brief Infer the type of an expression, reusing existing type information.
 *
 * The result of type checking is a new expression with unambiguous
 * type information filled in for the given node only. The local
 * version can use existing type information populated throughout
 * the expression and assumes this information is correct. The local
 * version also avoids examining large amounts of the graph assuming
 * type information is filled in properly which makes it much faster if we
 * iteratively call type inference.
 *
 * \return The type of the expression.
 */
TVM_DLL Type InferTypeLocal(const Expr& expr);

这个函数的作用是推断表达式的类型,并重用现有的类型信息。

类型检查的结果是一个新的表达式,其中给定节点的不明确的类型信息被填充。局部版本可以使用整个表达式中填充的现有类型信息,并假设这些信息是正确的。局部版本还避免了检查大量的计算图,假设类型信息被正确填充,如果我们迭代地调用类型推断,这会使其更快。

返回值:表达式的类型。

Type InferTypeLocal(const Expr& expr) {
  /*
  This type inference differs from InferType in that it uses existing type information
  to avoid recursing over much of the graph, and it only examines the type of the input
  node. This makes it faster if you need to run type inference iteratively throughout
  a pass for example.

  However, it assumes any existing populated type inference is correct! If some populated
  type inference is incorrect, an incorrect type may be returned or a type error will be
  raised. If you know not all populated type fields are correct with the current graph,
  you should use InferType() instead.
  */
  SameTypedSubgraphExtractor subgraph_extractor;
  Expr sub_graph = subgraph_extractor(expr);

  Type result_type;
  result_type = relay::InferType(sub_graph)->checked_type();

  expr->checked_type_ = result_type;
  return result_type;
}

TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal").set_body_typed([](const Expr& expr) {
  return InferTypeLocal(expr);
});

InferType 不同的是,它使用现有的类型信息来避免在计算图中递归遍历很多部分,并且只检查输入节点的类型。如果您需要在传递过程中迭代运行类型推断,这将使其更快。

但是,它假设任何现有填充的类型推断都是正确的!如果某些填充的类型推断是错误的,可能会返回错误类型的结果或引发类型错误。如果您知道当前计算图中并非所有填充的类型字段都是正确的,则应使用 InferType() 代替。

该函数首先创建 SameTypedSubgraphExtractor 对象,然后使用该对象从给定的表达式中提取子图。接下来,它调用 relay::InferType() 函数来推断子图的类型,并将结果存储在 result_type 变量中。最后,它将 result_type 分配给 expr 对象的 checked_type_ 属性,并将其作为函数的返回值返回。