添加算子到 Relay#

在本文档中,将介绍在 Relay 中注册新的 TVM 算子所需的步骤。我们将遵循这个 PR,它增加了 cumulative product 作为例子。PR 本身建立在另一个 PR 的基础上,后者添加了 cumulative sum 算子。

注册新的算子需要几个步骤:

  1. 添加属性节点,声明在编译时已知的固定参数

  2. 为集成到 Relay 类型系统中的运算编写类型关系。

  3. 使用 C++ 中的 RELAY_REGISTER_OP 宏为编译器注册算子的属性、类型和其他提示

  4. 编写算子的计算方式

  5. 注册 Relay 算子的 compute, schedule

  6. 定义 C++ 函数,为算子生成 call 节点,并为该函数注册 Python API 钩子

  7. 将上面的 Python API 钩子包装在更整洁的接口中

  8. 为新的 Relay 算子编写测试

1. 定义属性节点#

属性是固定的参数,应该在编译时就知道。卷积算子的 stride 和 expand 是属于卷积算子属性节点的字段的一个适当的例子。

属性应该定义在 include/tvm/relay/attrs/ 文件夹下的文件中

最终希望创建一个算子,它的接口可以在最终的 python 接口中清楚地看到:

def cumprod(data, axis=None, dtype=None, exclusive=None):
    """Numpy style cumprod op. Return the cumulative inclusive product of the elements along
    a given axis.
    Parameters
    ----------
    data : relay.Expr
        The input data to the operator.
    axis : int, optional
        Axis along which the cumulative product is computed. The default (None) is to compute
        the cumprod over the flattened array.
    dtype : string, optional
        Type of the returned array and of the accumulator in which the elements are multiplied.
        If dtype is not specified, it defaults to the dtype of data.
    exclusive : bool, optional
        If true will return exclusive product in which the first element is not
        included. In other terms, if true, the j-th output element would be
        the product of the first (j-1) elements. Otherwise, it would be the product of
        the first j elements. The product of zero elements will be 1.
    Returns
    -------
    result : relay.Expr
        The result has the same size as data, and the same shape as data if axis is not None.
        If axis is None, the result is a 1-d array.
    """

实现 cumsum() 类似的接口。

因此,当在 include/tvm/relay/attrs/transform.h 中定义属性时,选择算子的 axis、累积 dtype 和 exclusivity 作为结构的适当字段。

/*! \brief Attributes used in cumsum and cumprod operator */
struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> {
  Integer axis;
  DataType dtype;
  Bool exclusive = Bool(false);
  TVM_DECLARE_ATTRS(ScanopAttrs, "relay.attrs.ScanopAttrs") {
    TVM_ATTR_FIELD(axis).describe("The axis to operate over").set_default(NullValue<Integer>());
    TVM_ATTR_FIELD(dtype).describe("Output data type").set_default(NullValue<DataType>());
    TVM_ATTR_FIELD(exclusive)
        .describe("The first element is not included")
        .set_default(Bool(false));
  }
};

2. 编写类型关系#

为了在 Relay 中实现算子的灵活注册以及更丰富的类型表达和粒度,算子使用输入类型和输出类型之间的关系进行类型化。这些关系表示为接受输入类型列表和输出类型列表(其中任何一种类型都可以是不完整的)并返回满足关系的输入和输出类型的函数。这包括可以在编译时静态确定的形状信息。基本上,算子的关系可以除了计算输出类型之外,还可以强制执行所有必要的类型规则(即通过检查输入类型)。

累积乘积和累积加法算子运的类型关系可在 src/relay/op/tensor/transform.cc 中查找:

TVM_REGISTER_NODE_TYPE(ScanopAttrs);
bool ScanopRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) {
    // types: [data, output]
    ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output";
    const auto* data = types[0].as<TensorTypeNode>();
    if (data == nullptr) {
        ICHECK(types[0].as<IncompleteTypeNode>())
        << "Scanop: expect input type to be TensorType but get " << types[0];
        return false;
    }

    const auto* param = attrs.as<ScanopAttrs>();

    auto dtype = param->dtype;
    if (dtype.is_void()) {
        dtype = data->dtype;
    }

    if (param->axis.defined()) {
        reporter->Assign(types[1], TensorType(data->shape, dtype));
    } else {
        auto prod = data->shape[0];
        for (size_t i = 1; i < data->shape.size(); ++i) {
            prod = prod * data->shape[i];
        }
        reporter->Assign(types[1], TensorType({prod}, dtype));
    }

    return true;
}

