解读 tvm/script/parser/core/diagnostics.py
#
解析 Source
#
import set_env
import pytest
import inspect
import tvm.testing
from tvm.script.parser.core.diagnostics import Source
from tvm.script.parser.core import doc_core as doc
from tvm.script import tir as T
Source?
Init signature: Source(program: Union[str, tvm.script.parser.core.doc_core.AST])
Docstring:
Source code class for TVMScript.
It is constructed by source code str or doc AST tree.
Parameters
----------
source_name : str
The filename of the file where the source code locates.
start_line : int
The first line number of the source code.
start_column : int
The first column number of the first line of the source code.
source : str
The source code str of source code.
full_source : str
The complete source code of the file where the source code locates.
File: /media/pc/data/lxw/ai/tvm/python/tvm/script/parser/core/diagnostics.py
Type: type
Subclasses:
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
def test_source_base():
source = Source(matmul)
assert (
source.source_name == inspect.getsourcefile(matmul)
and source.start_line is not None
and source.start_column == 0
and source.source == inspect.getsource(matmul)
and source.full_source == inspect.getsource(inspect.getmodule(matmul))
)
test_source_base()
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[4], line 20
12 source = Source(matmul)
13 assert (
14 source.source_name == inspect.getsourcefile(matmul)
15 and source.start_line is not None
(...)
18 and source.full_source == inspect.getsource(inspect.getmodule(matmul))
19 )
---> 20 test_source_base()
Cell In[4], line 18, in test_source_base()
11 def test_source_base():
12 source = Source(matmul)
13 assert (
14 source.source_name == inspect.getsourcefile(matmul)
15 and source.start_line is not None
16 and source.start_column == 0
17 and source.source == inspect.getsource(matmul)
---> 18 and source.full_source == inspect.getsource(inspect.getmodule(matmul))
19 )
File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/inspect.py:1282, in getsource(object)
1276 def getsource(object):
1277 """Return the text of the source code for an object.
1278
1279 The argument may be a module, class, method, function, traceback, frame,
1280 or code object. The source code is returned as a single string. An
1281 OSError is raised if the source code cannot be retrieved."""
-> 1282 lines, lnum = getsourcelines(object)
1283 return ''.join(lines)
File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/inspect.py:1264, in getsourcelines(object)
1256 """Return a list of source lines and starting line number for an object.
1257
1258 The argument may be a module, class, method, function, traceback, frame,
(...)
1261 original source file the first line of code was found. An OSError is
1262 raised if the source code cannot be retrieved."""
1263 object = unwrap(object)
-> 1264 lines, lnum = findsource(object)
1266 if istraceback(object):
1267 object = object.tb_frame
File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/inspect.py:1075, in findsource(object)
1067 def findsource(object):
1068 """Return the entire source file and starting line number for an object.
1069
1070 The argument may be a module, class, method, function, traceback, frame,
1071 or code object. The source code is returned as a list of all the lines
1072 in the file and the line number indexes a line in that list. An OSError
1073 is raised if the source code cannot be retrieved."""
-> 1075 file = getsourcefile(object)
1076 if file:
1077 # Invalidate cache if needed.
1078 linecache.checkcache(file)
File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/inspect.py:952, in getsourcefile(object)
948 def getsourcefile(object):
949 """Return the filename that can be used to locate an object's source.
950 Return None if no way can be identified to get the source.
951 """
--> 952 filename = getfile(object)
953 all_bytecode_suffixes = importlib.machinery.DEBUG_BYTECODE_SUFFIXES[:]
954 all_bytecode_suffixes += importlib.machinery.OPTIMIZED_BYTECODE_SUFFIXES[:]
File /media/pc/data/lxw/ai/tvm/python/tvm/script/parser/core/diagnostics.py:110, in _patched_inspect_getfile(obj)
108 """Work out which source or compiled file an object was defined in."""
109 if not inspect.isclass(obj):
--> 110 return _getfile(obj)
111 mod = getattr(obj, "__module__", None)
112 if mod is not None:
File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/inspect.py:913, in getfile(object)
911 if getattr(object, '__file__', None):
912 return object.__file__
--> 913 raise TypeError('{!r} is a built-in module'.format(object))
914 if isclass(object):
915 if hasattr(object, '__module__'):
TypeError: <module '__main__'> is a built-in module
def test_source_ast():
source = Source(matmul)
mod = source.as_ast()
assert isinstance(mod, doc.Module)
func_def = mod.body[0]
assert isinstance(func_def, doc.FunctionDef)
assert func_def.name == "matmul"
func_args = func_def.args
assert (
len(func_args.args) == 3
and func_args.args[0].arg == "a"
and func_args.args[1].arg == "b"
and func_args.args[2].arg == "c"
)
func_body = func_def.body
assert len(func_body) == 4
func_assigns = func_body[:3]
assert (
isinstance(func_assigns[0], doc.Assign)
and func_assigns[0].targets[0].id == "A"
and isinstance(func_assigns[1], doc.Assign)
and func_assigns[1].targets[0].id == "B"
and isinstance(func_assigns[2], doc.Assign)
and func_assigns[2].targets[0].id == "C"
)
func_for = func_body[3]
assert (
len(func_for.target.elts) == 3
and func_for.target.elts[0].id == "i"
and func_for.target.elts[1].id == "j"
and func_for.target.elts[2].id == "k"
)
for_body = func_for.body
assert len(for_body) == 1
for_block = for_body[0]
assert isinstance(for_block, doc.With) and len(for_block.body) == 2
def test_nesting_parsing():
class dummy:
pass
for i in range(1):
@tvm.script.ir_module
class Module:
@T.prim_func
def impl(
A: T.Buffer((12, 196, 64), "float32"),
) -> None:
T.evaluate(0)