DPL 模式重写#

from tvm.script import relax as R
from tvm.script import tir as T
from tvm import relax as rx
from tvm import relay, tir
from tvm.relax.analysis import get_var2val
from tvm.relax.dpl import *

简单的示例#

原始表达式为:

\[\begin{split} \begin{aligned} &x2 = x + x\\ &x4 = x2 + x2 \end{aligned} \end{split}\]
@R.function
def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16), "float32"):
    with R.dataflow():
        x2 = R.add(x, x)
        x4 = R.add(x2, x2)
        R.output(x4)
    return x4
main.show()
# from tvm.script import relax as R

@R.function
def main(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"):
    with R.dataflow():
        x2: R.Tensor((16, 16), dtype="float32") = R.add(x, x)
        x4: R.Tensor((16, 16), dtype="float32") = R.add(x2, x2)
        R.output(x4)
    return x4

构建模板:

x = wildcard()
pattern = is_op("relax.add")(x, x)

重写模式:

def rewriter(_, matchings):
    return R.multiply(matchings[x], R.const(2, "float32"))
rewritten = rewrite_call(pattern, rewriter, main)
rewritten.show()
# from tvm.script import relax as R

@R.function
def main(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"):
    with R.dataflow():
        x2: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(2.0, "float32"))
        x4: R.Tensor((16, 16), dtype="float32") = R.multiply(x2, R.const(2.0, "float32"))
        R.output(x4)
    return x4

此时将 \(x + x\) 重写为 \(x \times 2\) 。更进一步,原式可以化简为:

add1 = is_op("relax.add")(x, x)
pattern = is_op("relax.add")(add1, add1)

def rewriter(_, matchings):
    return R.multiply(matchings[x], R.const(4, "float32"))

rewritten = rewrite_call(pattern, rewriter, main)
rewritten.show()
# from tvm.script import relax as R

@R.function
def main(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"):
    with R.dataflow():
        x4: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(4.0, "float32"))
        R.output(x4)
    return x4

不进行重写,按原样返回原始调用节点:

def rewriter(orig, _):
    return orig

rewritten = rewrite_call(pattern, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, main)

重写注意力模块#

@R.function
def main(
    Q: R.Tensor((2, 4096, 8, 40), "float32"),
    K: R.Tensor((2, 4096, 8, 40), "float32"),
    V: R.Tensor((2, 4096, 8, 40), "float32"),
) -> R.Tensor((2, 4096, 8, 40), "float32"):
    with R.dataflow():
        lv58 = R.permute_dims(Q, axes=[0, 2, 1, 3])
        lv59 = R.reshape(lv58, R.shape([16, 4096, 40]))

        lv61 = R.permute_dims(K, axes=[0, 2, 1, 3])
        lv62 = R.reshape(lv61, R.shape([16, 4096, 40]))

        lv64 = R.permute_dims(V, axes=[0, 2, 1, 3])
        lv65 = R.reshape(lv64, R.shape([16, 4096, 40]))

        lv62_transposed = R.permute_dims(lv62, axes=[0, 2, 1])
        lv3_1 = R.matmul(lv59, lv62_transposed)
        lv68 = R.multiply(lv3_1, R.const(0.15811388194561005, "float32"))
        lv69 = R.nn.softmax(lv68, axis=-1)
        lv_3 = R.matmul(lv69, lv65)

        lv71 = R.reshape(lv_3, R.shape([2, 8, 4096, 40]))
        lv72 = R.permute_dims(lv71, axes=[0, 2, 1, 3])
        R.output(lv72)

    return lv72
main.show()
# from tvm.script import relax as R

@R.function
def main(Q: R.Tensor((2, 4096, 8, 40), dtype="float32"), K: R.Tensor((2, 4096, 8, 40), dtype="float32"), V: R.Tensor((2, 4096, 8, 40), dtype="float32")) -> R.Tensor((2, 4096, 8, 40), dtype="float32"):
    with R.dataflow():
        lv58: R.Tensor((2, 8, 4096, 40), dtype="float32") = R.permute_dims(Q, axes=[0, 2, 1, 3])
        lv59: R.Tensor((16, 4096, 40), dtype="float32") = R.reshape(lv58, R.shape([16, 4096, 40]))
        lv61: R.Tensor((2, 8, 4096, 40), dtype="float32") = R.permute_dims(K, axes=[0, 2, 1, 3])
        lv62: R.Tensor((16, 4096, 40), dtype="float32") = R.reshape(lv61, R.shape([16, 4096, 40]))
        lv64: R.Tensor((2, 8, 4096, 40), dtype="float32") = R.permute_dims(V, axes=[0, 2, 1, 3])
        lv65: R.Tensor((16, 4096, 40), dtype="float32") = R.reshape(lv64, R.shape([16, 4096, 40]))
        lv62_transposed: R.Tensor((16, 40, 4096), dtype="float32") = R.permute_dims(lv62, axes=[0, 2, 1])
        lv3_1: R.Tensor((16, 4096, 4096), dtype="float32") = R.matmul(lv59, lv62_transposed, out_dtype="void")
        lv68: R.Tensor((16, 4096, 4096), dtype="float32") = R.multiply(lv3_1, R.const(0.15811388194561005, "float32"))
        lv69: R.Tensor((16, 4096, 4096), dtype="float32") = R.nn.softmax(lv68, axis=-1)
        lv_3: R.Tensor((16, 4096, 40), dtype="float32") = R.matmul(lv69, lv65, out_dtype="void")
        lv71: R.Tensor((2, 8, 4096, 40), dtype="float32") = R.reshape(lv_3, R.shape([2, 8, 4096, 40]))
        lv72: R.Tensor((2, 4096, 8, 40), dtype="float32") = R.permute_dims(lv71, axes=[0, 2, 1, 3])
        R.output(lv72)
    return lv72

构建模板:

def BSNH_to_BSH(tensor):
    return is_op("relax.reshape")(is_op("relax.permute_dims")(tensor), wildcard())

def BSH_to_BSNH(tensor):
    return is_op("relax.permute_dims")(is_op("relax.reshape")(tensor, wildcard()))

Q = wildcard()
K = wildcard()
V = wildcard()

Q_3D = BSNH_to_BSH(Q)
V_3D = BSNH_to_BSH(V)
K_3D = BSNH_to_BSH(K)

matmul1 = is_op("relax.matmul")(Q_3D, is_op("relax.permute_dims")(V_3D))
multiply = is_op("relax.multiply")(matmul1, is_const())
softmax = is_op("relax.nn.softmax")(multiply)
matmul2 = is_op("relax.matmul")(softmax, K_3D)
pattern = BSH_to_BSNH(matmul2)
def rewriter(_, matchings):
    return R.nn.attention(matchings[Q], matchings[K], matchings[V])

rewritten = rewrite_call(pattern, rewriter, main)
rewritten.show()
# from tvm.script import relax as R

@R.function
def main(Q: R.Tensor((2, 4096, 8, 40), dtype="float32"), K: R.Tensor((2, 4096, 8, 40), dtype="float32"), V: R.Tensor((2, 4096, 8, 40), dtype="float32")) -> R.Tensor((2, 4096, 8, 40), dtype="float32"):
    with R.dataflow():
        lv72: R.Tensor((2, 4096, 8, 40), dtype="float32") = R.nn.attention(Q, V, K, scale=None, causal_mask=None, window_size=None)
        R.output(lv72)
    return lv72

测试交换律模式匹配#

@R.function(private=True)
def before(
    x: R.Tensor((1024,)),
):
    with R.dataflow():
        y = R.add(x, x)
        out = R.add(R.const(1.0), y)
        R.output(out)
    return out
before.show()
# from tvm.script import relax as R

@R.function(private=True)
def main(x: R.Tensor((1024,))) -> R.Tensor((1024,)):
    with R.dataflow():
        y: R.Tensor((1024,)) = R.add(x, x)
        out: R.Tensor((1024,)) = R.add(R.const(1.0, "float32"), y)
        R.output(out)
    return out
pattern_add = is_op("relax.add")
pattern_mul = is_op("relax.multiply")
pattern_op = pattern_add | pattern_mul
pattern_arg = wildcard()
pattern_const = is_const()

pattern = pattern_op(pattern_arg, pattern_const)

def rewriter(expr, matches):
    op = matches[pattern_op]
    arg = matches[pattern_arg]
    const = matches[pattern_const].data.numpy()
    if const.shape == tuple() and const[()] == 1.0:
        return rx.Call(op, [arg, rx.const(2.0)])
    else:
        return expr

after = rewrite_call(pattern, rewriter, before)
after.show()
# from tvm.script import relax as R

@R.function(private=True)
def main(x: R.Tensor((1024,))) -> R.Tensor((1024,)):
    with R.dataflow():
        y: R.Tensor((1024,)) = R.add(x, x)
        out: R.Tensor((1024,)) = R.add(y, R.const(2.0, "float32"))
        R.output(out)
    return out

测试重复模式匹配#

重写调用应迭代直到收敛:

@R.function(private=True)
def before(
    x: R.Tensor((1024,)),
    y: R.Tensor((1024,)),
    z: R.Tensor((1024,)),
):
    with R.dataflow():
        a = R.add(x, y)
        b = R.add(a, z)
        out = R.multiply(b, R.const(5.0))
        R.output(out)
    return out
before.show()
# from tvm.script import relax as R

@R.function(private=True)
def main(x: R.Tensor((1024,)), y: R.Tensor((1024,)), z: R.Tensor((1024,))) -> R.Tensor((1024,)):
    with R.dataflow():
        a: R.Tensor((1024,)) = R.add(x, y)
        b: R.Tensor((1024,)) = R.add(a, z)
        out: R.Tensor((1024,)) = R.multiply(b, R.const(5.0, "float32"))
        R.output(out)
    return out
pattern_add_lhs = wildcard()
pattern_add_rhs = wildcard()
pattern_add = is_op("relax.add")(pattern_add_lhs, pattern_add_rhs)

mul_const = is_const()
pattern_mul = is_op("relax.multiply")(pattern_add, mul_const)

pattern = pattern_mul

def rewriter(_expr, matches):
    const = matches[mul_const]
    return (matches[pattern_add_lhs] * const) + (matches[pattern_add_rhs] * const)

after = rewrite_call(pattern, rewriter, before)
after.show()
# from tvm.script import relax as R

@R.function(private=True)
def main(x: R.Tensor((1024,)), y: R.Tensor((1024,)), z: R.Tensor((1024,))) -> R.Tensor((1024,)):
    with R.dataflow():
        lv3: R.Tensor((1024,)) = R.multiply(x, R.const(5.0, "float32"))
        lv4: R.Tensor((1024,)) = R.multiply(y, R.const(5.0, "float32"))
        lv1: R.Tensor((1024,)) = R.add(lv3, lv4)
        lv2: R.Tensor((1024,)) = R.multiply(z, R.const(5.0, "float32"))
        out: R.Tensor((1024,)) = R.add(lv1, lv2)
        R.output(out)
    return out