定制 Pass#
import tvm
from tvm import te
n = tvm.tir.const(128, "int32")
a = te.placeholder((n,), name="a")
b = te.placeholder((n,), name="b")
c = te.compute((n,), lambda i: a[i] + b[i], name="c")
sch = te.create_schedule(c.op)
ir = tvm.lower(sch, [a, b, c])
ir.show()
loops = []
def find_width8(op):
"""找出所有范围能被 8 除的 'tir.For' 节点。"""
if isinstance(op, tvm.tir.For):
if isinstance(op.extent, tvm.tir.IntImm):
if op.extent.value % 8 == 0:
loops.append(op)
def vectorize8(op):
"""Split can vectorize the loops found in `find_width8`."""
if op in loops:
extent = op.extent.value
name = op.loop_var.name
lo, li = te.var(name + ".outer"), te.var(name + ".inner")
body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li})
body = tvm.tir.For(li, 0, 8, tvm.tir.ForKind.VECTORIZED, body)
body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.ForKind.SERIAL, body)
return body
return None
@tvm.tir.transform.prim_func_pass(opt_level=0)
def vectorize(f, mod, ctx):
global loops
tvm.tir.stmt_functor.post_order_visit(f.body, find_width8)
if not loops:
return f
# 最后一个 list 参数表示要转换的节点类型。
# 因此,在这种情况下,只有 `For` 节点会调用 `vectorize8`
return f.with_body(tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ["tir.For"]))
vectorize.info
with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, vectorize)]}) as ctx:
print(ctx)
tvm.lower(sch, [a, b, c]).show()
loops
@tvm.tir.transform.prim_func_pass(opt_level=1)
class TestReplaceFunc:
def __init__(self, new_func):
self.new_func = new_func
def transform_function(self, func, mod, ctx):
# just for demo purposes
# transform func to new_func
return self.new_func
@tvm.tir.transform.prim_func_pass(opt_level=2)
def transform(func, mod, ctx):
# my transformations here.
return func
function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the following:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
import numpy as np
import tvm
from tvm import relay
def example():
shape = (1, 64, 54, 54)
c_data = np.empty(shape).astype("float32")
c = relay.const(c_data)
weight = relay.var("weight", shape=(64, 64, 3, 3))
x = relay.var("x", relay.TensorType((1, 64, 56, 56), "float32"))
conv = relay.nn.conv2d(x, weight, kernel_size=(3, 3))
y = relay.add(c, c)
y = relay.multiply(y, relay.const(2, "float32"))
y = relay.add(conv, y)
z = relay.add(y, c)
z1 = relay.add(y, c)
z2 = relay.add(z, z1)
return relay.Function([x, weight], z2)