解读 EnvFunc

解读 EnvFunc#

import tvm
import tvm.testing # 保证 "attrs.TestAttrs" 被注册
tvm.ir.EnvFunc??

Hide code cell output

Init signature: tvm.ir.EnvFunc(self, /, *args, **kwargs)
Source:        
@register_object("ir.EnvFunc")
class EnvFunc(Object):
    """Environment function.

    This is a global function object that can be serialized by its name.
    """

    def __call__(self, *args):
        return _ffi_api.EnvFuncCall(self, *args)  # type: ignore # pylint: disable=no-member

    @property
    def func(self):
        return _ffi_api.EnvFuncGetFunction(self)  # type: ignore # pylint: disable=no-member

    @staticmethod
    def get(name):
        """Get a static env function

        Parameters
        ----------
        name : str
            The name of the function.
        """
        return _ffi_api.EnvFuncGet(name)  # type: ignore # pylint: disable=no-member
File:           /media/pc/data/lxw/ai/tvm/python/tvm/ir/base.py
Type:           type
Subclasses:     
@tvm.register_func("test.env_func")
def test(x):
    return x + 1
f = tvm.get_global_func("test.env_func")
x = tvm.ir.EnvFunc.get("test.env_func")
assert x.name == "test.env_func"
json_str = tvm.ir.save_json([x])
y = tvm.ir.load_json(json_str)[0]
assert y.name == x.name
assert y(1) == 2
assert y.func(1) == 2

x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4), func=y)
# assert x.name == "xx"
assert x.padding[0].value == 3
assert x.padding[1].value == 4
assert x.axis == 10
x = tvm.ir.load_json(tvm.ir.save_json(x))
assert isinstance(x.func, tvm.ir.EnvFunc)
assert x.func(10) == 11
type(y), type(f), type(y.func)
(tvm.ir.base.EnvFunc,
 tvm.runtime.packed_func.PackedFunc,
 tvm.runtime.packed_func.PackedFunc)
import tvm
import numpy as np
import pytest
from tvm import te
from tvm.ffi.access_path import AccessPath
from tvm.script import tir as T, ir as I


def consistent_equal(x, y, map_free_vars=False):
    struct_equal0 = tvm.ir.structural_equal(x, y, map_free_vars)
    struct_equal1 = tvm.ir.structural_equal(y, x, map_free_vars)

    xhash = tvm.ir.structural_hash(x, map_free_vars)
    yhash = tvm.ir.structural_hash(y, map_free_vars)

    if struct_equal0 != struct_equal1:
        raise ValueError(
            "Non-commutative {} vs {}, sequal0={}, sequal1={}".format(
                x, y, struct_equal0, struct_equal1
            )
        )

    # NOTE: hash colision can happen but should be rare.
    # we can confirm that hash colison doesn't happen for our testcases
    if struct_equal0 != (xhash == yhash):
        raise ValueError(
            "Inconsistent {} vs {}, sequal={}, xhash={}, yhash={}".format(
                x, y, struct_equal0, xhash, yhash
            )
        )
    return struct_equal0
@tvm.register_func("test.sequal.env_func")
def test(x):
    return x + 1

x = tvm.ir.EnvFunc.get("test.sequal.env_func")
y = tvm.ir.EnvFunc.get("test.sequal.env_func")
assert consistent_equal(y, x)