analysis-get-calibration_data#

import testing
import numpy as np

import tvm
import tvm.relay.testing
from tvm import relay
from tvm.relay import transform
from tvm.relay.analysis import get_calibration_data

get_calibration_data() 的函数,它用于获取给定 relay 计算图的校准数据。校准数据包括每个函数的输入和输出值。返回的数据使用每个函数的 GlobalVar 作为键。用户可以通过使用 inputsoutputs 作为键来进一步访问输入和输出。

以下是一些限制:

  1. 输入模块(计算图)不能有控制流。

  2. 每个函数的输入参数不能是元组(输出可以是元组)。

  3. 只处理顶层(top-level)函数(即不处理嵌套函数)。

  4. 只处理具有 Compiler 属性设置的函数。

函数接受两个参数:

  • modtvm.IRModule 类型,用于收集校准数据的输入模块。

  • dataDict[str, NDArray] 类型,用于运行模块的输入数据。

函数返回一个字典,其键为 tvm.relay.GlobalVar,值为包含输入和输出数据的字典。

def check_data_size(mod, data):
    assert len(data) == len(mod.functions) - 1
    for key, value in mod.functions.items():
        if key.name_hint != "main":
            assert len(data[key]["inputs"]) == len(value.params)
            if isinstance(value.body, relay.Tuple):
                assert len(data[key]["outputs"]) == len(value.body.fields)
            else:
                assert len(data[key]["outputs"]) == 1

测试简单计算图校准数据#

# A module with two subgraphs
mod = tvm.IRModule()

x0 = relay.var("x0", shape=(8, 8))
y0 = relay.var("y0", shape=(8, 8))
z0 = x0 + y0
z1 = x0 - y0
z2 = relay.Tuple((z0, z1))
f0 = relay.Function([x0, y0], z2)
f0 = f0.with_attr("Compiler", "test_graph")
g0 = relay.GlobalVar("g0")
mod[g0] = f0
mod = relay.transform.InferType()(mod)

x1 = relay.var("x1", shape=(8, 8))
y1 = relay.var("y1", shape=(8, 8))
z1 = x1 - y1
f1 = relay.Function([x1, y1], z1)
f1 = f1.with_attr("Compiler", "test_graph")
g1 = relay.GlobalVar("g1")
mod[g1] = f1
mod = relay.transform.InferType()(mod)

x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
z = relay.var("z", shape=(8, 8))
c0 = relay.Call(g0, [x, y])
c1 = relay.Call(g1, [relay.TupleGetItem(c0, 0), z])
fm = relay.Function([x, y, z], c1)
mod["main"] = fm
mod = relay.transform.InferType()(mod)

x_data = np.random.rand(8, 8).astype("float32")
y_data = np.random.rand(8, 8).astype("float32")
z_data = np.random.rand(8, 8).astype("float32")
mod.show()
def @g0(%x0: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y0: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, Compiler="test_graph") -> (Tensor[(8, 8), float32], Tensor[(8, 8), float32]) {
  %0 = add(%x0, %y0) /* ty=Tensor[(8, 8), float32] */;
  %1 = subtract(%x0, %y0) /* ty=Tensor[(8, 8), float32] */;
  (%0, %1) /* ty=(Tensor[(8, 8), float32], Tensor[(8, 8), float32]) */
}

def @g1(%x1: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y1: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, Compiler="test_graph") -> Tensor[(8, 8), float32] {
  subtract(%x1, %y1) /* ty=Tensor[(8, 8), float32] */
}

