Relay 中的代数数据类型#

代数数据类型(ADTs)是函数式编程语言的核心特性,尤其是那些源自 ML 的语言,因为它们以一种易于推理的方式表达数据结构,特别是在编写递归计算时。由于递归旨在成为 Relay 中控制流的主要机制之一,因此 Relay 包含 ADT 非常重要,以便最好地表达必须使用递归实现的循环和其他控制流结构。

定义和匹配 ADT#

注意:ADT 目前在文本格式中不受支持。此处的语法是基于其他语言中的 ADT 推测的。"

ADT 可以理解为类似 C 语言中 enumstruct 类型的广义版本。与 C 的 struct 类似,ADT 实例是指定类型字段的容器,但类型系统允许同一类型以系统化的方式编码不同的字段组合,类似于 C 的 enum 类型,后者使用用户命名的有限可能值集定义。

具体来说,ADT 被定义为一组命名的构造函数,每个构造函数都是一个函数,它接受指定类型的值作为参数并返回命名 ADT 的实例。ADT 实例仅包含用于生成它的构造函数调用中传递的参数值。

ADT 值在被 解构 之前是不透明的,解构允许再次访问构造函数的参数并将其用于计算新值。由于特定的 ADT 可以具有不同签名的多个构造函数,因此通常需要根据不同的可能构造函数进行分支,从而产生 ADT 的 match 语法。因此,ADT 有时被称为“带标签的联合体”,因为 ADT 实例由用于生成它的构造函数的名称标记,并且以后可以根据标签进行检查。

由于每个 ADT 都有一组有限的构造函数,因此可以很容易地确定处理 ADT 实例的函数是否处理了所有可能的情况。特别是,类型系统可以确保在解构 ADT 实例时,在所有情况下都正确分配类型,这与 C 中的 union 类型形成对比。因此,通常很容易对 ADT 进行推理。

实现细节:Relay ADT 定义是全局的,并存储在模块中,类似于全局函数定义。实际上,ADT 名称是全局类型变量(就像全局函数名称是全局变量一样)。模块维护了 ADT 名称(全局类型变量)到该 ADT 构造函数列表的映射。

以下是定义 ADT 并通过 match 表达式在函数中使用它的简单示例:

# Defines an ADT named "Numbers"
data Numbers {
  Empty : () -> Numbers
  Single : (Tensor[(), int32]) -> Numbers
  Pair : (Tensor[(), int32], Tensor[(), int32]) -> Numbers
}
# A Numbers value can be produced using an Empty, Single, or Pair
# constructor, each with a signature given above

def @sum(%n : Numbers[]) -> Tensor[(), int32] {
   # The match expression branches on the constructor that was
   # used to produce %n. The variables in each case are bound
   # if the constructor matches that used for %n
   match(%n) {
     case Empty() { 0 }
     case Single(x) { x }
     case Pair(x, y) { x + y }
   }
}

@sum(Empty())    # evaluates to 0
@sum(Single(3))  # evaluates to 3
@sum(Pair(5, 6)) # evaluates to 11

请注意,ADT 通过名称进行标识,这意味着从类型检查器的角度来看,具有结构相同构造函数的两个 ADT 仍然是不同的数据类型。

# structurally identical constructors to Numbers
data Numbers2 {
  Empty2 : () -> Numbers2
  Single2 : (Tensor[(), int32]) -> Numbers2
  Pair2 : (Tensor[(), int32], Tensor[(), int32]) -> Numbers2
}

# the below results in a type error because Numbers2
# is a distinct type from Numbers
# fn() { @sum(Empty2()) }

类型检查 ADT 和多态性#

本节将更详细地介绍 ADT 的类型。涉及的大部分复杂性源于这样的事实,即与函数一样,ADT 可以是多态的并接受类型参数。

例如,函数式编程语言中常用的标准 ADT 之一是可选类型,定义如下:

# a is a type parameter
data Optional<a> {
  None : () -> Optional
  Some : (a) -> Optional
}

可选类型通常用作涉及查询数据结构的任何运算的返回类型(如果找到值则返回 Some(v),如果未找到则返回 None)。在定义中采用类型参数允许在多种情况下使用相同的可选类型,而不必为其中可能包含的每种不同类型定义唯一的 ADT。

然而,重要的是要确保内容类型不同的可选类型仍然可以被类型系统区分,因为如果期望包含 Tensor[(), int32] 的可选类型的函数接收到包含 Tensor[(3, 4), float32] 的可选类型,将违反类型安全性。正如这个例子可能暗示的那样,ADT 实例因此被赋予一个包含该实例的具体类型参数的类型,确保信息得以保留。让以下示例说明:

