匹配子图#
import numpy as np
import tvm
from tvm import relax
from tvm.relax.backend.cuda.cublas import partition_for_cublas
from tvm.relax.backend.cuda.cutlass import partition_for_cutlass
from tvm.relax.dpl.pattern import (
is_op,
is_tuple_get_item,
make_fused_bias_activation_pattern,
wildcard,
)
from tvm.relax.transform import PatternCheckContext
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
@R.function
def func(
x: R.Tensor((32, 8), dtype="int32"),
y: R.Tensor((8, 8), dtype="int32"),
bias: R.Tensor((8,), dtype="int32"),
) -> R.Tensor((32, 8), dtype="int32"):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
lv0 = R.matmul(x, y, out_dtype="int32")
lv1 = R.add(lv0, bias)
lv2 = R.clip(lv1, -128, 127)
R.output(lv2)
return lv2
mod = tvm.IRModule({"main": func})
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((32, 8), dtype="int32"), y: R.Tensor((8, 8), dtype="int32"), bias: R.Tensor((8,), dtype="int32")) -> R.Tensor((32, 8), dtype="int32"):
with R.dataflow():
lv0: R.Tensor((32, 8), dtype="int32") = R.matmul(x, y, out_dtype="int32")
lv1: R.Tensor((32, 8), dtype="int32") = R.add(lv0, bias)
lv2: R.Tensor((32, 8), dtype="int32") = R.clip(lv1, R.prim_value(-128), R.prim_value(127))
R.output(lv2)
return lv2
matmul = is_op("relax.matmul")(wildcard(), wildcard())
matmul_add = is_op("relax.add")(matmul, wildcard())
pattern = matmul_add | is_op("relax.clip")(matmul_add, wildcard(), wildcard())
partitioned = relax.transform.FuseOpsByPattern([("orclip", pattern)])(mod)
func_names = [name.name_hint for (name, _) in partitioned.functions.items()]
assert "fused_relax_matmul_relax_add_relax_clip" in func_names