TVM 自动量化校准#
参考:tvm/src/relay/quantize/calibrate.cc
和 tvm/python/tvm/relay/quantize/_calibrate.py
%cd ..
import testing
/media/pc/data/lxw/ai/tvm-book/doc/read/relay
// KL divergence minimization code is adapted from MXNet.
// The original one is in incubator-mxnet/src/operator/quantization/calibrate.cc
static std::vector<float> SmoothDistribution(const std::vector<float>& p,
const float eps = 0.0001) {
std::vector<size_t> is_zeros(p.size());
std::vector<size_t> is_nonzeros(p.size());
{
auto it = p.begin();
std::generate(is_zeros.begin(), is_zeros.end(),
[&it]() { return static_cast<size_t>(*(it++) == 0.f); });
}
{
auto it = p.begin();
std::generate(is_nonzeros.begin(), is_nonzeros.end(),
[&it]() { return static_cast<size_t>(*(it++) != 0.f); });
}
size_t n_zeros = std::accumulate(is_zeros.begin(), is_zeros.end(), 0);
size_t n_nonzeros = p.size() - n_zeros;
if (!n_nonzeros) {
// The discrete probability distribution is malformed. All entries are 0.
return std::vector<float>();
}
float eps1 = eps * static_cast<float>(n_zeros) / static_cast<float>(n_nonzeros);
if (eps1 >= 1.0) return std::vector<float>();
auto ret = p;
for (size_t i = 0; i < p.size(); i++) {
ret[i] += eps * is_zeros[i] - eps1 * is_nonzeros[i];
}
return ret;
}
这段代码实现了平滑离散概率分布函数(SmoothDistribution),用于最小化 KL 散度。该函数接受浮点数向量 p
作为输入,并返回平滑后的浮点数向量。
具体实现过程如下:
首先定义两个大小为
p.size()
的整数向量is_zeros
和is_nonzeros
,分别用于记录p
中每个元素是否为0
或非0
。使用
std::generate
函数生成is_zeros
和is_nonzeros
向量,其中is_zeros[i]
表示p[i]
是否为0
,is_nonzeros[i]
表示p[i]
是否非0
。计算
p
中0
的个数n_zeros
和非 0 的个数n_nonzeros
。如果
n_nonzeros
为0
,说明离散概率分布格式有误,所有元素都为0
,直接返回空向量。计算
eps1
,即eps
乘以n_zeros
除以n_nonzeros
。如果eps1
大于等于1.0
,也直接返回空向量。定义新的向量
ret
,将p
的值复制到ret
中。遍历
p
中的每个元素,根据is_zeros
和is_nonzeros
的值对ret
进行更新,最终得到平滑后的离散概率分布。返回平滑后的向量
ret
。
static float ComputeEntropy(float* p, float* q, size_t size) {
float p_sum = std::accumulate(p, p + size, 0.f);
float q_sum = std::accumulate(q, q + size, 0.f);
float ret = 0;
for (size_t i = 0; i < size; i++) {
ICHECK(p[i] > 0 && q[i] > 0);
p[i] /= p_sum;
q[i] /= q_sum;
if (p[i] && q[i]) ret += p[i] * std::log(p[i] / q[i]);
}
return ret;
}
这段代码是一个计算信息熵的函数,输入参数为两个浮点数数组 p
和 q
以及它们的大小 size
。函数首先计算 p
和 q
的元素之和,然后遍历数组,对每个元素进行归一化处理,并计算信息熵。最后返回计算得到的信息熵值。
float MinimizeKL(const std::vector<int>& hist, const std::vector<float>& hist_edges, int num_bins,
int num_quantized_bins) {
const int zero_bin_idx = num_bins / 2;
const int num_half_quantized_bins = num_quantized_bins / 2;
std::vector<float> thresholds(num_bins / 2 + 1 - num_quantized_bins / 2, 0.f);
std::vector<float> divergence(thresholds.size(), 0.f);
std::vector<float> quantized_bins(num_quantized_bins, 0);
for (int i = num_quantized_bins / 2; i < zero_bin_idx + 1; ++i) {
const int p_bin_idx_start = zero_bin_idx - i;
const int p_bin_idx_stop = zero_bin_idx + i + 1;
thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop];
std::vector<int> sliced_nd_hist(p_bin_idx_stop - p_bin_idx_start);
std::vector<float> p(sliced_nd_hist.size());
p[0] = 0;
p.back() = 0;
for (int j = 0; j < num_bins; j++) {
if (j <= p_bin_idx_start) {
p[0] += hist[j];
} else if (j >= p_bin_idx_stop) {
p.back() += hist[j];
} else {
sliced_nd_hist[j - p_bin_idx_start] = hist[j];
p[j - p_bin_idx_start] = hist[j];
}
}
// calculate how many bins should be merged to generate quantized distribution q
const auto num_merged_bins = sliced_nd_hist.size() / num_quantized_bins;
for (int j = 0; j < num_quantized_bins; j++) {
const int start = j * num_merged_bins;
const int stop = (j + 1) * num_merged_bins;
quantized_bins[j] =
std::accumulate(sliced_nd_hist.begin() + start, sliced_nd_hist.begin() + stop, 0);
}
quantized_bins.back() += std::accumulate(
sliced_nd_hist.begin() + static_cast<int>(num_quantized_bins * num_merged_bins),
sliced_nd_hist.end(), 0);
// expand quantized_bins into p.size bins
std::vector<float> q(sliced_nd_hist.size(), 0);
for (int j = 0; j < num_quantized_bins; j++) {
const int start = j * num_merged_bins;
const int stop = (j == num_quantized_bins - 1) ? q.size() : ((j + 1) * num_merged_bins);
int norm = std::count_if(sliced_nd_hist.begin() + start, sliced_nd_hist.begin() + stop,
[](size_t i) { return i != 0; });
if (norm) {
for (int k = start; k < stop; k++) {
if (p[k]) q[k] = quantized_bins[j] / norm;
}
}
}
p = SmoothDistribution(p);
q = SmoothDistribution(q);
if (!q.size()) {
divergence[i - num_half_quantized_bins] = std::numeric_limits<float>::infinity();
} else {
divergence[i - num_half_quantized_bins] = ComputeEntropy(p.data(), q.data(), p.size());
}
}
auto min_divergence_idx =
std::distance(divergence.begin(), std::min_element(divergence.begin(), divergence.end()));
return thresholds[min_divergence_idx];
}
这段代码是一个最小化 KL 散度的函数,输入参数为一个整数向量 hist、一个浮点数向量 hist_edges、两个整数 num_bins 和 num_quantized_bins。函数首先定义了一些变量,包括零分箱索引zero_bin_idx、半量化分箱数num_half_quantized_bins、阈值向量thresholds、发散度向量divergence和量化分箱向量quantized_bins。然后,函数遍历hist_edges中的元素,计算每个元素对应的p和q分布,并计算它们之间的KL散度。最后,函数返回具有最小 KL 散度的阈值。
class StatsCollector : private ExprMutator {
public:
StatsCollector() : simulated_quantize_op_(Op::Get("relay.op.annotation.simulated_quantize")) {}
Expr Collect(const Expr& expr) {
auto new_e = this->Mutate(expr);
const FunctionNode* func = new_e.as<FunctionNode>();
ICHECK(func) << "Input shoule be Function";
Expr new_body = Tuple(std::move(profile_data_));
Function ret_func = WithFields(GetRef<Function>(func), FreeVars(new_body), new_body);
// We are changing the function's ret_type to an empty type. Unfortunately, Optional<Type>() is
// indistinguishable from NullValue<Type>(), so we can't express "update to nullptr" in
// WithFields.
ret_func.CopyOnWrite()->ret_type = NullValue<Type>();
return std::move(ret_func);
}
private:
Array<Expr> profile_data_;
const Op& simulated_quantize_op_;
Expr VisitExpr_(const CallNode* call) {
Expr new_e = ExprMutator::VisitExpr_(call);
const CallNode* new_call = new_e.as<CallNode>();
ICHECK(new_call);
if (new_call->op == simulated_quantize_op_) {
auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
// rewrite the annotation
auto new_attrs = make_object<SimulatedQuantizeAttrs>();
const Expr& quantize_input = new_call->args[0]; // expression being quantized
auto placeholder = MakeConstantScalar(DataType::Float(32), 0.); // unused argument
Array<Expr> new_args{quantize_input, placeholder, placeholder, placeholder};
new_attrs->kind = QAnnotateKind::kQIdentity;
new_attrs->sign = attrs->sign;
new_attrs->rounding = attrs->rounding;
Expr identity_quantize = Call(new_call->op, new_args, Attrs{new_attrs}, {});
// add non-const expressions to profile data
if (attrs->kind != QAnnotateKind::kQWeight) {
ICHECK(!quantize_input.as<ConstantNode>());
profile_data_.push_back(identity_quantize);
}
return identity_quantize;
} else {
return new_e;
}
}
};
这段代码定义了 StatsCollector 类,它继承自 ExprMutator 类。该类的主要作用是收集表达式中的量化信息,并将这些信息存储在 profile_data_
数组中。
在Collect函数中,首先调用Mutate函数对输入的表达式进行遍历和修改,然后将其转换为FunctionNode类型,并检查其是否为空。接着,将 profile_data_
数组转换为 Tuple 类型,并将其作为新的函数体。最后,将新函数的返回类型设置为 NullValue<Type>()
,表示返回类型为空。
在 VisitExpr_
函数中,首先调用 ExprMutator::VisitExpr_
函数对 CallNode
类型的节点进行处理。如果该节点算子是 simulated_quantize_op_
,则获取该节点的属性,并创建一个新的 SimulatedQuantizeAttrs
对象。接着,将该节点的第一个参数作为量化表达式,创建一个占位符常量,并将它们与新属性一起传递给 Call
函数,生成 identity_quantize
节点。如果该节点的属性 kind
不等于 kQWeight
,则将非 const
表达式添加到 profile_data_
数组中。最后,返回 identity_quantize
节点。否则,直接返回 new_e
节点。
/*
* \brief Given an annotated graph, create a profile graph to collect profile data from the
* calibration dataset.
*
* This pass collects simulated_quantize op into a tuple. Simulated_quantize ops are rewritten to
* identity mode. The tuple is the output of the profile graph. Both input and output of this pass
* are relay::Function.
*
* \param expr The simulation graph after annotation.
* \return The profile graph.
*/
Expr CreateStatsCollector(const Expr& expr) { return StatsCollector().Collect(expr); }
TVM_REGISTER_GLOBAL("relay._quantize.CreateStatsCollector").set_body_typed(CreateStatsCollector);
TVM_REGISTER_GLOBAL("relay._quantize.FindScaleByKLMinimization")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int* hist_ptr = static_cast<int*>(static_cast<void*>(args[0]));
float* hist_edges_ptr = static_cast<float*>(static_cast<void*>(args[1]));
int num_bins = args[2];
int num_quantized_bins = args[3];
std::vector<int> hist(hist_ptr, hist_ptr + num_bins);
std::vector<float> hist_edges(hist_edges_ptr, hist_edges_ptr + num_bins + 1);
ret[0] = MinimizeKL(hist, hist_edges, num_bins, num_quantized_bins);
});
这段代码定义了一个名为 StatsCollector
的类,它继承自 ExprMutator
类。该类的主要作用是收集表达式中的量化信息,并将这些信息存储在 profile_data_
`数组中。
在Collect函数中,首先调用Mutate函数对输入的表达式进行遍历和修改,然后将其转换为FunctionNode类型,并检查其是否为空。接着,将profile_data_数组转换为Tuple类型,并将其作为新的函数体。最后,将新函数的返回类型设置为 NullValue<Type>()
,表示返回类型为空。
在 VisitExpr_
函数中,首先调用 ExprMutator::VisitExpr_
函数对CallNode类型的节点进行处理。如果该节点的算子是 simulated_quantize_op_
,则获取该节点的属性,并创建一个新的 SimulatedQuantizeAttrs对象。接着,将该节点的第一个参数作为量化表达式,创建一个占位符常量,并将它们与新属性一起传递给Call函数,生成一个identity_quantize节点。如果该节点的属性kind不等于kQWeight,则将非const表达式添加到profile_data_数组中。最后,返回identity_quantize节点。否则,直接返回new_e节点。
此外,还定义了两个全局变量:CreateStatsCollector
和 FindScaleByKLMinimization
。CreateStatsCollector
用于创建统计收集器,而 FindScaleByKLMinimization
用于通过KL最小化方法查找scale。
from tvm.relay.quantize.kl_divergence import _find_scale_by_kl
_find_scale_by_kl??
Signature:
_find_scale_by_kl(
arr,
quantized_dtype='int8',
num_bins=8001,
num_quantized_bins=255,
)
Source:
def _find_scale_by_kl(arr, quantized_dtype="int8", num_bins=8001, num_quantized_bins=255):
"""Given a tensor, find the optimal threshold for quantizing it.
The reference distribution is `q`, and the candidate distribution is `p`.
`q` is a truncated version of the original distribution.
Ref:
http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
"""
assert isinstance(arr, np.ndarray)
min_val = np.min(arr)
max_val = np.max(arr)
thres = max(abs(min_val), abs(max_val))
if min_val >= 0 and quantized_dtype in ["uint8"]:
# We need to move negative bins to positive bins to fit uint8 range.
num_quantized_bins = num_quantized_bins * 2 + 1
def get_pointer(arr, ctypes_type):
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes_type))
return ctypes.cast(ptr, ctypes.c_void_p)
hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-thres, thres))
hist_ptr = get_pointer(hist.astype(np.int32), ctypes.c_int)
hist_edges_ptr = get_pointer(hist_edges, ctypes.c_float)
return _quantize.FindScaleByKLMinimization(
hist_ptr, hist_edges_ptr, num_bins, num_quantized_bins
)
File: /media/pc/data/lxw/ai/tvm/python/tvm/relay/quantize/kl_divergence.py
Type: function
示例:
import numpy as np
# 生成随机数据
data = np.random.randn(1000)
# 定义分箱边界
bin_edges = np.linspace(-5, 5, 21)
# 计算直方图
hist, bin_edges = np.histogram(data, bins=bin_edges)
print("频数:", hist)
print("分箱边界:", bin_edges)
频数: [ 0 0 0 3 2 22 50 81 165 174 188 148 101 45 12 8 1 0
0 0]
分箱边界: [-5. -4.5 -4. -3.5 -3. -2.5 -2. -1.5 -1. -0.5 0. 0.5 1. 1.5
2. 2.5 3. 3.5 4. 4.5 5. ]