def @main(%x: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %z: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */) -> Tensor[(8, 8), float32] {
  %2 = @g0(%x, %y) /* ty=(Tensor[(8, 8), float32], Tensor[(8, 8), float32]) */;
  %3 = %2.0 /* ty=Tensor[(8, 8), float32] */;
  @g1(%3, %z) /* ty=Tensor[(8, 8), float32] */
}
data = get_calibration_data(mod, {"x": x_data, "y": y_data, "z": z_data})
data.keys()
dict_keys([I.GlobalVar("g0"), I.GlobalVar("g1")])
# Check the number and orders
check_data_size(mod, data)
tvm.testing.assert_allclose(data[g0]["inputs"][0].numpy(), x_data)
tvm.testing.assert_allclose(data[g0]["inputs"][1].numpy(), y_data)
tvm.testing.assert_allclose(data[g0]["outputs"][0].numpy(), x_data + y_data)
tvm.testing.assert_allclose(data[g0]["outputs"][1].numpy(), x_data - y_data)
tvm.testing.assert_allclose(data[g1]["inputs"][0].numpy(), x_data + y_data)
tvm.testing.assert_allclose(data[g1]["inputs"][1].numpy(), z_data)
tvm.testing.assert_allclose(data[g1]["outputs"][0].numpy(), x_data + y_data - z_data)
from tvm.relay.analysis import _ffi_api
output_map = _ffi_api.get_calibrate_output_map(mod)
run_mod = _ffi_api.get_calibrate_module(mod)
output_map
{I.GlobalVar("g0"): [0, 2, 2], I.GlobalVar("g1"): [4, 2, 1]}
run_mod.show()
def @g0(%x0: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y0: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, Compiler=None, Inline=1) -> (Tensor[(8, 8), float32], Tensor[(8, 8), float32]) {
  %0 = add(%x0, %y0) /* ty=Tensor[(8, 8), float32] */;
  %1 = subtract(%x0, %y0) /* ty=Tensor[(8, 8), float32] */;
  (%0, %1) /* ty=(Tensor[(8, 8), float32], Tensor[(8, 8), float32]) */
}

def @g1(%x1: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y1: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, Compiler=None, Inline=1) -> Tensor[(8, 8), float32] {
  subtract(%x1, %y1) /* ty=Tensor[(8, 8), float32] */
}

def @main(%x: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %z: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */) {
  %2 = @g0(%x, %y) /* ty=(Tensor[(8, 8), float32], Tensor[(8, 8), float32]) */;
  %3 = %2.0 /* ty=Tensor[(8, 8), float32] */;
  %4 = %2.0;
  %5 = %2.1;
  %6 = @g1(%3, %z) /* ty=Tensor[(8, 8), float32] */;
  (%x, %y, %4, %5, %3, %z, %6)
}
mod.show()
def @g0(%x0: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y0: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, Compiler="test_graph") -> (Tensor[(8, 8), float32], Tensor[(8, 8), float32]) {
  %0 = add(%x0, %y0) /* ty=Tensor[(8, 8), float32] */;
  %1 = subtract(%x0, %y0) /* ty=Tensor[(8, 8), float32] */;
  (%0, %1) /* ty=(Tensor[(8, 8), float32], Tensor[(8, 8), float32]) */
}

def @g1(%x1: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y1: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, Compiler="test_graph") -> Tensor[(8, 8), float32] {
  subtract(%x1, %y1) /* ty=Tensor[(8, 8), float32] */
}

def @main(%x: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %z: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */) -> Tensor[(8, 8), float32] {
  %2 = @g0(%x, %y) /* ty=(Tensor[(8, 8), float32], Tensor[(8, 8), float32]) */;
  %3 = %2.0 /* ty=Tensor[(8, 8), float32] */;
  @g1(%3, %z) /* ty=Tensor[(8, 8), float32] */
}

测试 DNNL 校准数据#

def test_mobilenet_dnnl():
    if not tvm.get_global_func("relay.ext.dnnl", True):
        print("skip because DNNL codegen is not available")
        return

    dtype = "float32"
    ishape = (1, 3, 224, 224)
    mod, params = relay.testing.mobilenet.get_workload(batch_size=1, dtype="float32")

    mod = transform.AnnotateTarget(["dnnl"])(mod)
    mod = transform.MergeCompilerRegions()(mod)
    mod = transform.PartitionGraph()(mod)

    i_data = np.random.uniform(0, 1, ishape).astype(dtype)
    data = get_calibration_data(mod, {"data": i_data, **params})

    # Check the number and orders
    check_data_size(mod, data)

src/relay/analysis/get_calibration_data.cc#

为了获取校准数据,需要进行两个步骤:

  1. 准备生成张量值的模块(GetCalibrateModule)。

  2. 生成值与函数之间的映射(GetCalibrateOutputMap)。

/*!
 * \brief This function returns a module that will be used by
 * the relay graph executor for collecting the calibration data.
 * To do that, we first make all inputs and outputs of each
 * function into the final output (i.e., the final output is a
 * tuple of tensors). Then, we change the compiler attribute of
 * each function. Finally, we mark all function to be inlined.
 */

