解读 tvm/script/parser/core/diagnostics.py

目录

解读 tvm/script/parser/core/diagnostics.py#

解析 Source#

import set_env
import pytest
import inspect
import tvm.testing
from tvm.script.parser.core.diagnostics import Source
from tvm.script.parser.core import doc_core as doc
from tvm.script import tir as T
Source?
Init signature: Source(program: Union[str, tvm.script.parser.core.doc_core.AST])
Docstring:     
Source code class for TVMScript.

It is constructed by source code str or doc AST tree.

Parameters
----------
source_name : str
    The filename of the file where the source code locates.

start_line : int
    The first line number of the source code.

start_column : int
    The first column number of the first line of the source code.

source : str
    The source code str of source code.

full_source : str
    The complete source code of the file where the source code locates.
File:           /media/pc/data/lxw/ai/tvm/python/tvm/script/parser/core/diagnostics.py
Type:           type
Subclasses:
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 128])
    for i, j, k in T.grid(128, 128, 128):
        with T.block("update"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


def test_source_base():
    source = Source(matmul)
    assert (
        source.source_name == inspect.getsourcefile(matmul)
        and source.start_line is not None
        and source.start_column == 0
        and source.source == inspect.getsource(matmul)
        and source.full_source == inspect.getsource(inspect.getmodule(matmul))
    )
test_source_base()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 20
     12     source = Source(matmul)
     13     assert (
     14         source.source_name == inspect.getsourcefile(matmul)
     15         and source.start_line is not None
   (...)
     18         and source.full_source == inspect.getsource(inspect.getmodule(matmul))
     19     )
---> 20 test_source_base()

Cell In[4], line 18, in test_source_base()
     11 def test_source_base():
     12     source = Source(matmul)
     13     assert (
     14         source.source_name == inspect.getsourcefile(matmul)
     15         and source.start_line is not None
     16         and source.start_column == 0
     17         and source.source == inspect.getsource(matmul)
---> 18         and source.full_source == inspect.getsource(inspect.getmodule(matmul))
     19     )

File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/inspect.py:1282, in getsource(object)
   1276 def getsource(object):
   1277     """Return the text of the source code for an object.
   1278 
   1279     The argument may be a module, class, method, function, traceback, frame,
   1280     or code object.  The source code is returned as a single string.  An
   1281     OSError is raised if the source code cannot be retrieved."""
-> 1282     lines, lnum = getsourcelines(object)
   1283     return ''.join(lines)

File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/inspect.py:1264, in getsourcelines(object)
   1256 """Return a list of source lines and starting line number for an object.
   1257 
   1258 The argument may be a module, class, method, function, traceback, frame,
   (...)
   1261 original source file the first line of code was found.  An OSError is
   1262 raised if the source code cannot be retrieved."""
   1263 object = unwrap(object)
-> 1264 lines, lnum = findsource(object)
   1266 if istraceback(object):
   1267     object = object.tb_frame

File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/inspect.py:1075, in findsource(object)
   1067 def findsource(object):
   1068     """Return the entire source file and starting line number for an object.
   1069 
   1070     The argument may be a module, class, method, function, traceback, frame,
   1071     or code object.  The source code is returned as a list of all the lines
   1072     in the file and the line number indexes a line in that list.  An OSError
   1073     is raised if the source code cannot be retrieved."""
-> 1075     file = getsourcefile(object)
   1076     if file:
   1077         # Invalidate cache if needed.
   1078         linecache.checkcache(file)

File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/inspect.py:952, in getsourcefile(object)
    948 def getsourcefile(object):
    949     """Return the filename that can be used to locate an object's source.
    950     Return None if no way can be identified to get the source.
    951     """
--> 952     filename = getfile(object)
    953     all_bytecode_suffixes = importlib.machinery.DEBUG_BYTECODE_SUFFIXES[:]
    954     all_bytecode_suffixes += importlib.machinery.OPTIMIZED_BYTECODE_SUFFIXES[:]

File /media/pc/data/lxw/ai/tvm/python/tvm/script/parser/core/diagnostics.py:110, in _patched_inspect_getfile(obj)
    108 """Work out which source or compiled file an object was defined in."""
    109 if not inspect.isclass(obj):
--> 110     return _getfile(obj)
    111 mod = getattr(obj, "__module__", None)
    112 if mod is not None:

File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/inspect.py:913, in getfile(object)
    911     if getattr(object, '__file__', None):
    912         return object.__file__
--> 913     raise TypeError('{!r} is a built-in module'.format(object))
    914 if isclass(object):
    915     if hasattr(object, '__module__'):

TypeError: <module '__main__'> is a built-in module
def test_source_ast():
    source = Source(matmul)
    mod = source.as_ast()
    assert isinstance(mod, doc.Module)
    func_def = mod.body[0]
    assert isinstance(func_def, doc.FunctionDef)
    assert func_def.name == "matmul"
    func_args = func_def.args
    assert (
        len(func_args.args) == 3
        and func_args.args[0].arg == "a"
        and func_args.args[1].arg == "b"
        and func_args.args[2].arg == "c"
    )
    func_body = func_def.body
    assert len(func_body) == 4
    func_assigns = func_body[:3]
    assert (
        isinstance(func_assigns[0], doc.Assign)
        and func_assigns[0].targets[0].id == "A"
        and isinstance(func_assigns[1], doc.Assign)
        and func_assigns[1].targets[0].id == "B"
        and isinstance(func_assigns[2], doc.Assign)
        and func_assigns[2].targets[0].id == "C"
    )
    func_for = func_body[3]
    assert (
        len(func_for.target.elts) == 3
        and func_for.target.elts[0].id == "i"
        and func_for.target.elts[1].id == "j"
        and func_for.target.elts[2].id == "k"
    )
    for_body = func_for.body
    assert len(for_body) == 1
    for_block = for_body[0]
    assert isinstance(for_block, doc.With) and len(for_block.body) == 2


def test_nesting_parsing():
    class dummy:
        pass

    for i in range(1):

        @tvm.script.ir_module
        class Module:
            @T.prim_func
            def impl(
                A: T.Buffer((12, 196, 64), "float32"),
            ) -> None:
                T.evaluate(0)