Ramp

Ramp#

%cd ..
import testing
/media/pc/data/lxw/ai/tvm-book/doc/read/tir

源码:tvm/include/tvm/tir/expr.h

import tvm
tvm.tir.Ramp??
Init signature:
tvm.tir.Ramp(
    base: tvm.ir.expr.PrimExpr,
    stride: tvm.ir.expr.PrimExpr,
    lanes: tvm.ir.expr.PrimExpr,
    span: Optional[tvm.ir.base.Span] = None,
) -> None
Source:        
@tvm._ffi.register_object("tir.Ramp")
class Ramp(PrimExprWithOp):
    """Ramp node.

    Parameters
    ----------
    base : PrimExpr
        The base expression.

    stride : PrimExpr
        The stride of the ramp.

    lanes : PrimExpr
        The lanes of the expression.

    span : Optional[Span]
        The location of this expression in the source code.
    """

    base: PrimExpr
    stride: PrimExpr
    lanes: PrimExpr

    def __init__(
        self, base: PrimExpr, stride: PrimExpr, lanes: PrimExpr, span: Optional[Span] = None
    ) -> None:
        self.__init_handle_by_constructor__(
            _ffi_api.Ramp, base, stride, lanes, span  # type: ignore
        )
File:           /media/pc/data/lxw/ai/tvm/python/tvm/tir/expr.py
Type:           type
Subclasses:     
/*!
 * \brief Construct a vector with lanes elements
 *        where its i-th element equals base + i * stride.
 *  This is useful to construct a index for a continuous vector load.
 *
 *  Examples:
 *  - ramp(0, 1, 3) = [0, 1, 2]
 *  - ramp(1, 2, 4) = [1, 3, 5, 7]
 */
class RampNode : public PrimExprNode {
 public:
  /*! \brief The base value. */
  PrimExpr base;
  /*! \brief The stride of each step. */
  PrimExpr stride;
  /*! \brief Total number of lanes. */
  PrimExpr lanes;

  void VisitAttrs(AttrVisitor* v) {
    v->Visit("dtype", &dtype);
    v->Visit("base", &base);
    v->Visit("stride", &stride);
    v->Visit("lanes", &lanes);
    v->Visit("span", &span);
  }

  bool SEqualReduce(const RampNode* other, SEqualReducer equal) const {
    return equal(dtype, other->dtype) && equal(base, other->base) && equal(stride, other->stride) &&
           equal(lanes, other->lanes);
  }

  void SHashReduce(SHashReducer hash_reduce) const {
    hash_reduce(dtype);
    hash_reduce(base);
    hash_reduce(stride);
    hash_reduce(lanes);
  }

  static constexpr const char* _type_key = "tir.Ramp";
  TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode);
};

RampNode 类,它继承自 PrimExprNode。这个类用于构建具有多个元素的向量,其中第 i 个元素等于 base + i * stride。这在构造连续向量加载的索引时非常有用。

类中有三个成员变量:

  1. base:表示向量的起始值。

  2. stride:表示每个步长的跨度。

  3. lanes:表示向量的总元素个数。

类中还包含了三个方法:

  1. VisitAttrs:用于访问类的属性。

  2. SEqualReduce:用于比较两个 RampNode 对象是否相等。

  3. SHashReduce:用于计算 RampNode 对象的哈希值。

此外,类中还定义了静态常量字符串 _type_key,用于表示类的类型,以及宏 TVM_DECLARE_FINAL_OBJECT_INFO,用于声明类的相关信息。