class Collector : public ExprRewriter {
 public:
  explicit Collector(const IRModule& module) : module_(module) {}

  Expr Rewrite_(const CallNode* call, const Expr& post) final {
    // check if the function implementation is available
    // intrinsic functions are excluded for now
    if (call->op->IsInstance<GlobalVarNode>()) {
      auto var = Downcast<GlobalVar>(call->op);
      ICHECK(module_->ContainGlobalVar(var->name_hint)) << "Function " << var << " is not defined";
      // we only handle functions with Compiler attribute set
      auto func = Downcast<Function>(module_->Lookup(var));
      if (func->GetAttr<String>(attr::kCompiler)) {
        // collect all the inputs and outputs
        for (const auto& it : call->args) new_outputs_.push_back(it);
        new_outputs_.push_back(post);
      }
    }
    return post;
  }

  Array<Expr> GetNewOutputs() { return new_outputs_; }

 private:
  const IRModule& module_;
  Array<Expr> new_outputs_;
};

Collector 类,它继承自 ExprRewriterCollector 类的主要目的是收集校准数据,以便在 relay 计算图执行器中使用。为了实现这个目标,Collector 类首先将所有函数的输入和输出转换为最终输出(即一个张量元组)。然后,它更改每个函数的编译器属性。最后,它将所有函数标记为内联。

Collector 类有一个构造函数,接受 IRModule 类型的参数。它还重写了 ExprRewriter 类的 Rewrite_ 方法,该方法用于处理 CallNode 类型的表达式。在 Rewrite_ 方法中,首先检查函数实现是否可用,排除了 intrinsic 函数。然后,只处理设置了 Compiler 属性的函数。对于这些函数,它会收集所有的输入和输出,并将它们添加到 new_outputs_ 数组中。最后,GetNewOutputs 方法返回收集到的新输出数组。

Expr FlattenOutputTuple(const Array<Expr>& exprs) {
  Array<Expr> fields;
  for (const auto& it : exprs) {
    ICHECK(it->checked_type_.defined());
    if (auto* tn = it->checked_type_.as<TupleTypeNode>()) {
      // TODO(seanlatias): for now input argument cannot be a tuple
      ICHECK(it->IsInstance<CallNode>());
      for (size_t i = 0; i < tn->fields.size(); i++) {
        fields.push_back(TupleGetItem(it, i));
      }
    } else {
      fields.push_back(it);
    }
  }
  return Tuple(fields);
}

IRModule GetCalibrateModule(IRModule module) {
  auto glob_funcs = module->functions;
  // module is mutable, hence, we make a copy of it.
  module.CopyOnWrite();
  for (const auto& pair : glob_funcs) {
    if (auto opt = pair.second.as<Function>()) {
      // we only collect the outputs for main function
      if (pair.first->name_hint == "main") {
        auto func = opt.value();
        Collector collector(module);
        PostOrderRewrite(func->body, &collector);
        auto new_outputs = collector.GetNewOutputs();
        Expr tuple = FlattenOutputTuple(new_outputs);
        func = Function(func->params, tuple, tuple->checked_type_, func->type_params, func->attrs);
        module->Update(pair.first, func);
      }
    }
  }
  // reset the attribute of functions for running graph executor
  for (const auto& pair : glob_funcs) {
    if (auto opt = pair.second.as<Function>()) {
      auto func = opt.value();
      if (func->GetAttr<String>(attr::kCompiler)) {
        // we need to inline the functions in order to run grpah runtime
        func = WithAttr(std::move(func), attr::kInline, tvm::Integer(1));
        // reset the compiler attribute to null for llvm execution
        func = WithAttr(std::move(func), attr::kCompiler, NullValue<ObjectRef>());
        module->Update(pair.first, func);
      }
    }
  }
  return module;
}

这段代码定义了两个函数:FlattenOutputTupleGetCalibrateModule

FlattenOutputTuple 函数接受 Expr 类型的数组作为参数,然后遍历这个数组。对于每个元素,它首先检查元素的类型是否已定义。如果元素是一个元组类型,它会将元组中的每个元素添加到一个新的 fields 数组中。如果元素不是元组类型,它会直接将元素添加到 fields 数组中。最后,函数返回一个包含 fields 数组所有元素的新元组。

