tvm.tir.block_scope 源代码
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope."""
from enum import IntEnum
from typing import List, Optional, Union
from tvm._ffi import register_object
from tvm.runtime import Object
from tvm.tir import Block, For
from . import _ffi_api
[文档]
@register_object("tir.StmtSRef")
class StmtSRef(Object):
"""An object that refers to schedulable elements in the TensorIR, aka "sref".
Glossary
- Block sref: An StmtSref that points to a TensorIR block.
- Loop sref: An StmtSRef that points to a TensorIR for loop.
- Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
schedulable statement of its ancestors on the TensorIR AST.
- Root sref: Sref to the root block. Every sref has exactly one parent sref
except for root sref.
- Sref tree: The parent-children-relationship of srefs that forms a tree,
uniquely determined by the TensorIR AST.
"""
seq_index: int
@property
def stmt(self) -> Optional[Union[Block, For]]:
"""The block/for stmt the object refers to"""
return _ffi_api.StmtSRefStmt(self) # type: ignore # pylint: disable=no-member
@property
def parent(self) -> Optional["StmtSRef"]:
"""The parent sref"""
return _ffi_api.StmtSRefParent(self) # type: ignore # pylint: disable=no-member
[文档]
@staticmethod
def inline_mark() -> "StmtSRef":
"""A special StmtSRef, which doesn't point to any stmt in the AST,
only serving as a "mark" to hint compute-at to do the work of compute-inline"""
return _ffi_api.StmtSRefInlineMark() # type: ignore # pylint: disable=no-member
[文档]
@staticmethod
def root_mark() -> "StmtSRef":
"""A special StmtSRef, which doesn't point to any stmt in the AST,
only serving as a "mark" to hint compute-at to do nothing"""
return _ffi_api.StmtSRefRootMark() # type: ignore # pylint: disable=no-member
[文档]
class DepKind(IntEnum):
"""Type of dependency.
Attributes
----------
RAW : int = 0
Read-after-write dependency
WAW : int = 1
Write-after-write dependency
WAR : int = 2
Write-after-read dependency. Not supported in TensorIR for now.
OPAQUE: int = 3
Opaque dependency
"""
RAW = 0
WAW = 1
WAR = 2
OPAQUE = 3
[文档]
@register_object("tir.Dependency")
class Dependency(Object):
"""A tuple (src, dst, kind) representing certain types of dependency.
For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
read-after-write, which means block B reads the result written by block A.
Parameters
----------
src : StmtSRef
The source of the dependency relation
dst : StmtSRef
The destination of the dependency relation
kind : DepKind
The dependency kind
"""
src: StmtSRef
dst: StmtSRef
kind: DepKind
[文档]
@register_object("tir.BlockScope")
class BlockScope(Object):
"""An object corresponds to each block sref in the sref tree, which
tracks the producer-consumer dependency between blocks.
Glossary:
- Block scope: A contiguous subtree of the sref tree, rooted at
each block sref, whose components are:
- scope root: a block sref
- internal srefs: loop srefs
- scope leaves: block srefs
- Child block: The scope leaf blocks under the scope root or a specific internal sref
"""
[文档]
def get_deps_by_src(self, block: StmtSRef) -> List[Dependency]:
"""Get all dependencies whose `src` is the target`block`.
Parameters
----------
block: StmtSRef
The queried block
Returns
-------
blocks: List[Dependency]
The dependencies
"""
return _ffi_api.BlockScopeGetDepsBySrc(self, block) # type: ignore # pylint: disable=no-member
[文档]
def get_deps_by_dst(self, block: StmtSRef) -> List[Dependency]:
"""Get all dependencies whose `dst` is the target `block`.
Parameters
----------
block: StmtSRef
The queried block
Returns
-------
blocks: List[Dependency]
The dependencies
"""
return _ffi_api.BlockScopeGetDepsByDst(self, block) # type: ignore # pylint: disable=no-member