合法化 qdq#

# 导入必要的模块和类
import tvm
from tvm.relax.transform import LegalizeOps
from tvm.script import relax as R, tir as T
import tvm.testing

测试从 float32 到 int8 的量化算子的合法化变换#

@tvm.script.ir_module
class Quantize:
    @R.function
    def main(
        data: R.Tensor((2, 4), "float32"),
        scale: R.Tensor((2,), "float32"),
        zp: R.Tensor((2,), "int8"),
    ) -> R.Tensor((2, 4), "int8"):
        # 调用量化操作,沿轴0量化到int8
        out = R.quantize(data, scale, zp, axis=0, out_dtype="int8")
        return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Quantize)
mod.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def quantize(data: T.Buffer((T.int64(2), T.int64(4)), "float32"), scale: T.Buffer((T.int64(2),), "float32"), zp: T.Buffer((T.int64(2),), "int8"), quantized: T.Buffer((T.int64(2), T.int64(4)), "int8")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for i0, i1 in T.grid(T.int64(2), T.int64(4)):
            with T.block("quantized"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(data[v_i0, v_i1], scale[v_i0], zp[v_i0])
                T.writes(quantized[v_i0, v_i1])
                quantized[v_i0, v_i1] = T.Cast("int8", T.max(T.min(T.round(data[v_i0, v_i1] / scale[v_i0]) + T.Cast("float32", zp[v_i0]), T.float32(127.0)), T.float32(-128.0)))

    @R.function
    def main(data: R.Tensor((2, 4), dtype="float32"), scale: R.Tensor((2,), dtype="float32"), zp: R.Tensor((2,), dtype="int8")) -> R.Tensor((2, 4), dtype="int8"):
        cls = Module
        out = R.call_tir(cls.quantize, (data, scale, zp), out_sinfo=R.Tensor((2, 4), dtype="int8"))
        return out

测试从 float16 到 uint8 的量化算子的合法化变换#

@tvm.script.ir_module
class Quantize:
    @R.function
    def main(
        data: R.Tensor((2, 4), "float16"),
        scale: R.Tensor((2,), "float16"),
        zp: R.Tensor((2,), "int8"),
    ) -> R.Tensor((2, 4), "uint8"):
        # 调用量化操作,沿轴0量化到uint8
        out = R.quantize(data, scale, zp, axis=0, out_dtype="uint8")
        return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Quantize)
mod.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def quantize(data: T.Buffer((T.int64(2), T.int64(4)), "float16"), scale: T.Buffer((T.int64(2),), "float16"), zp: T.Buffer((T.int64(2),), "int8"), quantized: T.Buffer((T.int64(2), T.int64(4)), "uint8")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for i0, i1 in T.grid(T.int64(2), T.int64(4)):
            with T.block("quantized"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(data[v_i0, v_i1], scale[v_i0], zp[v_i0])
                T.writes(quantized[v_i0, v_i1])
                quantized[v_i0, v_i1] = T.Cast("uint8", T.max(T.min(T.round(data[v_i0, v_i1] / scale[v_i0]) + T.Cast("float16", zp[v_i0]), T.float16(255.0)), T.float16(0.0)))

    @R.function
    def main(data: R.Tensor((2, 4), dtype="float16"), scale: R.Tensor((2,), dtype="float16"), zp: R.Tensor((2,), dtype="int8")) -> R.Tensor((2, 4), dtype="uint8"):
        cls = Module
        out = R.call_tir(cls.quantize, (data, scale, zp), out_sinfo=R.Tensor((2, 4), dtype="uint8"))
        return out

测试符号形状输入下,从float32到int8的量化算子的合法化变换#

@tvm.script.ir_module
class Quantize:
    @R.function
    def main(
        data: R.Tensor((4, "n"), "float32"),
        scale: R.Tensor(("n",), "float32"),
        zp: R.Tensor(("n",), "int8"),
    ) -> R.Tensor((4, "n"), "int8"):
        # 调用量化操作,沿最后一个轴(-1)量化到int8
        out = R.quantize(data, scale, zp, axis=-1, out_dtype="int8")
        return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Quantize)
mod.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def quantize(var_data: T.handle, var_scale: T.handle, var_zp: T.handle, var_quantized: T.handle):
        T.func_attr({"tir.noalias": True})
        n = T.int64()
        data = T.match_buffer(var_data, (T.int64(4), n))
        scale = T.match_buffer(var_scale, (n,))
        zp = T.match_buffer(var_zp, (n,), "int8")
        quantized = T.match_buffer(var_quantized, (T.int64(4), n), "int8")
        # with T.block("root"):
        for i0, i1 in T.grid(T.int64(4), n):
            with T.block("quantized"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(data[v_i0, v_i1], scale[v_i1], zp[v_i1])
                T.writes(quantized[v_i0, v_i1])
                quantized[v_i0, v_i1] = T.Cast("int8", T.max(T.min(T.round(data[v_i0, v_i1] / scale[v_i1]) + T.Cast("float32", zp[v_i1]), T.float32(127.0)), T.float32(-128.0)))

    @R.function
    def main(data: R.Tensor((4, "n"), dtype="float32"), scale: R.Tensor(("n",), dtype="float32"), zp: R.Tensor(("n",), dtype="int8")) -> R.Tensor((4, "n"), dtype="int8"):
        n = T.int64()
        cls = Module
        out = R.call_tir(cls.quantize, (data, scale, zp), out_sinfo=R.Tensor((4, n), dtype="int8"))
        return out

测试使用标量参数(而非张量)时,从float32到int8的量化算子的合法化变换#

@tvm.script.ir_module
class Quantize:
    @R.function
    def main(data: R.Tensor((2, 4), "float32")) -> R.Tensor((2, 4), "int8"):
        # 使用标量值作为scale和zp参数
        out = R.quantize(
            data, R.const(2.0, "float32"), R.const(1, "int8"), axis=-1, out_dtype="int8"
        )
        return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Quantize)
mod.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def quantize(data: T.Buffer((T.int64(2), T.int64(4)), "float32"), quantized: T.Buffer((T.int64(2), T.int64(4)), "int8")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for i0, i1 in T.grid(T.int64(2), T.int64(4)):
            with T.block("quantized"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(data[v_i0, v_i1])
                T.writes(quantized[v_i0, v_i1])
                quantized[v_i0, v_i1] = T.Cast("int8", T.max(T.min(T.round(data[v_i0, v_i1] * T.float32(0.5)) + T.float32(1.0), T.float32(127.0)), T.float32(-128.0)))

    @R.function
    def main(data: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4), dtype="int8"):
        cls = Module
        out = R.call_tir(cls.quantize, (data,), out_sinfo=R.Tensor((2, 4), dtype="int8"))
        return out

测试使用一维常量数组作为参数时,从float32到int8的量化算子的合法化变换#

@tvm.script.ir_module
class Quantize:
    @R.function
    def main(data: R.Tensor((2, 4), "float32")) -> R.Tensor((2, 4), "int8"):
        # 使用一维常量数组作为scale和zp参数
        out = R.quantize(
            data,
            R.const([2.0, 1.0], "float32"),
            R.const([4, 5], "int8"),
            axis=0,
            out_dtype="int8",
        )
        return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Quantize)
mod.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def quantize(data: T.Buffer((T.int64(2), T.int64(4)), "float32"), B: T.Buffer((T.int64(2),), "float32"), C: T.Buffer((T.int64(2),), "int8"), quantized: T.Buffer((T.int64(2), T.int64(4)), "int8")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for i0, i1 in T.grid(T.int64(2), T.int64(4)):
            with T.block("quantized"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(data[v_i0, v_i1], B[v_i0], C[v_i0])
                T.writes(quantized[v_i0, v_i1])
                quantized[v_i0, v_i1] = T.Cast("int8", T.max(T.min(T.round(data[v_i0, v_i1] / B[v_i0]) + T.Cast("float32", C[v_i0]), T.float32(127.0)), T.float32(-128.0)))

    @R.function
    def main(data: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4), dtype="int8"):
        cls = Module
        out = R.call_tir(cls.quantize, (data, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((2, 4), dtype="int8"))
        return out

# Metadata omitted. Use show_meta=True in script() method to show it.

测试使用标量参数时,从float16到int8的量化算子的合法化变换#

@tvm.script.ir_module
class Quantize:
    @R.function
    def main(data: R.Tensor((2, 4), "float16")) -> R.Tensor((2, 4), "int8"):
        # 使用标量值作为scale和zp参数
        out = R.quantize(
            data, R.const(2.0, "float16"), R.const(1, "int8"), axis=-1, out_dtype="int8"
        )
        return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Quantize)
mod.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def quantize(data: T.Buffer((T.int64(2), T.int64(4)), "float16"), quantized: T.Buffer((T.int64(2), T.int64(4)), "int8")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for i0, i1 in T.grid(T.int64(2), T.int64(4)):
            with T.block("quantized"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(data[v_i0, v_i1])
                T.writes(quantized[v_i0, v_i1])
                quantized[v_i0, v_i1] = T.Cast("int8", T.max(T.min(T.round(data[v_i0, v_i1] * T.float16(0.5)) + T.float16(1.0), T.float16(127.0)), T.float16(-128.0)))

    @R.function
    def main(data: R.Tensor((2, 4), dtype="float16")) -> R.Tensor((2, 4), dtype="int8"):
        cls = Module
        out = R.call_tir(cls.quantize, (data,), out_sinfo=R.Tensor((2, 4), dtype="int8"))
        return out

测试从int8到float32的反量化算子的合法化变换#

@tvm.script.ir_module
class Dequantize:
    @R.function
    def main(
        data: R.Tensor((2, 4), "int8"),
        scale: R.Tensor((2,), "float32"),
        zp: R.Tensor((2,), "int8"),
    ) -> R.Tensor((2, 4), "float32"):
        # 调用反量化算子,沿轴0反量化到float32
        out = R.dequantize(data, scale, zp, axis=0, out_dtype="float32")
        return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Dequantize)
mod.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def dequantize(data: T.Buffer((T.int64(2), T.int64(4)), "int8"), scale: T.Buffer((T.int64(2),), "float32"), zp: T.Buffer((T.int64(2),), "int8"), dequantized: T.Buffer((T.int64(2), T.int64(4)), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for i0, i1 in T.grid(T.int64(2), T.int64(4)):
            with T.block("dequantized"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(data[v_i0, v_i1], zp[v_i0], scale[v_i0])
                T.writes(dequantized[v_i0, v_i1])
                dequantized[v_i0, v_i1] = T.Cast("float32", T.Cast("int32", data[v_i0, v_i1]) - T.Cast("int32", zp[v_i0])) * scale[v_i0]

    @R.function
    def main(data: R.Tensor((2, 4), dtype="int8"), scale: R.Tensor((2,), dtype="float32"), zp: R.Tensor((2,), dtype="int8")) -> R.Tensor((2, 4), dtype="float32"):
        cls = Module
        out = R.call_tir(cls.dequantize, (data, scale, zp), out_sinfo=R.Tensor((2, 4), dtype="float32"))
        return out

测试使用标量参数时,从int8到float32的反量化算子的合法化变换#

@tvm.script.ir_module
class Dequantize:
    @R.function
    def main(data: R.Tensor((2, 4), "int8")) -> R.Tensor((2, 4), "float32"):
        # 使用标量值作为scale和zp参数
        out = R.dequantize(
            data, R.const(2.0, "float32"), R.const(1, "int8"), axis=0, out_dtype="float32"
        )
        return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Dequantize)
mod.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def dequantize(data: T.Buffer((T.int64(2), T.int64(4)), "int8"), dequantized: T.Buffer((T.int64(2), T.int64(4)), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for i0, i1 in T.grid(T.int64(2), T.int64(4)):
            with T.block("dequantized"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(data[v_i0, v_i1])
                T.writes(dequantized[v_i0, v_i1])
                dequantized[v_i0, v_i1] = T.Cast("float32", T.Cast("int32", data[v_i0, v_i1]) - 1) * T.float32(2.0)

    @R.function
    def main(data: R.Tensor((2, 4), dtype="int8")) -> R.Tensor((2, 4), dtype="float32"):
        cls = Module
        out = R.call_tir(cls.dequantize, (data,), out_sinfo=R.Tensor((2, 4), dtype="float32"))
        return out

测试符号形状输入下,从int8到float32的反量化算子的合法化变换#

@tvm.script.ir_module
class Dequantize:
    @R.function
    def main(
        data: R.Tensor((2, "n"), "int8"),
        scale: R.Tensor(("n",), "float32"),
        zp: R.Tensor(("n",), "int8"),
    ) -> R.Tensor((2, "n"), "float32"):
        # 调用反量化操作,沿最后一个轴(-1)反量化到float32
        out = R.dequantize(data, scale, zp, axis=-1, out_dtype="float32")
        return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Dequantize)
mod.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def dequantize(var_data: T.handle, var_scale: T.handle, var_zp: T.handle, var_dequantized: T.handle):
        T.func_attr({"tir.noalias": True})
        n = T.int64()
        data = T.match_buffer(var_data, (T.int64(2), n), "int8")
        scale = T.match_buffer(var_scale, (n,))
        zp = T.match_buffer(var_zp, (n,), "int8")
        dequantized = T.match_buffer(var_dequantized, (T.int64(2), n))
        # with T.block("root"):
        for i0, i1 in T.grid(T.int64(2), n):
            with T.block("dequantized"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(data[v_i0, v_i1], zp[v_i1], scale[v_i1])
                T.writes(dequantized[v_i0, v_i1])
                dequantized[v_i0, v_i1] = T.Cast("float32", T.Cast("int32", data[v_i0, v_i1]) - T.Cast("int32", zp[v_i1])) * scale[v_i1]

    @R.function
    def main(data: R.Tensor((2, "n"), dtype="int8"), scale: R.Tensor(("n",), dtype="float32"), zp: R.Tensor(("n",), dtype="int8")) -> R.Tensor((2, "n"), dtype="float32"):
        n = T.int64()
        cls = Module
        out = R.call_tir(cls.dequantize, (data, scale, zp), out_sinfo=R.Tensor((2, n), dtype="float32"))
        return out

测试从int8到float16的反量化算子的合法化变换#

@tvm.script.ir_module
class Dequantize:
    @R.function
    def main(
        data: R.Tensor((2, 4), "int8"),
        scale: R.Tensor((2,), "float16"),
        zp: R.Tensor((2,), "int8"),
    ) -> R.Tensor((2, 4), "float16"):
        # 调用反量化操作,沿轴0反量化到float16
        out = R.dequantize(data, scale, zp, axis=0, out_dtype="float16")
        return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Dequantize)
mod.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def dequantize(data: T.Buffer((T.int64(2), T.int64(4)), "int8"), scale: T.Buffer((T.int64(2),), "float16"), zp: T.Buffer((T.int64(2),), "int8"), dequantized: T.Buffer((T.int64(2), T.int64(4)), "float16")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for i0, i1 in T.grid(T.int64(2), T.int64(4)):
            with T.block("dequantized"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(data[v_i0, v_i1], zp[v_i0], scale[v_i0])
                T.writes(dequantized[v_i0, v_i1])
                dequantized[v_i0, v_i1] = T.Cast("float16", T.max(T.min(T.Cast("float32", T.Cast("int32", data[v_i0, v_i1]) - T.Cast("int32", zp[v_i0])) * T.Cast("float32", scale[v_i0]), T.float32(65504.0)), T.float32(-65504.0)))

    @R.function
    def main(data: R.Tensor((2, 4), dtype="int8"), scale: R.Tensor((2,), dtype="float16"), zp: R.Tensor((2,), dtype="int8")) -> R.Tensor((2, 4), dtype="float16"):
        cls = Module
        out = R.call_tir(cls.dequantize, (data, scale, zp), out_sinfo=R.Tensor((2, 4), dtype="float16"))
        return out

测试使用标量参数时,从int8到float16的反量化算子的合法化变换#

@tvm.script.ir_module
class Dequantize:
    @R.function
    def main(data: R.Tensor((2, 4), "int8")) -> R.Tensor((2, 4), "float16"):
        # 使用标量值作为scale和zp参数
        out = R.dequantize(
            data, R.const(2.0, "float16"), R.const(1, "int8"), axis=0, out_dtype="float16"
        )
        return out
# 应用LegalizeOps转换并验证结果是否符合预期
mod = LegalizeOps()(Dequantize)
mod.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def dequantize(data: T.Buffer((T.int64(2), T.int64(4)), "int8"), dequantized: T.Buffer((T.int64(2), T.int64(4)), "float16")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for i0, i1 in T.grid(T.int64(2), T.int64(4)):
            with T.block("dequantized"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(data[v_i0, v_i1])
                T.writes(dequantized[v_i0, v_i1])
                dequantized[v_i0, v_i1] = T.Cast("float16", T.max(T.min(T.Cast("float32", T.Cast("int32", data[v_i0, v_i1]) - 1) * T.float32(2.0), T.float32(65504.0)), T.float32(-65504.0)))

    @R.function
    def main(data: R.Tensor((2, 4), dtype="int8")) -> R.Tensor((2, 4), dtype="float16"):
        cls = Module
        out = R.call_tir(cls.dequantize, (data,), out_sinfo=R.Tensor((2, 4), dtype="float16"))
        return out