初识 Relay 变换#
先从 Relay 变换,开始了解 TVM 的一些 FFI 机制。
研读源码,可以看出 tvm/src/relay/transforms/
定义了大量 Relay 变换实现。下面挑选 tvm/src/relay/transforms/div_to_mul.cc
中的 DivToMul
Pass,以了解 Relay 变换是如何生效的。
namespace tvm {
namespace relay {
namespace transform {
Pass DivToMul() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(DivToMulRewrite().Mutate(f));
};
return CreateFunctionPass(pass_func, 0, "DivToMul", {"InferType", "FoldConstant"});
}
// 注册到全局
TVM_REGISTER_GLOBAL("relay._transform.DivToMul").set_body_typed(DivToMul);
}
}
}
这里在名称空间 tvm::relay::transform
下定义变换函数 DivToMul()
并将其注册到全局。
namespace tvm {
namespace relay {
class DivToMulRewrite : public MixedModeMutator {
Expr Rewrite_(const CallNode* pre, const Expr& post) final {
if (const CallNode* call_node = post.as<CallNode>()) {
if (call_node->op == Op::Get("divide")) {
auto rhs = call_node->args[1].as<ConstantNode>();
if (rhs != nullptr) {
auto inv =
runtime::NDArray::Empty(rhs->data.Shape(), rhs->data.DataType(), rhs->data->device);
std::string dtype = DLDataType2String(rhs->data.DataType());
if (dtype == "float32") {
float rhs_val = static_cast<float*>(rhs->data->data)[0];
// Check for division by zero
if (rhs_val == 0.) {
return post;
}
static_cast<float*>(inv->data)[0] = 1. / rhs_val;
} else if (dtype == "float64") {
double rhs_val = static_cast<double*>(rhs->data->data)[0];
// Check for division by zero
if (rhs_val == 0.) {
return post;
}
static_cast<double*>(inv->data)[0] = 1. / rhs_val;
} else if (dtype == "float16") {
// Do f16 math in f32
float rhs_val = __gnu_h2f_ieee(static_cast<uint16_t*>(rhs->data->data)[0]);
// Check for division by zero
if (rhs_val == 0.) {
return post;
}
static_cast<uint16_t*>(inv->data)[0] = __gnu_f2h_ieee(1. / rhs_val);
} else {
// Cannot do 1/int because it will truncate
return post;
}
return Multiply(call_node->args[0], Constant(inv));
}
}
}
return post;
}
};
}
}
想要 Python 端使用,需要在 tvm/python/tvm/relay/transform/transform.py
中定义:
def DivToMul():
"""Transform division by a constant to multiplication by the inverse of the constant"""
return _ffi_api.DivToMul()
关键点就在于:_ffi_api
,即 tvm/python/tvm/relay/transform/_ffi_api.py
中的:
tvm._ffi._init_api("relay._transform", __name__)
Python 端测试代码见:除法转乘法。