AnnotateUsedMemory C++ 源码

AnnotateUsedMemory C++ 源码#

/*!
 * \brief Annotates the minimum required memory of each primitive function callsite by analyzing
 * the liveness of the input/output tensors at each function callsite and calculating the total
 * amount of memory these tensors require. This is added as a "used_memory" annotation to the
 * function in question as a list of the number of bytes for each callsite. In addition, the
 * containing function is annotated with an "io_used_memory" annotation which refers to the total
 * memory required for the IO tensors.
 *
 * Note: This pass does not support dynamic shapes, it is the users responsibility to check this
 * pass isn't applied where dynamic shapes may be input.
 */
TVM_DLL Pass AnnotateUsedMemory();

这段代码是用于分析每个原始函数调用站点所需的最小内存。它通过分析每个函数调用站点的输入/输出张量的活跃性,并计算这些张量所需的总内存来实现这一目标。这个信息被添加为一个名为 "used_memory" 的注解,以字节为单位列出每个调用站点所需的内存大小。此外,被注释为 "io_used_memory" 的函数,表示 IO 张量所需的总内存。

需要注意的是,此 Pass 不支持动态形状,用户需要自行检查是否在可能输入动态形状的情况下应用了此 Pass。

简单的例子:

修改前:

def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
  let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
    nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
  };
  let %x_1 = %x_0(%input);
  %x_1
}

修改后:

def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
  let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=[32]) -> Tensor[(1, 2, 2, 4), int8] {
    nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
  };
  let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
  %x_1
}

在上面的简单示例中,io_used_memoryused_memory 是相同的,因为只有一个原始函数。

class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
 public:
  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
                            const transform::LivenessAnalysis& lva)
      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}

  /*!
   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
   * added to the input function which refers to the total size required for the IO
   * tensors.
   */
  Function operator()(const Function& func) {
    uint64_t io_used_memory = 0;

    // Inputs
    for (const Var& param : func->params) {
      Type type = param->checked_type();
      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
      ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes.";
      io_used_memory += CalculateRelayExprSizeBytes(type);
    }

    // Outputs
    Type type = func->body->checked_type();
    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
    ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes.";
    io_used_memory += CalculateRelayExprSizeBytes(type);

    Expr new_func_body = VisitExpr(func->body);
    Function new_func = WithFields(func, func->params, new_func_body);
    return WithAttr(std::move(new_func), "io_used_memory",
                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
  }

  /*!
   * \brief Establish which let bindings have primitive function values.
   */
  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) override {
    if (const auto* func_node = value.as<FunctionNode>()) {
      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
          << "Expect top-level functions to be primitive.";
      let_bound_prim_func_.insert(var);
    }
    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
  }

  /*!
   * \brief Visit let nodes and perform one of two actions depending on their value:
   *
   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
   *               primitive functions.
   *
   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
   *                   previous analysis at the callsite.
   *
   */
  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
    Var let_var = post_let_node->var;
    Expr let_value = IgnoreOnDevice(post_let_node->value);

    if (let_value->IsInstance<CallNode>()) {
      Call callsite = Downcast<Call>(let_value);
      if (CheckPrimitiveFunctionCall(callsite)) {
        Var call_op = Downcast<Var>(callsite->op);

        // Find all the vars that are live at the callsite. This is done by merging the
        // in and out varset's and then removing the var that references the primitive
        // function itself since we don't want this included in the calculation.
        const transform::ControlFlowGraph::NodePtr cfg_node =
            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
        const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
        live_tensors.insert(live_out.begin(), live_out.end());
        live_tensors.erase(call_op);

        // Calculate size of live tensors and store to allow annotation when the function
        // gets visited.
        uint64_t used_memory = 0;
        for (const auto& var : live_tensors) {
          Type type = var->checked_type();
          ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
          ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes.";
          used_memory += CalculateRelayExprSizeBytes(type);
        }
        IntImm annotation(DataType::UInt(64), used_memory);
        used_memory_annotations_[call_op].push_back(annotation);
      }
    } else if (let_value->IsInstance<FunctionNode>()) {
      Function func = Downcast<Function>(let_value);
      ICHECK(used_memory_annotations_.find(let_var) != used_memory_annotations_.end())
          << "Could not find used_memory value for primitive function bound at "
          << let_var->name_hint();
      Array<IntImm> used_memory = used_memory_annotations_[let_var];
      used_memory_annotations_.erase(let_var);

      Function new_func = WithAttr(std::move(func), "used_memory",
                                   Array<IntImm>(used_memory.rbegin(), used_memory.rend()));
      return Let(let_var, new_func, post_let_node->body, post_let_node->span);
    }

    return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node);
  }

 private:
  /*!
   * \brief Check if a call is a primitive function callsite.
   */
  bool CheckPrimitiveFunctionCall(const Call& callsite) {
    if (auto var = callsite->op.as<Var>()) {
      if (let_bound_prim_func_.find(var.value()) != let_bound_prim_func_.end()) {
        return true;
      }
    }
    return false;
  }

  /*! \brief Control flow graph representation of the main function. */
  transform::ControlFlowGraph control_flow_graph_;
  /*! \brief Liveness analysis of the main function. */
  transform::LivenessAnalysis liveness_;
  /*! \brief Var's that reference primitive functions. */
  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> let_bound_prim_func_;
  /*! \brief Stores the calculated uint64 used_memory values so they can be annotated on the
   * relevant function. */
  std::unordered_map<Var, Array<IntImm>, ObjectPtrHash, ObjectPtrEqual> used_memory_annotations_;
};

这段代码定义了一个名为 AnnotateUsedMemoryMutator 的类,该类继承自 transform::DeviceAwareExprMutator。这个类的主要目的是在输入函数中添加一个名为 "io_used_memory" 的注解,该注解表示 IO 张量所需的总大小。

AnnotateUsedMemoryMutator 类有一个构造函数,它接受三个参数:IR 模块、控制流图和活跃度分析。这些参数用于初始化类的私有成员变量。

类中的 operator() 方法接受一个函数作为输入,并返回一个新的函数,其中包含了 "io_used_memory" 注解。这个方法首先计算输入和输出张量的总大小,然后遍历函数体中的表达式,对每个 let 节点进行处理。对于调用原始函数的 call 节点,它会计算活跃张量的大小并将结果存储起来;对于绑定原始函数的 function 节点,它会使用之前存储的结果来添加 "used_memory" 注解。

此外,类还包含一些辅助方法,如PreVisitLetBinding_PostVisitLet_CheckPrimitiveFunctionCall,这些方法用于在遍历过程中处理特定的 let 节点和调用站点。