3. 将 Arity 和 Attributes 与运算关联起来#

然后注册新 ops 的名称,并用调用接口进行注解。C++ 中的 RELAY_REGISTER_OP 宏允许开发人员指定有关 Relay 中算子的以下信息:

  • Arity(参数数量)

  • 位置参数的名称和描述

  • 支持级别(1 表示内部 intrinsic;较高的数字表示较少 integral 或外部支持的算子)

  • 该算子的类型关系

  • 当优化运算时,其他注解也很有用。

再次将其添加到 src/relay/op/tensor/transform.cc 中:

RELAY_REGISTER_OP("cumsum")
    .describe(
        R"doc(Return the cumulative sum of the elements along a given axis.)doc" TVM_ADD_FILELINE)
    .set_num_inputs(1)
    .add_argument("data", "Tensor", "The input tensor.")
    .set_support_level(3)
    .add_type_rel("Cumsum", ScanopRel)
    .set_attr<TOpPattern>("TOpPattern", kOpaque);

RELAY_REGISTER_OP("cumprod")
    .describe(
        R"doc(Return the cumulative product of the elements along a given axis.)doc" TVM_ADD_FILELINE)
    .set_num_inputs(1)
    .add_argument("data", "Tensor", "The input tensor.")
    .set_support_level(3)
    .add_type_rel("Cumprod", ScanopRel)
    .set_attr<TOpPattern>("TOpPattern", kOpaque);

在这种情况下,TOpPattern 是向编译器提供的关于算子计算模式的提示,这可能对融合算子有用。kOpaque 告诉 TVM 不要试图融合这个算子。

4. 定义运算的计算#

虽然已经定义了运算的接口,但仍然需要定义如何执行累积加法以及累积乘积的实际计算。

编写这段代码超出了本教程的范围。现在,假设已经为运算的计算实现了经过良好测试的实现。更多关于如何做到这一点的细节,我们建议查阅以下教程:tensor expressions,并查看在 python/tvm/topi/scan.py 中找到的累积求和和乘积实现示例以及在 python/tvm/topi/cuda/scan.py 中找到的GPU版本。对于我们的累积求和和乘积运算,直接在 TIR 中编写,这是张量表达式和 topi 将7到的表示形式。

5. 将计算和策略与 Relay 勾连起来#

在您实现了计算函数后,我们现在需要将其粘合到我们的 Relay 运算中。在 TVM 中,这不仅仅是定义计算,还包括运算的调度。策略是一种方法,它选择要使用的计算和调度。例如,对于 2D 卷积,可能会识别出我们正在进行深度卷积,并将调度分派给更高效的计算和调度。然而,在我们的情况下,除了在 CPU 和 GPU 实现之间进行调度之外,没有这样的需求。在 python/tvm/relay/op/strategy/generic.pypython/tvm/relay/op/strategy/cuda.py 中,添加以下策略:

def wrap_compute_scanop(topi_compute):
    """Wrap scanop style topi compute"""

    def _compute_scanop(attrs, inputs, _):
        return [topi_compute(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]

    return _compute_scanop


@override_native_generic_func("cumsum_strategy")
def cumsum_strategy(attrs, inputs, out_type, target):
    """cumsum generic strategy"""
    strategy = _op.OpStrategy()
    strategy.add_implementation(
        wrap_compute_scanop(topi.cumsum),
        wrap_topi_schedule(topi.generic.schedule_extern),
        name="cumsum.generic",
    )
    return strategy


@override_native_generic_func("cumprod_strategy")
def cumprod_strategy(attrs, inputs, out_type, target):
    """cumprod generic strategy"""
    strategy = _op.OpStrategy()
    strategy.add_implementation(
        wrap_compute_scanop(topi.cumprod),
        wrap_topi_schedule(topi.generic.schedule_extern),
        name="cumprod.generic",
    )
    return strategy

@cumsum_strategy.register(["cuda", "gpu"])
def cumsum_strategy_cuda(attrs, inputs, out_type, target):
    """cumsum cuda strategy"""
    strategy = _op.OpStrategy()
    strategy.add_implementation(
        wrap_compute_scanop(topi.cuda.cumsum),
        wrap_topi_schedule(topi.cuda.schedule_scan),
        name="cumsum.cuda",
    )
    return strategy


@cumprod_strategy.register(["cuda", "gpu"])
def cumprod_strategy_cuda(attrs, inputs, out_type, target):
    """cumprod cuda strategy"""
    strategy = _op.OpStrategy()
    strategy.add_implementation(
        wrap_compute_scanop(topi.cuda.cumprod),
        wrap_topi_schedule(topi.cuda.schedule_scan),
        name="cumprod.cuda",
    )
    return strategy

在每个策略中,定义了我们在 add_implementation() 中使用的计算和调度。最后,在 python/tvm/relay/op/_transform.py 中将策略和计算与定义的 Relay 运算链接起来:

# cumsum
@_reg.register_compute("cumsum")
def compute_cumsum(attrs, inputs, output_type):
    """Compute definition of cumsum"""
    return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]


_reg.register_strategy("cumsum", strategy.cumsum_strategy)
_reg.register_shape_func("cumsum", False, elemwise_shape_func)