GetCalibrateModule 函数接受 IRModule 类型的参数,并对其进行处理以生成校准模块。首先,它获取模块中的所有全局函数。然后,它创建可写的副本,以便在不改变原始模块的情况下进行修改。接下来,它遍历所有全局函数,只处理名为 "main" 的函数。对于每个符合条件的函数,它使用 Collector 类收集函数体中的所有输出,并将这些输出展平为一个元组。然后,它用新的输出元组替换原来的函数体,并将更新后的函数放回模块中。最后,它重置所有具有编译器属性的函数的属性,以便在运行图执行器时可以内联这些函数,并将编译器属性设置为 null 以便于 LLVM 执行。函数返回处理后的模块。

/*!
 * \brief This function generates the output mapping between
 * the calibration data and each function. The key is a
 * GlobalVar that corresponds to each function and the value
 * is an array of integers. The size of the array is always
 * three. The first value is the offset the points to the start.
 * The second value is the number of inputs. The third value
 * is the number of outputs.
 */

class OutputMapper : public ExprRewriter {
 public:
  OutputMapper(Map<GlobalVar, Array<Integer>>* output_map, const IRModule& module, size_t* offset)
      : output_map_(output_map), module_(module), offset_(offset) {}

  Expr Rewrite_(const CallNode* call, const Expr& post) final {
    if (call->op->IsInstance<GlobalVarNode>()) {
      auto var = Downcast<GlobalVar>(call->op);
      ICHECK(module_->ContainGlobalVar(var->name_hint)) << "Function " << var << " is not defined";
      ICHECK_EQ(output_map_->count(var), 0)
          << "Repeated function call " << var << " is not supported.";
      auto func = Downcast<Function>(module_->Lookup(var));
      // we only handle functions with Compiler attribute set
      if (func->GetAttr<String>(attr::kCompiler)) {
        Array<Integer> info;
        // the first value is the offset
        info.push_back(Integer(*offset_));
        // the second value is the number of inputs
        info.push_back(Integer(call->args.size()));
        // the third value is the number of outputs
        // we need to check if the output is a tuple
        size_t out_size = 1;
        if (auto* tn = func->body.as<TupleNode>()) {
          info.push_back(Integer(tn->fields.size()));
          out_size = tn->fields.size();
        } else {
          info.push_back(Integer(1));
        }
        output_map_->Set(var, info);
        // calculate the offset for the next function
        *offset_ = *offset_ + call->args.size() + out_size;
      }
    }
    return post;
  }

 private:
  Map<GlobalVar, Array<Integer>>* output_map_;
  const IRModule& module_;
  size_t* offset_;
};

Map<GlobalVar, Array<Integer>> GetCalibrateOutputMap(const IRModule& module) {
  Map<GlobalVar, Array<Integer>> output_map;
  size_t offset = 0;
  auto glob_funcs = module->functions;
  for (const auto& pair : glob_funcs) {
    if (const auto* func = pair.second.as<FunctionNode>()) {
      if (pair.first->name_hint == "main") {
        OutputMapper output_mapper(&output_map, module, &offset);
        PostOrderRewrite(func->body, &output_mapper);
      }
    }
  }

  return output_map;
}

TVM_REGISTER_GLOBAL("relay.analysis.get_calibrate_module").set_body_typed([](IRModule mod) {
  return GetCalibrateModule(mod);
});

TVM_REGISTER_GLOBAL("relay.analysis.get_calibrate_output_map")
    .set_body_typed([](const IRModule& mod) { return GetCalibrateOutputMap(mod); });

OutputMapper 类,它继承自 ExprRewriterOutputMapper 的主要功能是生成校准数据与每个函数之间的输出映射。映射的键是 GlobalVar,对应于每个函数;值是一个整数数组,大小始终为 3。第一个值是指向起始位置的偏移量;第二个值是输入的数量;第三个值是输出的数量。

GetCalibrateOutputMap 函数接受 IRModule 类型的参数,并返回一个包含校准数据的映射。它首先创建一个空的 output_map 和一个初始偏移量 offset。然后,遍历模块中的所有全局函数,只处理名为 "main" 的函数。对于每个符合条件的函数,使用 OutputMapper 类收集函数体中的所有输出,并将这些输出展平为一个元组。最后,将更新后的函数放回模块中。

全局函数:relay.analysis.get_calibrate_modulerelay.analysis.get_calibrate_output_map 分别用于获取校准模块和校准输出映射。