合法化 qdq#
# 导入必要的模块和类
import tvm
from tvm.relax.transform import LegalizeOps
from tvm.script import relax as R, tir as T
import tvm.testing
测试从 float32 到 int8 的量化算子的合法化变换#
@tvm.script.ir_module
class Quantize:
@R.function
def main(
data: R.Tensor((2, 4), "float32"),
scale: R.Tensor((2,), "float32"),
zp: R.Tensor((2,), "int8"),
) -> R.Tensor((2, 4), "int8"):
# 调用量化操作,沿轴0量化到int8
out = R.quantize(data, scale, zp, axis=0, out_dtype="int8")
return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Quantize)
mod.show()
测试从 float16 到 uint8 的量化算子的合法化变换#
@tvm.script.ir_module
class Quantize:
@R.function
def main(
data: R.Tensor((2, 4), "float16"),
scale: R.Tensor((2,), "float16"),
zp: R.Tensor((2,), "int8"),
) -> R.Tensor((2, 4), "uint8"):
# 调用量化操作,沿轴0量化到uint8
out = R.quantize(data, scale, zp, axis=0, out_dtype="uint8")
return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Quantize)
mod.show()
测试符号形状输入下,从float32到int8的量化算子的合法化变换#
@tvm.script.ir_module
class Quantize:
@R.function
def main(
data: R.Tensor((4, "n"), "float32"),
scale: R.Tensor(("n",), "float32"),
zp: R.Tensor(("n",), "int8"),
) -> R.Tensor((4, "n"), "int8"):
# 调用量化操作,沿最后一个轴(-1)量化到int8
out = R.quantize(data, scale, zp, axis=-1, out_dtype="int8")
return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Quantize)
mod.show()
测试使用标量参数(而非张量)时,从float32到int8的量化算子的合法化变换#
@tvm.script.ir_module
class Quantize:
@R.function
def main(data: R.Tensor((2, 4), "float32")) -> R.Tensor((2, 4), "int8"):
# 使用标量值作为scale和zp参数
out = R.quantize(
data, R.const(2.0, "float32"), R.const(1, "int8"), axis=-1, out_dtype="int8"
)
return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Quantize)
mod.show()
测试使用一维常量数组作为参数时,从float32到int8的量化算子的合法化变换#
@tvm.script.ir_module
class Quantize:
@R.function
def main(data: R.Tensor((2, 4), "float32")) -> R.Tensor((2, 4), "int8"):
# 使用一维常量数组作为scale和zp参数
out = R.quantize(
data,
R.const([2.0, 1.0], "float32"),
R.const([4, 5], "int8"),
axis=0,
out_dtype="int8",
)
return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Quantize)
mod.show()
测试使用标量参数时,从float16到int8的量化算子的合法化变换#
@tvm.script.ir_module
class Quantize:
@R.function
def main(data: R.Tensor((2, 4), "float16")) -> R.Tensor((2, 4), "int8"):
# 使用标量值作为scale和zp参数
out = R.quantize(
data, R.const(2.0, "float16"), R.const(1, "int8"), axis=-1, out_dtype="int8"
)
return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Quantize)
mod.show()
测试从int8到float32的反量化算子的合法化变换#
@tvm.script.ir_module
class Dequantize:
@R.function
def main(
data: R.Tensor((2, 4), "int8"),
scale: R.Tensor((2,), "float32"),
zp: R.Tensor((2,), "int8"),
) -> R.Tensor((2, 4), "float32"):
# 调用反量化算子,沿轴0反量化到float32
out = R.dequantize(data, scale, zp, axis=0, out_dtype="float32")
return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Dequantize)
mod.show()
测试使用标量参数时,从int8到float32的反量化算子的合法化变换#
@tvm.script.ir_module
class Dequantize:
@R.function
def main(data: R.Tensor((2, 4), "int8")) -> R.Tensor((2, 4), "float32"):
# 使用标量值作为scale和zp参数
out = R.dequantize(
data, R.const(2.0, "float32"), R.const(1, "int8"), axis=0, out_dtype="float32"
)
return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Dequantize)
mod.show()
测试符号形状输入下,从int8到float32的反量化算子的合法化变换#
@tvm.script.ir_module
class Dequantize:
@R.function
def main(
data: R.Tensor((2, "n"), "int8"),
scale: R.Tensor(("n",), "float32"),
zp: R.Tensor(("n",), "int8"),
) -> R.Tensor((2, "n"), "float32"):
# 调用反量化操作,沿最后一个轴(-1)反量化到float32
out = R.dequantize(data, scale, zp, axis=-1, out_dtype="float32")
return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Dequantize)
mod.show()
测试从int8到float16的反量化算子的合法化变换#
@tvm.script.ir_module
class Dequantize:
@R.function
def main(
data: R.Tensor((2, 4), "int8"),
scale: R.Tensor((2,), "float16"),
zp: R.Tensor((2,), "int8"),
) -> R.Tensor((2, 4), "float16"):
# 调用反量化操作,沿轴0反量化到float16
out = R.dequantize(data, scale, zp, axis=0, out_dtype="float16")
return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Dequantize)
mod.show()
测试使用标量参数时,从int8到float16的反量化算子的合法化变换#
@tvm.script.ir_module
class Dequantize:
@R.function
def main(data: R.Tensor((2, 4), "int8")) -> R.Tensor((2, 4), "float16"):
# 使用标量值作为scale和zp参数
out = R.dequantize(
data, R.const(2.0, "float16"), R.const(1, "int8"), axis=0, out_dtype="float16"
)
return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Dequantize)
mod.show()