# cumprod
@_reg.register_compute("cumprod")
def compute_cumprod(attrs, inputs, output_type):
    """Compute definition of cumprod"""
    return [topi.cumprod(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]


_reg.register_strategy("cumprod", strategy.cumprod_strategy)
_reg.register_shape_func("cumprod", False, elemwise_shape_func)

形状函数用于确定给定动态形状张量的输出形状。在这种情况下,告诉 TVM 输出形状将与输入形状相同。

6. 创建 Relay Call 节点并公开 Python 钩子#

现在有可以工作的运算,现在只需要通过 Relay Call 节点正确地调用它。这个步骤需要简单地编写函数,该函数接受算子的参数(作为 Relay 表达式),并返回对算子的调用节点(即应该放入 Relay AST 中的节点,以进行算子的调用)。

目前,调用属性和类型参数(最后两个字段)不受支持,因此只需使用 Op::Get 从算子注册表中获取算子的信息,并将参数传递给调用节点即可。在 src/relay/op/tensor/transform.cc 中,如下所示:

Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Bool exclusive) {
    auto attrs = make_object<ScanopAttrs>();
    attrs->dtype = dtype;
    attrs->axis = axis;
    attrs->exclusive = exclusive;
    static const Op& op = Op::Get("cumsum");
    return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.cumsum").set_body_typed(MakeCumsum);

Expr MakeCumprod(Expr data, Integer axis, DataType dtype, Bool exclusive) {
    auto attrs = make_object<ScanopAttrs>();
    attrs->dtype = dtype;
    attrs->axis = axis;
    attrs->exclusive = exclusive;
    static const Op& op = Op::Get("cumprod");
    return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.cumprod").set_body_typed(MakeCumprod);

Where TVM_REGISTER_GLOBAL exposes the MakeCumsum and MakeCumprod functions in Python via relay.op._make.cumsum(...) and relay.op._make.cumprod(...).

7. 包含更干净的 Python API 钩子#

在 Relay 中,通常的约定是,通过 TVM_REGISTER_GLOBAL 导出的函数应该包装在单独的 Python 函数中,而不是直接在 Python 中调用。对于我们的算子,我们在 python/tvm/relay/op/transform.py 中公开了这个更干净的接口。

def cumsum(data, axis=None, dtype=None, exclusive=None):
    return _make.cumsum(data, axis, dtype, exclusive)

def cumprod(data, axis=None, dtype=None, exclusive=None):
    return _make.cumprod(data, axis, dtype, exclusive)

请注意,这些 Python 包装器也可能会提供更易于算子使用的接口。例如,concat 算子注册为只接受一个算子,即要连接的张量组成的元组,但是 Python 包装器将张量作为参数,并在产生调用节点之前将它们组合成一个元组:

def concat(*args):
    """Concatenate the input tensors along the zero axis.

    Parameters
    ----------
    args: list of Tensor

    Returns
    -------
    tensor: The concatenated tensor.
    """
    tup = Tuple(list(args))
    return _make.concat(tup)

8. 编写单元测试!#

这很容易理解!我们可以在 tests/python/relay/test_op_level3.py 中找到一些示例单元测试,用于我们的累积和与乘积运算。

其他主题#

Gradient 算子#

梯度算子在 Relay 中编写可微分程序时非常重要。尽管 Relay 的自动微分算法可以对一等语言构造进行微分,但算子是不透明的。由于 Relay 无法查看实现细节,因此必须提供明确的微分规则。

Python 和 C++ 都可以用于编写梯度算子,但我们的示例主要集中在 Python 上,因为它更常使用。

在 Python 中添加梯度#

可以在 python/tvm/relay/op/_tensor_grad.py 中找到 Python 梯度算子的集合。我们将通过两个代表性的例子进行说明:sigmoidmultiply

@register_gradient("sigmoid")
def sigmoid_grad(orig, grad):
    """Returns [grad * sigmoid(x) * (1 - sigmoid(x))]."""
    return [grad * orig * (ones_like(orig) - orig)]

这里的输入是原始算子 orig 和要累积的梯度 grad。我们返回的是一个列表,其中第 i 个元素的导数是相对于算子的第 i 个输入的算子。一般来说,梯度将返回具有与基础算子相同数量输入的元素的列表。

在我们进一步分析这个定义之前,首先我们应该回顾一下 sigmoid 函数的导数:\(\frac{\partial \sigma}{\partial x} = \sigma(x)(1 - \sigma(x))\)。上述定义看起来与数学定义类似,但有一个重要的添加项,我们将在下面描述。

术语 orig * (ones_like(orig) - orig) 直接匹配于导数,因为这里的 orig 是 sigmoid 函数,但我们不仅仅对如何计算这个函数的梯度感兴趣。我们感兴趣的是将这个梯度与其他梯度组合起来,以便在整个程序中累积梯度。这就是 grad 项的来源。在表达式 grad * orig * (ones_like(orig) - orig) 中,乘以 grad 指定了如何将到目前为止的导数与梯度进行组合。

现在,我们考虑稍微更有趣的例子:multiply

@register_gradient("multiply")
def multiply_grad(orig, grad):
    """Returns [grad * y, grad * x]"""
    x, y = orig.args
    return [collapse_sum_like(grad * y, x),
            collapse_sum_like(grad * x, y)]

在这个例子中,返回的列表中有两个元素,因为 multiply 是一个二元运算符。回想一下,如果 \(f(x, y) = xy\),偏导数为 \(\frac{\partial f}{\partial x} = y\)\(\frac{\partial f}{\partial y} = x\)

对于 multiply,有一个不是必需的步骤,是因为 multiply 具有广播语义。由于 grad 的形状可能与输入的形状不匹配,我们使用 collapse_sum_like 来获取 grad * <var> 项的内容,并使形状与我们正在求导的输入的形状相匹配。

在 C++ 中添加梯度#

在 C++ 中添加梯度与在 Python 中类似,但是注册的接口略有不同。

首先,确保包含 src/relay/transforms/pattern_utils.h。它提供了在 Relay AST 中创建节点的辅助函数。然后,以与 Python 示例类似的方式定义梯度:

tvm::Array<Expr> MultiplyGrad(const Expr& orig_call, const Expr& output_grad) {
    const Call& call = orig_call.Downcast<Call>();
    return { CollapseSumLike(Multiply(output_grad, call.args[1]), call.args[0]),
             CollapseSumLike(Multiply(output_grad, call.args[0]), call.args[1]) };
}

请注意,在 C++ 中,我们不能使用与 Python 中相同的算子重载,我们需要进行向下转换,因此实现更为冗长。即便如此,我们可以轻松验证这个定义与之前在 Python 中的示例相呼应。

现在,我们不再使用 Python 装饰器,而是需要将 set_attr 调用添加到基本运算符注册的末尾,以注册梯度。

RELAY_REGISTER_OP("multiply")
    // ...
    // Set other attributes
    // ...
    .set_attr<FPrimalGradient>("FPrimalGradient", MultiplyGrad);