# the signature for option indicates the type argument
def @inc_scalar(%opt : Optional[Tensor[(), int32]]) -> Tensor[(), int32] {
  match(%opt) {
    case None() { 1 }
    case Some(%s) { %s + 1 }
  }
}

def @main() {
  let %one : Optional[Tensor[(), int32]] = Some(1);
  let %big : Optional[Tensor[(10, 10), float32]]
    = Some(Constant(1, (10, 10), float32));
  let %two = inc_scalar(%one);
  # let %bigger = inc_scalar(%big); # type system rejects
  # None does not take an argument so it can always implicitly
  # be given the correct type arguments
  let %z = inc_scalar(None());
  ()
}

上述示例中带注释的类型参数(例如 Optional[Tensor[(), int32]])的语法称为“类型调用”,将多态 ADT 定义视为类型级函数(接受类型参数并返回类型,即 ADT)。出现在类型注释或函数签名中的任何 ADT 都必须用类型参数进行注释(非多态 ADT 必须位于没有参数的类型调用中)。

因此,可以说,如果接受类型为 T1, ..., Tn 的参数的构造函数 C 是接受类型参数 v1, ..., vn 的 ADT D 的构造函数(其中 T1, ..., Tn 可能包含任何 v1, ..., vn),那么 C 的类型为 fun<v1, ..., vn>(T1, ..., Tn) -> D[v1, ..., vn]。这意味着构造函数的类型与普通函数类似,因此出现在调用节点内部,并且可以传递给其他函数或由其他函数返回。特别是,上面的 Some 示例具有签名 fun<a>(a) -> Optional[a],而 None 具有签名 fun<a>() -> Optional[a]

使用 ADT 进行递归#

ADT 定义允许递归,即名为 D 的 ADT 的定义可以假设类型 D 的存在并将其用作构造函数的参数。递归允许 ADT 表示复杂结构,例如列表或树;它是 ADT 在函数式编程中强大功能的来源,因为适当设计的数据结构可以很容易地用递归函数简洁地表达计算。

许多常用的 ADT 都涉及递归;其中一些在 Common ADT Uses 中给出。作为例子,将检查函数式语言中无处不在的列表 ADT:

data List<a> {
   Nil : () -> List
   Cons : (a, List[a]) -> List
}

(请注意,即使在构造函数中,对 List 的递归引用也包装在类型调用中。)

上述定义意味着特定类型值的列表可以通过嵌套 Cons 构造函数来表示,直到到达列表的末尾,可以用 Nil 表示(表示空列表)。

以这种方式表示的列表可以很容易地递归处理。例如,以下函数对整数列表求和:

def @list_sum(%l : List[Tensor[(), int32]]) -> Tensor[(), int32] {
  match(%l) {
    case Nil() { 0 }
    # add the head of the list to the sum of the tail
    case Cons(%h, %t) { %h + @list_sum(%t) }
  }
}

事实上,许多像上面给出的那样的列表递归函数共享可以分解为通用的、易于使用的函数的结构,这些函数将在 Common ADT Uses 下讨论。

匹配表达式中的模式匹配#

与其他函数式语言一样,Relay 中的匹配表达式能够进行比仅为被解构值的类型的每个构造函数提供一个 case 更通用的模式匹配。

特别是,匹配 case 中的模式可以递归构建:

  • 构造函数模式匹配特定的 ADT 构造函数。如果值与构造函数匹配,则构造函数的每个参数将与嵌套模式匹配。

  • 通配符模式将匹配任何值,并且不会绑定到变量。

  • 变量模式将匹配任何值并将其绑定到局部变量,作用域为匹配子句。

在上面简单的 @list_sum 例子中,第一个匹配 case 有一个 Nil 构造函数模式(没有嵌套参数),第二个有一个 Cons 构造函数模式,该模式对 Cons 的每个参数使用变量模式。

以下示例使用通配符模式忽略 Cons 的一个参数:"

def @first<a>(%l : List[a]) -> Optional[a] {
  match(%l) {
    case Nil() { None() }
    case Cons(%h, _) { Some(%h) } # list tail is unused and ignored
  }
}

在这里,构造函数模式嵌套在另一个构造函数模式中,以避免列表选项的嵌套匹配表达式。还使用顶级通配符模式来处理与第一个子句不匹配的所有情况:"

def @second_opt<a>(%ll : Optional[List[a]]) -> Optional[a] {
  match(%ll) {
    # we only need the second member of the list if there is one
    case Some(Cons(_, Cons(%s, _))) { Some(%s) }
    case _ { None() }
  }
}

