AnnotateUsedMemory C++ 源码#
/*!
* \brief Annotates the minimum required memory of each primitive function callsite by analyzing
* the liveness of the input/output tensors at each function callsite and calculating the total
* amount of memory these tensors require. This is added as a "used_memory" annotation to the
* function in question as a list of the number of bytes for each callsite. In addition, the
* containing function is annotated with an "io_used_memory" annotation which refers to the total
* memory required for the IO tensors.
*
* Note: This pass does not support dynamic shapes, it is the users responsibility to check this
* pass isn't applied where dynamic shapes may be input.
*/
TVM_DLL Pass AnnotateUsedMemory();
这段代码是用于分析每个原始函数调用站点所需的最小内存。它通过分析每个函数调用站点的输入/输出张量的活跃性,并计算这些张量所需的总内存来实现这一目标。这个信息被添加为一个名为 "used_memory"
的注解,以字节为单位列出每个调用站点所需的内存大小。此外,被注释为 "io_used_memory"
的函数,表示 IO 张量所需的总内存。
需要注意的是,此 Pass 不支持动态形状,用户需要自行检查是否在可能输入动态形状的情况下应用了此 Pass。
简单的例子:
修改前:
def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
};
let %x_1 = %x_0(%input);
%x_1
}
修改后:
def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=[32]) -> Tensor[(1, 2, 2, 4), int8] {
nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
};
let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
%x_1
}
在上面的简单示例中,io_used_memory
和 used_memory
是相同的,因为只有一个原始函数。
class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
public:
AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
const transform::LivenessAnalysis& lva)
: DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
/*!
* \brief Mutates the input function. In addition, an "io_used_memory" annotation is
* added to the input function which refers to the total size required for the IO
* tensors.
*/
Function operator()(const Function& func) {
uint64_t io_used_memory = 0;
// Inputs
for (const Var& param : func->params) {
Type type = param->checked_type();
ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes.";
io_used_memory += CalculateRelayExprSizeBytes(type);
}
// Outputs
Type type = func->body->checked_type();
ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes.";
io_used_memory += CalculateRelayExprSizeBytes(type);
Expr new_func_body = VisitExpr(func->body);
Function new_func = WithFields(func, func->params, new_func_body);
return WithAttr(std::move(new_func), "io_used_memory",
tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
}
/*!
* \brief Establish which let bindings have primitive function values.
*/
std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) override {
if (const auto* func_node = value.as<FunctionNode>()) {
ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
<< "Expect top-level functions to be primitive.";
let_bound_prim_func_.insert(var);
}
return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
}
/*!
* \brief Visit let nodes and perform one of two actions depending on their value:
*
* 1. CallNode - Calculate "used_memory" annotation value at the callsite of
* primitive functions.
*
* 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
* previous analysis at the callsite.
*
*/
Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
Var let_var = post_let_node->var;
Expr let_value = IgnoreOnDevice(post_let_node->value);
if (let_value->IsInstance<CallNode>()) {
Call callsite = Downcast<Call>(let_value);
if (CheckPrimitiveFunctionCall(callsite)) {
Var call_op = Downcast<Var>(callsite->op);
// Find all the vars that are live at the callsite. This is done by merging the
// in and out varset's and then removing the var that references the primitive
// function itself since we don't want this included in the calculation.
const transform::ControlFlowGraph::NodePtr cfg_node =
control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
live_tensors.insert(live_out.begin(), live_out.end());
live_tensors.erase(call_op);
// Calculate size of live tensors and store to allow annotation when the function
// gets visited.
uint64_t used_memory = 0;
for (const auto& var : live_tensors) {
Type type = var->checked_type();
ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes.";
used_memory += CalculateRelayExprSizeBytes(type);
}
IntImm annotation(DataType::UInt(64), used_memory);
used_memory_annotations_[call_op].push_back(annotation);
}
} else if (let_value->IsInstance<FunctionNode>()) {
Function func = Downcast<Function>(let_value);
ICHECK(used_memory_annotations_.find(let_var) != used_memory_annotations_.end())
<< "Could not find used_memory value for primitive function bound at "
<< let_var->name_hint();
Array<IntImm> used_memory = used_memory_annotations_[let_var];
used_memory_annotations_.erase(let_var);
Function new_func = WithAttr(std::move(func), "used_memory",
Array<IntImm>(used_memory.rbegin(), used_memory.rend()));
return Let(let_var, new_func, post_let_node->body, post_let_node->span);
}
return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node);
}
private:
/*!
* \brief Check if a call is a primitive function callsite.
*/
bool CheckPrimitiveFunctionCall(const Call& callsite) {
if (auto var = callsite->op.as<Var>()) {
if (let_bound_prim_func_.find(var.value()) != let_bound_prim_func_.end()) {
return true;
}
}
return false;
}
/*! \brief Control flow graph representation of the main function. */
transform::ControlFlowGraph control_flow_graph_;
/*! \brief Liveness analysis of the main function. */
transform::LivenessAnalysis liveness_;
/*! \brief Var's that reference primitive functions. */
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> let_bound_prim_func_;
/*! \brief Stores the calculated uint64 used_memory values so they can be annotated on the
* relevant function. */
std::unordered_map<Var, Array<IntImm>, ObjectPtrHash, ObjectPtrEqual> used_memory_annotations_;
};
这段代码定义了一个名为 AnnotateUsedMemoryMutator
的类,该类继承自 transform::DeviceAwareExprMutator
。这个类的主要目的是在输入函数中添加一个名为 "io_used_memory"
的注解,该注解表示 IO 张量所需的总大小。
AnnotateUsedMemoryMutator
类有一个构造函数,它接受三个参数:IR 模块、控制流图和活跃度分析。这些参数用于初始化类的私有成员变量。
类中的 operator()
方法接受一个函数作为输入,并返回一个新的函数,其中包含了 "io_used_memory"
注解。这个方法首先计算输入和输出张量的总大小,然后遍历函数体中的表达式,对每个 let
节点进行处理。对于调用原始函数的 call
节点,它会计算活跃张量的大小并将结果存储起来;对于绑定原始函数的 function 节点,它会使用之前存储的结果来添加 "used_memory"
注解。
此外,类还包含一些辅助方法,如PreVisitLetBinding_
、PostVisitLet_
和CheckPrimitiveFunctionCall
,这些方法用于在遍历过程中处理特定的 let
节点和调用站点。