tvm_ffi.error 源代码

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Error handling."""

from __future__ import annotations

import ast
import re
import sys
import types
from typing import Any

from . import core


def _parse_backtrace(backtrace: str) -> list[tuple[str, int, str]]:
    """Parse the backtrace string into a list of (filename, lineno, func).

    Parameters
    ----------
    backtrace
        The backtrace string.

    Returns
    -------
    result
        The list of (filename, lineno, func)

    """
    pattern = r'File "(.+?)", line (\d+), in (.+)'
    result = []
    for line in backtrace.split("\n"):
        match = re.match(pattern, line.strip())
        if match:
            try:
                filename = match.group(1)
                lineno = int(match.group(2))
                func = match.group(3)
                result.append((filename, lineno, func))
            except ValueError:
                pass
    return result


class TracebackManager:
    """Helper to manage traceback generation."""

    def __init__(self) -> None:
        """Initialize the traceback manager and its cache."""
        self._code_cache: dict[tuple[str, int, str], types.CodeType] = {}

    def _get_cached_code_object(self, filename: str, lineno: int, func: str) -> types.CodeType:
        # Hack to create a code object that points to the correct
        # line number and function name
        key = (filename, lineno, func)
        # cache the code object to avoid re-creating it
        if key in self._code_cache:
            return self._code_cache[key]
        # Parse to AST and zero out column info
        # since column info are not accurate in original trace
        tree = ast.parse("_getframe()", filename=filename, mode="eval")
        for node in ast.walk(tree):
            if hasattr(node, "col_offset"):
                node.col_offset = 0
            if hasattr(node, "end_col_offset"):
                node.end_col_offset = 0
        # call into get frame, bt changes the context
        code_object = compile(tree, filename, "eval")
        # replace the function name and line number
        code_object = code_object.replace(co_name=func, co_firstlineno=lineno)
        self._code_cache[key] = code_object
        return code_object

    def _create_frame(self, filename: str, lineno: int, func: str) -> types.FrameType:
        """Create a frame object from the filename, lineno, and func."""
        code_object = self._get_cached_code_object(filename, lineno, func)
        # call into get frame, but changes the context so the code
        # points to the correct frame
        context = {"_getframe": sys._getframe}
        # pylint: disable=eval-used
        return eval(code_object, context, context)

    def append_traceback(
        self,
        tb: types.TracebackType | None,
        filename: str,
        lineno: int,
        func: str,
    ) -> types.TracebackType:
        """Append a traceback to the given traceback.

        Parameters
        ----------
        tb
            The traceback to append to.
        filename
            The filename of the traceback
        lineno
            The line number of the traceback
        func
            The function name of the traceback

        Returns
        -------
        new_tb
            The new traceback with the appended frame.

        """

        # This approach avoids binding the created frame object to a local variable
        # in `append_traceback`, which would create a reference cycle. By using a
        # nested function, the frame object is a temporary that is not held by
        # the locals of `append_traceback`. See the diagram in `_with_append_backtrace`
        # and PR #327 for more details.
        def create(
            tb: types.TracebackType | None, frame: types.FrameType, lineno: int
        ) -> types.TracebackType:
            return types.TracebackType(tb, frame, frame.f_lasti, lineno)

        return create(tb, self._create_frame(filename, lineno, func), lineno)


_TRACEBACK_MANAGER = TracebackManager()


def _with_append_backtrace(py_error: BaseException, backtrace: str) -> BaseException:
    """Append the backtrace to the py_error and return it."""
    # We manually delete py_error and tb to avoid reference cycle, making it faster to gc the locals inside the frame
    # please see pull request #327 for more details
    #
    # Memory Cycle Diagram:
    #
    #         [Stack Frames]                            [Heap Objects]
    #     +-------------------+
    #     | outside functions | -----------------------> [ Tensor ]
    #     +-------------------+                   (Held by cycle, slow to free)
    #             ^
    #             | f_back
    #     +-------------------+  locals      py_error
    #     | py_error (this)   | -----+--------------> [ BaseException ]
    #     +-------------------+      |                       |
    #             ^                  |                       | (with_traceback)
    #             | f_back           |                       v
    #     +-------------------+      +--------------> [ Traceback Obj ]
    #     | append_traceback  |                   tb         |
    #     +-------------------+                              |
    #             ^                                          |
    #             | f_back                                   |
    #     +-------------------+                              |
    #     | _create_frame     |                              |
    #     +-------------------+                              |
    #             ^                                          |
    #             | f_back                                   |
    #     +-------------------+                              |
    #     | _get_frame        | <----------------------------+
    #     +-------------------+      (Cycle closes here)
    tb = py_error.__traceback__
    try:
        for filename, lineno, func in _parse_backtrace(backtrace):
            tb = _TRACEBACK_MANAGER.append_traceback(tb, filename, lineno, func)
        return py_error.with_traceback(tb)
    finally:
        # We explicitly break the reference cycle here. The `finally` block is
        # executed just before the function returns, after the `return` expression
        # in the `try` block has been evaluated. Deleting `py_error` and `tb`
        # here ensures they are not held by this function's frame's locals,
        # which resolves the cycle.
        del py_error, tb


def _traceback_to_backtrace_str(tb: types.TracebackType | None) -> str:
    """Convert the traceback to a string."""
    lines = []
    while tb is not None:
        frame = tb.tb_frame
        lineno = tb.tb_lineno
        filename = frame.f_code.co_filename
        funcname = frame.f_code.co_name
        lines.append(f'  File "{filename}", line {lineno}, in {funcname}\n')
        tb = tb.tb_next
    # needs to reverse the order of the lines so backtrace stores in
    # the reverse order of python traceback
    return "".join(reversed(lines))


core._WITH_APPEND_BACKTRACE = _with_append_backtrace
core._TRACEBACK_TO_BACKTRACE_STR = _traceback_to_backtrace_str


[文档] def register_error( name_or_cls: str | type | None = None, cls: type | None = None, ) -> Any: """Register an error class so it can be recognized by the ffi error handler. Parameters ---------- name_or_cls The name of the error class. cls The class to register. Returns ------- fregister Register function if f is not specified. Examples -------- .. code-block:: python import tvm_ffi # Register a custom Python exception so tvm_ffi.Error maps to it @tvm_ffi.error.register_error class MyError(RuntimeError): pass # Convert a Python exception to an FFI Error and back ffi_err = tvm_ffi.convert(MyError("boom")) py_err = ffi_err.py_error() assert isinstance(py_err, MyError) """ if callable(name_or_cls): cls = name_or_cls name_or_cls = cls.__name__ def register(mycls: type) -> type: """Register the error class name with the FFI core.""" err_name = name_or_cls if isinstance(name_or_cls, str) else mycls.__name__ core.ERROR_NAME_TO_TYPE[err_name] = mycls core.ERROR_TYPE_TO_NAME[mycls] = err_name return mycls if cls is None: return register return register(cls)
register_error("RuntimeError", RuntimeError) register_error("ValueError", ValueError) register_error("TypeError", TypeError) register_error("AttributeError", AttributeError) register_error("KeyError", KeyError) register_error("IndexError", IndexError) register_error("AssertionError", AssertionError) register_error("MemoryError", MemoryError)