# @second_opt(Some(Cons(1, Nil()))) evaluates to None()
# @second_opt(Some(Cons(1, Cons(2, Nil())))) evaluates to Some(2)
# @second_opt(Some(Nil())) evaluates to None()
# @second_opt(None()) evaluates to None()

请注意,匹配表达式按照 case 列出的顺序检查其模式:第一个模式与输入值匹配的子句是被评估的子句。在这里,顶级变量模式绑定整个输入值:

def @match_order_beware<a>(%l : List[a]) -> List[a] {
  match(%l) {
    case %v { %v }
    # the above matches everything so neither of these runs
    case Cons(%h, %t) { Cons(%h, @match_order_beware(%t)) }
    case Nil() { Nil() }
  }
}

常见的 ADT 用法#

在函数式编程语言中,某些 ADT 为编写常见程序提供了有用的功能。参数多态性和高阶函数使这些 ADT 易于重用,并且通用函数可以在常见情况下操作它们。Relay 包含某些预定义 ADT 及其函数的“Prelude”,对应于其他语言中不可或缺的 ADT。

Type-Checking ADTs and Polymorphism 下定义的可选类型就是这样一种 ADT,每当函数在某些情况下仅返回值有意义时使用。拥有可选类型允许类型系统跟踪哪些函数总是返回某个类型的值,而不是返回该类型的可选值,确保任何可选值总是被显式检查(与返回空指针或抛出异常作为解决该问题的其他方法形成对比)。

列表(在 Recursion with ADTs 中定义)可以通过通用函数以类似于 Python 中的列表推导和某些库函数的方式操作。以下是遍历列表的非常常见的函数,它们包含在 Relay 的 Prelude 中。(这些函数在函数式编程文献中已经广泛描述,不会在本文档中尝试重现这些工作。)"

# Map: for [h1, h2, ..., hn] returns [f(h1), f(h2), ..., f(hn)]
def @map<a, b>(%f : fn(a) -> b, %l : List[a]) -> List[b] {
  match(%l) {
    case Nil() { Nil() }
    case Cons(%h, %t) { Cons(%f(%h), @map(%f, %t)) }
  }
}

# Left fold: for [h1, h2, ..., hn] returns f(...(f(f(z, h1), h2)...), hn)
def @foldl<a, b>(%f : fn(b, a) -> b, %z : b, %l : List[a]) -> b {
  match(%l) {
    case Nil() { %z }
    case Cons(%h, %t) { @foldl(%f, %f(%z, %h), %t) }
  }
}

