DPL 模式重写#
from tvm.script import relax as R
from tvm.script import tir as T
from tvm import relax as rx
from tvm import relay, tir
from tvm.relax.analysis import get_var2val
from tvm.relax.dpl import *
简单的示例#
原始表达式为:
\[\begin{split}
\begin{aligned}
&x2 = x + x\\
&x4 = x2 + x2
\end{aligned}
\end{split}\]
@R.function
def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16), "float32"):
with R.dataflow():
x2 = R.add(x, x)
x4 = R.add(x2, x2)
R.output(x4)
return x4
main.show()
# from tvm.script import relax as R
@R.function
def main(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"):
with R.dataflow():
x2: R.Tensor((16, 16), dtype="float32") = R.add(x, x)
x4: R.Tensor((16, 16), dtype="float32") = R.add(x2, x2)
R.output(x4)
return x4
构建模板:
x = wildcard()
pattern = is_op("relax.add")(x, x)
重写模式:
def rewriter(_, matchings):
return R.multiply(matchings[x], R.const(2, "float32"))
rewritten = rewrite_call(pattern, rewriter, main)
rewritten.show()
# from tvm.script import relax as R
@R.function
def main(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"):
with R.dataflow():
x2: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(2.0, "float32"))
x4: R.Tensor((16, 16), dtype="float32") = R.multiply(x2, R.const(2.0, "float32"))
R.output(x4)
return x4
此时将 \(x + x\) 重写为 \(x \times 2\) 。更进一步,原式可以化简为:
add1 = is_op("relax.add")(x, x)
pattern = is_op("relax.add")(add1, add1)
def rewriter(_, matchings):
return R.multiply(matchings[x], R.const(4, "float32"))
rewritten = rewrite_call(pattern, rewriter, main)
rewritten.show()
# from tvm.script import relax as R
@R.function
def main(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"):
with R.dataflow():
x4: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(4.0, "float32"))
R.output(x4)
return x4
不进行重写,按原样返回原始调用节点:
def rewriter(orig, _):
return orig
rewritten = rewrite_call(pattern, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, main)
重写注意力模块#
@R.function
def main(
Q: R.Tensor((2, 4096, 8, 40), "float32"),
K: R.Tensor((2, 4096, 8, 40), "float32"),
V: R.Tensor((2, 4096, 8, 40), "float32"),
) -> R.Tensor((2, 4096, 8, 40), "float32"):
with R.dataflow():
lv58 = R.permute_dims(Q, axes=[0, 2, 1, 3])
lv59 = R.reshape(lv58, R.shape([16, 4096, 40]))
lv61 = R.permute_dims(K, axes=[0, 2, 1, 3])
lv62 = R.reshape(lv61, R.shape([16, 4096, 40]))
lv64 = R.permute_dims(V, axes=[0, 2, 1, 3])
lv65 = R.reshape(lv64, R.shape([16, 4096, 40]))
lv62_transposed = R.permute_dims(lv62, axes=[0, 2, 1])
lv3_1 = R.matmul(lv59, lv62_transposed)
lv68 = R.multiply(lv3_1, R.const(0.15811388194561005, "float32"))
lv69 = R.nn.softmax(lv68, axis=-1)
lv_3 = R.matmul(lv69, lv65)
lv71 = R.reshape(lv_3, R.shape([2, 8, 4096, 40]))
lv72 = R.permute_dims(lv71, axes=[0, 2, 1, 3])
R.output(lv72)
return lv72
main.show()
# from tvm.script import relax as R
@R.function
def main(Q: R.Tensor((2, 4096, 8, 40), dtype="float32"), K: R.Tensor((2, 4096, 8, 40), dtype="float32"), V: R.Tensor((2, 4096, 8, 40), dtype="float32")) -> R.Tensor((2, 4096, 8, 40), dtype="float32"):
with R.dataflow():
lv58: R.Tensor((2, 8, 4096, 40), dtype="float32") = R.permute_dims(Q, axes=[0, 2, 1, 3])
lv59: R.Tensor((16, 4096, 40), dtype="float32") = R.reshape(lv58, R.shape([16, 4096, 40]))
lv61: R.Tensor((2, 8, 4096, 40), dtype="float32") = R.permute_dims(K, axes=[0, 2, 1, 3])
lv62: R.Tensor((16, 4096, 40), dtype="float32") = R.reshape(lv61, R.shape([16, 4096, 40]))
lv64: R.Tensor((2, 8, 4096, 40), dtype="float32") = R.permute_dims(V, axes=[0, 2, 1, 3])
lv65: R.Tensor((16, 4096, 40), dtype="float32") = R.reshape(lv64, R.shape([16, 4096, 40]))
lv62_transposed: R.Tensor((16, 40, 4096), dtype="float32") = R.permute_dims(lv62, axes=[0, 2, 1])
lv3_1: R.Tensor((16, 4096, 4096), dtype="float32") = R.matmul(lv59, lv62_transposed, out_dtype="void")
lv68: R.Tensor((16, 4096, 4096), dtype="float32") = R.multiply(lv3_1, R.const(0.15811388194561005, "float32"))
lv69: R.Tensor((16, 4096, 4096), dtype="float32") = R.nn.softmax(lv68, axis=-1)
lv_3: R.Tensor((16, 4096, 40), dtype="float32") = R.matmul(lv69, lv65, out_dtype="void")
lv71: R.Tensor((2, 8, 4096, 40), dtype="float32") = R.reshape(lv_3, R.shape([2, 8, 4096, 40]))
lv72: R.Tensor((2, 4096, 8, 40), dtype="float32") = R.permute_dims(lv71, axes=[0, 2, 1, 3])
R.output(lv72)
return lv72
构建模板:
def BSNH_to_BSH(tensor):
return is_op("relax.reshape")(is_op("relax.permute_dims")(tensor), wildcard())
def BSH_to_BSNH(tensor):
return is_op("relax.permute_dims")(is_op("relax.reshape")(tensor, wildcard()))
Q = wildcard()
K = wildcard()
V = wildcard()
Q_3D = BSNH_to_BSH(Q)
V_3D = BSNH_to_BSH(V)
K_3D = BSNH_to_BSH(K)
matmul1 = is_op("relax.matmul")(Q_3D, is_op("relax.permute_dims")(V_3D))
multiply = is_op("relax.multiply")(matmul1, is_const())
softmax = is_op("relax.nn.softmax")(multiply)
matmul2 = is_op("relax.matmul")(softmax, K_3D)
pattern = BSH_to_BSNH(matmul2)
def rewriter(_, matchings):
return R.nn.attention(matchings[Q], matchings[K], matchings[V])
rewritten = rewrite_call(pattern, rewriter, main)
rewritten.show()
# from tvm.script import relax as R
@R.function
def main(Q: R.Tensor((2, 4096, 8, 40), dtype="float32"), K: R.Tensor((2, 4096, 8, 40), dtype="float32"), V: R.Tensor((2, 4096, 8, 40), dtype="float32")) -> R.Tensor((2, 4096, 8, 40), dtype="float32"):
with R.dataflow():
lv72: R.Tensor((2, 4096, 8, 40), dtype="float32") = R.nn.attention(Q, V, K, scale=None, causal_mask=None, window_size=None)
R.output(lv72)
return lv72
测试交换律模式匹配#
@R.function(private=True)
def before(
x: R.Tensor((1024,)),
):
with R.dataflow():
y = R.add(x, x)
out = R.add(R.const(1.0), y)
R.output(out)
return out
before.show()
# from tvm.script import relax as R
@R.function(private=True)
def main(x: R.Tensor((1024,))) -> R.Tensor((1024,)):
with R.dataflow():
y: R.Tensor((1024,)) = R.add(x, x)
out: R.Tensor((1024,)) = R.add(R.const(1.0, "float32"), y)
R.output(out)
return out
pattern_add = is_op("relax.add")
pattern_mul = is_op("relax.multiply")
pattern_op = pattern_add | pattern_mul
pattern_arg = wildcard()
pattern_const = is_const()
pattern = pattern_op(pattern_arg, pattern_const)
def rewriter(expr, matches):
op = matches[pattern_op]
arg = matches[pattern_arg]
const = matches[pattern_const].data.numpy()
if const.shape == tuple() and const[()] == 1.0:
return rx.Call(op, [arg, rx.const(2.0)])
else:
return expr
after = rewrite_call(pattern, rewriter, before)
after.show()
# from tvm.script import relax as R
@R.function(private=True)
def main(x: R.Tensor((1024,))) -> R.Tensor((1024,)):
with R.dataflow():
y: R.Tensor((1024,)) = R.add(x, x)
out: R.Tensor((1024,)) = R.add(y, R.const(2.0, "float32"))
R.output(out)
return out
测试重复模式匹配#
重写调用应迭代直到收敛:
@R.function(private=True)
def before(
x: R.Tensor((1024,)),
y: R.Tensor((1024,)),
z: R.Tensor((1024,)),
):
with R.dataflow():
a = R.add(x, y)
b = R.add(a, z)
out = R.multiply(b, R.const(5.0))
R.output(out)
return out
before.show()
# from tvm.script import relax as R
@R.function(private=True)
def main(x: R.Tensor((1024,)), y: R.Tensor((1024,)), z: R.Tensor((1024,))) -> R.Tensor((1024,)):
with R.dataflow():
a: R.Tensor((1024,)) = R.add(x, y)
b: R.Tensor((1024,)) = R.add(a, z)
out: R.Tensor((1024,)) = R.multiply(b, R.const(5.0, "float32"))
R.output(out)
return out
pattern_add_lhs = wildcard()
pattern_add_rhs = wildcard()
pattern_add = is_op("relax.add")(pattern_add_lhs, pattern_add_rhs)
mul_const = is_const()
pattern_mul = is_op("relax.multiply")(pattern_add, mul_const)
pattern = pattern_mul
def rewriter(_expr, matches):
const = matches[mul_const]
return (matches[pattern_add_lhs] * const) + (matches[pattern_add_rhs] * const)
after = rewrite_call(pattern, rewriter, before)
after.show()
# from tvm.script import relax as R
@R.function(private=True)
def main(x: R.Tensor((1024,)), y: R.Tensor((1024,)), z: R.Tensor((1024,))) -> R.Tensor((1024,)):
with R.dataflow():
lv3: R.Tensor((1024,)) = R.multiply(x, R.const(5.0, "float32"))
lv4: R.Tensor((1024,)) = R.multiply(y, R.const(5.0, "float32"))
lv1: R.Tensor((1024,)) = R.add(lv3, lv4)
lv2: R.Tensor((1024,)) = R.multiply(z, R.const(5.0, "float32"))
out: R.Tensor((1024,)) = R.add(lv1, lv2)
R.output(out)
return out