量化 LayerNorm#
import torch
from torch.export import export
import tvm
from tvm import relax
import tvm.testing
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.relax.frontend import detach_params
from tvm.relax.frontend.torch import from_exported_program
class LayerNorm(torch.nn.Module):
def __init__(self, shape):
super().__init__()
self.shape = shape
self.weight = torch.nn.Parameter(torch.ones(shape))
self.bias = torch.nn.Parameter(torch.zeros(shape))
def forward(self, x):
return torch.nn.functional.layer_norm(x, self.shape, self.weight, self.bias, 1e-5)
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
model = LayerNorm((10, 10))
exported_program = export(model, args=example_args,)
mod = from_exported_program(exported_program)
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.layer_norm(x, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1], axes=[-2, -1], epsilon=1.0000000000000001e-05, center=True, scale=True)
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
import numpy as np
import tvm
from tvm import relax
exe = tvm.compile(mod, "llvm")
vm = relax.VirtualMachine(exe, tvm.cpu())
hit_count = {}
ret_vals = {}
def instrument(func, name, before_run, ret_val, *args):
if (name, before_run) not in hit_count:
hit_count[(name, before_run)] = 0
hit_count[(name, before_run)] += 1
assert callable(func)
if before_run:
assert ret_val is None
if name == "layer_norm":
print(ret_val, len(args))
print(name)
if not before_run:
ret_vals[name] = ret_vals.get(name, []) + [ret_val]
data_np = np.random.normal(size=(1, 3, 10, 10)).astype("float32")
vm.set_instrument(instrument)
output = vm["main"](tvm.nd.array(data_np))
vm.builtin.check_tensor_info
vm.builtin.check_tensor_info
vm.builtin.match_shape
vm.builtin.match_shape
vm.builtin.alloc_storage
vm.builtin.alloc_storage
vm.builtin.alloc_tensor
vm.builtin.alloc_tensor
vm.builtin.null_value
vm.builtin.null_value
None 4
layer_norm
None 4
layer_norm
vm.builtin.make_tuple
vm.builtin.make_tuple
vm.builtin.null_value
vm.builtin.null_value
/tmp/ipykernel_1947156/487285595.py:24: UserWarning: Returning type `vm.Storage` which is not registered via register_object, fallback to Object
output = vm["main"](tvm.nd.array(data_np))
ret_vals.keys()
dict_keys(['vm.builtin.check_tensor_info', 'vm.builtin.match_shape', 'vm.builtin.alloc_storage', 'vm.builtin.alloc_tensor', 'vm.builtin.null_value', 'layer_norm', 'vm.builtin.make_tuple'])
[len(v) for k, v in ret_vals.items()]
[1, 1, 1, 1, 2, 1, 1]
exe
<tvm.relax.vm_build.VMExecutable at 0x7ff42813ec60>
print(exe.as_python())
ib = rx.Builder()
with ib.function("main", num_inputs=1):
ib.emit_call("vm.builtin.check_tensor_info", args=[ib.r(0), ib.imm(4), ib.c(0), ib.c(1)])
ib.emit_call("vm.builtin.match_shape", args=[ib.r(0), ib.r(18014398509481984), ib.imm(4), ib.imm(0), ib.imm(1), ib.imm(0), ib.imm(3), ib.imm(0), ib.imm(10), ib.imm(0), ib.imm(10), ib.c(1)])
ib.emit_call("vm.builtin.alloc_storage", args=[ib.r(vm), ib.c(2), ib.imm(0), ib.c(3), ib.c(4)], dst=ib.r(1))
ib.emit_call("vm.builtin.alloc_tensor", args=[ib.r(1), ib.imm(0), ib.c(5), ib.c(0)], dst=ib.r(2))
ib.emit_call("vm.builtin.null_value", args=[], dst=ib.r(1))
ib.emit_call("layer_norm", args=[ib.r(0), ib.c(6), ib.c(7), ib.r(2)])
ib.emit_call("vm.builtin.make_tuple", args=[ib.r(2)], dst=ib.r(3))
ib.emit_call("vm.builtin.null_value", args=[], dst=ib.r(2))
ib.emit_ret(ib.r(3))