# Right fold: for [h1, h2, ..., hn] returns f(h1, f(h2, f(..., (f(hn, z)...)
def @foldr<a, b>(%f : fn(a, b) -> b, %z : b, %l : List[a] -> b {
  match(%l) {
    case Nil() { %z }
    case Cons(%h, %t) { %f(%h, @foldr(%f, %z, %t)) }
  }
}

使用这些迭代结构,可以简洁地表达列表上的许多常见操作。例如,以下 map 将列表的所有成员加倍:"

# directly written
def @double(%l : List[Tensor[(), int32]]) -> List[Tensor[(), int32]] {
  match(%l) {
    case Nil() { Nil() }
    case Cons(%h, %t) { Cons(%h * 2, @double(%t)) }
  }
}

# map takes care of the recursion
@map(fn(%i) { %i * 2 }, %l)

以下右折叠连接两个列表:

# directly written
def @concat<a>(%l1 : List[a], %l2 : List[a]) -> List[a] {
  match(%l1) {
    case Nil() { %l2 }
    case Cons(%h, %t) { Cons(%h, @concat(%t, %l2) }
  }
}

# foldr takes care of the recursion
@foldr(fn(%h, %z) { Cons(%h, %z) }, %l2, %l1)

以下左折叠展平列表的列表(使用连接):

# directly written
def @flatten<a>(%ll : List[List[a]]) -> List[a] {
  match(%ll) {
    case Cons(%h, %t) { @concat(%h, @flatten(%t)) }
    case Nil() { Nil() }
  }

# foldl takes care of the recursion
@foldl(@concat, Nil(), %ll)

请注意,这些迭代结构可以直接在 Relay 的源语言中实现,并且可以轻松定义更多(以及更多数据类型,如树),而不是内置到语言中的结构(例如,MXNet 中的 "foreach")。ADT 及其可扩展性允许在 Relay 中表达广泛的迭代和数据结构,并由类型系统支持,而无需修改语言实现。

使用 ADT 实现神经网络#

这篇 2015 年的博客文章 中,Christopher Olah 指出,许多神经网络可以很容易地使用常见的函数式编程结构表达。Relay 的 ADT 允许这些示例直接在 TVM 中实现。

首先让假设有对应于训练好的递归神经网络(RNN)单元的函数,它接受过去的状态和输入值并返回新的状态和输出值。在 Relay 中,这将具有以下签名:

@cell : fn<state_type, in_type, out_type>(state_type, in_type) -> (state_type, out_type)

可以将 ReLU 单元作为简单的具体示例,下面是训练好的版本:

def @linear(%x, %w, %b) { %w*%x + %b }

def @relu_cell(%w, # weights
               %b, # offsets
               %s, # state
               %x  # input
) {
  let %x2 = @linear(%x, %w.0, %b.0);
  let %s2 = @linear(%s, %w.1, %b.1);
  # doesn't change the state
  (%s, nn.relu(%x2 + %s2))
}

# this is a higher-order function because it returns a closure
def @trained_cell(%w, %b) {
  fn(%x, %h) { @relu_cell(%w, %b, %x, %h) }
}

按照 Olah 的例子,可以使用以下左折叠对输入序列(列表)进行编码:

def @encode<state_type, in_type, out_type>(%cell, %input : List[in_type], %init : state_type) -> state_type {
  # not using the output
  @foldl(fn(%state, %in) { %cell(%state, %in).0 }, %init, %input)
}

使用 unfold 迭代器(来自 Haskell 的标准库),相同的单元可以用于制作生成器网络(它接受单个输入并产生一系列输出):

# included in Relay's Prelude
def @unfoldr<a, b>(%f : fn(b) -> Optional[(a, b)], %z : b) -> List[a] {
  match(%f(%z)) {
    case Some(%pair) { Cons(%pair.0, @unfoldr(%f, %pair.1)) }
    case None() { Nil() }
  }
}

# we need some way of generating an input to the cell function given only a state
def @gen_func<state_type, in_type, out_type>(%state : state_type) : Optional[(out_type, state_type)] {
  let %in : Optional[in_type] = @generate_input(%state);
  match(%in) {
    case Some(%n) {
      let %cell_out = @cell(%n, %state);
      Some((%cell_out.1, %cell_out.0)) # pair of output and state
    }
    case None() { None() }
  }
}

def @generator<state_type, in_type, out_type>(%cell, %init : state_type) -> List[out_type] {
  @unfoldr(fn(%state) { @gen_func(%cell, %state) }, %init)
}

累积映射(同时更新累加器值和输出列表的折叠)可以用于编写通用 RNN(每个输入都有输出):"

def @map_accumr<a, b, c>(%f : fn(a, b) -> (a, c), %acc : a, %l : List[b]) -> (a, List[c]) {
  match(%l) {
    case Nil() { (%acc, Nil()) }
    case Cons(%b, %t) {
      let %update = %f(%acc, %b);
      let %rest = @map_accumr(%f, %update.0, %t));
      (%rest.0, Cons(%update.1, %rest.1))
    }
  }
}

# can also be implemented as a right fold
# (this version is included in Relay's Prelude)
def @map_accumr_fold(%f, %acc, %l) {
  @foldr(fn(%b, %p) {
    let %f_out = %f(%p.0, %b);
    (%f_out.0, Cons(%f_out.1, %p.1))
  },
  (%acc, Nil()), %l)
}

def @general_rnn<state_type, in_type, out_type>(%cell, %init : state_type, %input : List[in_type])
  -> (state_type, List[out_type]) {
  @map_accumr(%cell, %init, %input)
}

Olah 还给出了双向神经网络的例子,其中两组单元(可能具有不同的权重)在两个方向上处理输入并产生一组输出。以下是该示例的 Relay 实现:

# creates a list of tuples from two lists
# included in Relay's Prelude
def @zip<a, b>(%l : List[a], %m : List[b]) -> List[(a, b)] {
  match(%l) {
    case Nil() { Nil() }
    case Cons(%a, %t1) {
      match(%m) {
        case Nil() { Nil() }
        case Cons(%b, %t2) { Cons((%a, %b), @zip(%t1, %t2)) }
      }
    }
  }
}

# analogous to map_accumr
# included in Relay's Prelude
def @map_accmul(%f, %acc, %l) {
  @foldl(fn(%p, %b){
    let %f_out = %f(%p.0, %b);
    (%f_out.0, Cons(%f_out.1, %p.1))
  }, (%acc, Nil()), %l)
}

def @bidirectional_rnn<state1_type, state2_type, in_type, out1_type, out2_type>
  (%cell1, %cell2, %state1 : state1_type, %state2 : state2_type, %input : List[in_type])
  -> List[(out1_type, out2_type)] {
  @zip(@map_accumr(%cell1, %state1, %input).1, @map_accuml(%cell2, %state2, %input).1)
}