TorchScript 语言参考#

参考:jit_language_reference_v2

TorchScript 是 Python 的静态类型子集,可以直接编写(使用 @torch.jit.script 装饰器)或通过跟踪从 Python 代码自动生成。当使用跟踪时,通过只记录张量上的实际算子并简单地执行并丢弃其他周围的 Python 代码,代码将自动转换为 Python 的这个子集。

术语#

本文档使用以下术语:

模式

说明

::=

表示给定的符号被定义为。

" "

表示作为语法一部分的实际关键字和分隔符。

A | B

表示 A 或 B。

( )

表示分组。

[]

表示可选。

A+

表示正则表达式,其中术语 A 至少重复一次。

A*

表示正则表达式,其中术语 A 重复零次或多次。

当直接使用 @torch.jit.script 装饰器编写 TorchScript 时,程序员必须只使用 TorchScript 中支持的 Python 子集。

类型系统#

TorchScript 是 Python 的静态类型子集。TorchScript 和完整 Python 语言之间最大的区别在于,TorchScript 仅支持表达神经网络模型所需的一小部分类型。

TorchScript 类型#

TorchScript 类型系统由 TSTypeTSModuleType 组成,定义如下:

TSAllType ::= TSType | TSModuleType
TSType    ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType

TSType 表示 TorchScript 中大多数可组合的类型,这些类型可以在 TorchScript 类型注释中使用。TSType 指的是以下任何一种:

  • 元类型,例如 Any

  • 基本类型,例如 intfloatstr

  • 结构类型,例如 Optional[int]List[MyClass]

  • 名义类型(Python 类),例如 MyClass(用户定义)、torch.tensor(内置)

TSModuleType 表示 torch.nn.Module 及其子类。它与 TSType 的处理方式不同,因为其类型模式部分从对象实例推断,部分从类定义推断。因此,TSModuleType 的实例可能不会遵循相同的静态类型模式。出于类型安全考虑,TSModuleType 不能用作 TorchScript 类型注释,也不能与 TSType 组合。

元类型#

元类型非常抽象,它们更像类型约束而不是具体类型。目前,TorchScript 定义了一种元类型 Any,它表示任何 TorchScript 类型。

Any 类型#

Any 类型表示任何 TorchScript 类型。Any 不指定任何类型约束,因此对 Any 没有类型检查。因此,它可以绑定到任何 Python 或 TorchScript 数据类型(例如,int、TorchScript tuple 或未脚本的任意 Python 类)。

TSMetaType ::= "Any"

其中:

  • Any 是来自 typing 模块的 Python 类名。因此,要使用 Any 类型,必须从 typing 导入它(例如,from typing import Any)。

  • 由于 Any 可以表示任何 TorchScript 类型,因此允许对该类型的值进行操作的运算符集是有限的。

支持 Any 类型的运算符#

  • 分配给 Any 类型的数据。

  • 绑定到 Any 类型的参数或返回值。

  • x isx is not,其中 xAny 类型。

  • isinstance(x, Type),其中 xAny 类型。

  • Any 类型的数据是可打印的。

  • List[Any] 类型的数据在数据是同一类型 T 的值列表且 T 支持比较运算符时可能是可排序的。

与 Python 相比

Any 是 TorchScript 类型系统中最不受约束的类型。从这个意义上说,它与 Python 中的 Object 类非常相似。然而,Any 仅支持 Object 支持的运算符和方法的子集。

设计说明#

当脚本化 PyTorch 模块时,可能会遇到未参与脚本执行的数据。尽管如此,它必须由类型模式描述。不仅为未使用的数据描述静态类型(在脚本的上下文中)很繁琐,而且可能导致不必要的脚本化失败。Any 被引入以描述不需要精确静态类型的数据的类型。

示例 1

此示例说明了如何使用 Any 允许元组参数的第二个元素为任何类型。这是可能的,因为 x[1] 不参与任何需要知道其精确类型的计算。

import torch
from typing import Any

@torch.jit.export
def inc_first_element(x: tuple[int, Any]):
    return (x[0]+1, x[1])

m = torch.jit.script(inc_first_element)
print(m((1,2.0)))
print(m((1,(100,200))))
(2, 2.0)
(2, (100, 200))

元组的第二个元素是 Any 类型,因此可以绑定到多种类型。例如,(1, 2.0) 将浮点类型绑定到 tuple[int, Any] 中的 Any,而 (1, (100, 200)) 在第二次调用中将元组绑定到 Any

示例 2

此示例说明了如何使用 isinstance 动态检查注释为 Any 类型的数据的类型:

import torch
from typing import Any

def f(a:Any):
    print(a)
    return (isinstance(a, torch.Tensor))

ones = torch.ones([2])
m = torch.jit.script(f)
print(m(ones))
 1
 1
[ CPUFloatType{2} ]
True

基本类型#

基本 TorchScript 类型是表示单一类型值的类型,并带有预定义的单一类型名称。

TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"

结构类型#

结构类型是无需用户定义名称(与名义类型不同)结构化定义的类型,例如 Future[int]。结构类型可以与任何 TSType 组合。

TSStructuralType ::=  TSTuple | TSNamedTuple | TSList | TSDict |
                      TSOptional | TSUnion | TSFuture | TSRRef | TSAwait

TSTuple          ::= "Tuple" "[" (TSType ",")* TSType "]"
TSNamedTuple     ::= "namedtuple" "(" (TSType ",")* TSType ")"
TSList           ::= "List" "[" TSType "]"
TSOptional       ::= "Optional" "[" TSType "]"
TSUnion          ::= "Union" "[" (TSType ",")* TSType "]"
TSFuture         ::= "Future" "[" TSType "]"
TSRRef           ::= "RRef" "[" TSType "]"
TSAwait          ::= "Await" "[" TSType "]"
TSDict           ::= "Dict" "[" KeyType "," TSType "]"
KeyType          ::= "str" | "int" | "float" | "bool" | TensorType | "Any"

其中:

  • TupleListOptionalUnionFutureDict 表示在 typing 模块中定义的 Python 类型类名。要使用这些类型名称,必须从 typing 导入它们(例如,from typing import Tuple)。

  • namedtuple 表示 Python 类 collections.namedtupletyping.NamedTuple

  • FutureRRef 表示 Python 类 torch.futurestorch.distributed.rpc

  • Await 表示 Python 类 torch._awaits._Await

与 Python 相比

除了能够与 TorchScript 类型组合外,这些 TorchScript 结构类型通常支持其 Python 对应类型的运算符和方法的公共子集。

示例 1

此示例使用 typing.NamedTuple 语法定义元组:

import torch
from typing import NamedTuple

class MyTuple(NamedTuple):
    first: int
    second: int

def inc(x: MyTuple) -> tuple[int, int]:
    return (x.first+1, x.second+1)

t = MyTuple(first=1, second=2)
scripted_inc = torch.jit.script(inc)
print("TorchScript:", scripted_inc(t))
TorchScript: (2, 3)

示例 2

此示例使用 collections.namedtuple 语法定义元组:

import torch
from typing import NamedTuple
from collections import namedtuple

_AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('first', int), ('second', int)])
_UnannotatedNamedTuple = namedtuple('_NamedTupleAnnotated', ['first', 'second'])

def inc(x: _AnnotatedNamedTuple) -> tuple[int, int]:
    return (x.first+1, x.second+1)

m = torch.jit.script(inc)
print(inc(_UnannotatedNamedTuple(1,2)))
(2, 3)

示例 3

此示例说明了注释结构类型时的常见错误,即未从 typing 模块导入复合类型类:

from typing import Tuple
import torch

# 错误:Tuple 未被识别,因为未从 typing 导入
@torch.jit.export
def inc(x: Tuple[int, int]):
    return (x[0]+1, x[1]+1)

m = torch.jit.script(inc)
print(m((1,2)))
(2, 3)

名义类型#

名义 TorchScript 类型是 Python 类。这些类型被称为名义类型,因为它们使用自定义名称声明,并使用类名进行比较。名义类进一步分为以下类别:

TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum

其中,TSCustomClassTSEnum 必须可编译为 TorchScript 中间表示(IR)。这是由类型检查器强制执行的。

内置类#

内置名义类型是 Python 类,其语义内置于 TorchScript 系统中(例如,张量类型)。TorchScript 定义了这些内置名义类型的语义,并且通常仅支持其 Python 类定义的方法或属性的子集。

TSBuiltinClass ::= TSTensor | "torch.device" | "torch.Stream" | "torch.dtype" |
                   "torch.nn.ModuleList" | "torch.nn.ModuleDict" | ...
TSTensor       ::= "torch.Tensor" | "common.SubTensor" | "common.SubWithTorchFunction" |
                   "torch.nn.parameter.Parameter" | 以及 torch.Tensor 的子类

关于 torch.nn.ModuleListtorch.nn.ModuleDict 的特别说明#

尽管 torch.nn.ModuleListtorch.nn.ModuleDict 在 Python 中定义为列表和字典,但它们在 TorchScript 中的行为更像元组:

  • 在 TorchScript 中,torch.nn.ModuleListtorch.nn.ModuleDict 的实例是不可变的。

  • 遍历 torch.nn.ModuleListtorch.nn.ModuleDict 的代码是完全展开的,因此 torch.nn.ModuleList 的元素或 torch.nn.ModuleDict 的键可以是 torch.nn.Module 的不同子类。

示例

以下示例突出了一些内置 TorchScript 类(torch.*)的使用:

import torch

@torch.jit.script
class A:
    def __init__(self):
        self.x = torch.rand(3)

    def f(self, y: torch.device):
        return self.x.to(device=y)

def g():
    a = A()
    return a.f(torch.device("cpu"))

script_g = torch.jit.script(g)
print(script_g.graph)

自定义类#

与内置类不同,自定义类的语义由用户定义,整个类定义必须可编译为 TorchScript IR,并受 TorchScript 类型检查规则的约束。

TSClassDef ::= [ "@torch.jit.script" ]
                 "class" ClassName [ "(object)" ]  ":"
                    MethodDefinition |
                [ "@torch.jit.ignore" ] | [ "@torch.jit.unused" ]
                    MethodDefinition

其中:

  • 类必须是新式类。Python 3 仅支持新式类。在 Python 2.x 中,新式类通过从对象子类化来指定。

  • 实例数据属性是静态类型的,实例属性必须在 __init__() 方法中声明。

  • 不支持方法重载(即不能有多个同名方法)。

  • MethodDefinition 必须可编译为 TorchScript IR,并遵守 TorchScript 的类型检查规则(即所有方法必须是有效的 TorchScript 函数,类属性定义必须是有效的 TorchScript 语句)。

  • torch.jit.ignoretorch.jit.unused 可用于忽略不完全可脚本化的方法或函数,或应被编译器忽略的方法或函数。

与 Python 相比

TorchScript 自定义类与其 Python 对应类相比非常有限。TorchScript 自定义类:

  • 不支持类属性。

  • 不支持子类化,除了子类化接口类型或对象。

  • 不支持方法重载。

  • 必须在 __init__() 中初始化所有实例属性;这是因为 TorchScript 通过推断 __init__() 中的属性类型来构造类的静态模式。

  • 必须仅包含满足 TorchScript 类型检查规则并可编译为 TorchScript IR 的方法。

示例 1

Python 类可以通过使用 @torch.jit.script 注释来在 TorchScript 中使用,类似于声明 TorchScript 函数的方式:

import torch
@torch.jit.script
class MyClass:
    def __init__(self, x: int):
        self.x = x

    def inc(self, val: int):
        self.x += val

示例 2

TorchScript 自定义类类型必须在 __init__() 中“声明”所有实例属性。如果实例属性未在 __init__() 中定义但在类的其他方法中访问,则该类不能编译为 TorchScript 类,如下例所示:

import torch

@torch.jit.script
class foo:
    def __init__(self):
        self.y = 1

# 错误:self.x 未在 __init__ 中定义
def assign_x(self):
    self.x = torch.rand(2, 3)

该类将无法编译并发出以下错误:

RuntimeError:
Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
def assign_x(self):
    self.x = torch.rand(2, 3)
    ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

示例 3

在此示例中,TorchScript 自定义类定义了类变量名称,这是不允许的:

import torch

@torch.jit.script
class MyClass(object):
    name = "MyClass"
    def __init__(self, x: int):
        self.x = x

def fn(a: MyClass):
    return a.name

它会导致以下编译时错误:

RuntimeError:
'__torch__.MyClass' object has no attribute or method 'name'. Did you forget to initialize an attribute in __init__()?:
    File "test-class2.py", line 10
def fn(a: MyClass):
    return a.name
        ~~~~~~ <--- HERE

枚举类型#

与自定义类一样,枚举类型的语义是由用户定义的,整个类定义必须可编译为 TorchScript IR,并遵循 TorchScript 的类型检查规则。


    TSEnumDef ::= "class" 标识符 "(enum.Enum | TSEnumType)" ":"
                   ( 成员标识符 "="  )+
                   ( 方法定义 )*

其中:

  • 值必须是类型为 intfloatstr 的 TorchScript 字面量,并且必须是相同的 TorchScript 类型。

  • TSEnumType 是 TorchScript 枚举类型的名称。与 Python 枚举类似,TorchScript 允许受限的 Enum 子类化,即只有当枚举不定义任何成员时才允许子类化。

与 Python 相比

  • TorchScript 仅支持 enum.Enum。它不支持其他变体,如 enum.IntEnumenum.Flagenum.IntFlagenum.auto

  • TorchScript 枚举成员的值必须是相同类型,并且只能是 intfloatstr 类型,而 Python 枚举成员可以是任何类型。

  • 包含方法的枚举在 TorchScript 中被忽略。

示例 1

以下示例将类 Color 定义为 Enum 类型:


    import torch
    from enum import Enum

    class Color(Enum):
        RED = 1
        GREEN = 2

    def enum_fn(x: Color, y: Color) -> bool:
        if x == Color.RED:
            return True
        return x == y

    m = torch.jit.script(enum_fn)

    print("Eager: ", enum_fn(Color.RED, Color.GREEN))
    print("TorchScript: ", m(Color.RED, Color.GREEN))

示例 2

以下示例展示了受限枚举子类化的情况,其中 BaseColor 没有定义任何成员,因此可以被 Color 子类化:


    import torch
    from enum import Enum

    class BaseColor(Enum):
        def foo(self):
            pass

    class Color(BaseColor):
        RED = 1
        GREEN = 2

    def enum_fn(x: Color, y: Color) -> bool:
        if x == Color.RED:
            return True
        return x == y

    m = torch.jit.script(enum_fn)

    print("TorchScript: ", m(Color.RED, Color.GREEN))
    print("Eager: ", enum_fn(Color.RED, Color.GREEN))

TorchScript 模块类#

TSModuleType 是一种特殊的类类型,它从在 TorchScript 外部创建的对象实例中推断出来。TSModuleType 的名称由对象实例的 Python 类命名。Python 类的 __init__() 方法不被视为 TorchScript 方法,因此它不必遵守 TorchScript 的类型检查规则。

模块实例类的类型模式直接从实例对象(在 TorchScript 范围之外创建)构建,而不是像自定义类那样从 __init__() 推断。同一实例类类型的两个对象可能遵循两种不同的类型模式。

从这个意义上说,TSModuleType 并不是真正的静态类型。因此,出于类型安全考虑,TSModuleType 不能用于 TorchScript 类型注解,也不能与 TSType 组合。

模块实例类#

TorchScript 模块类型表示用户定义的 PyTorch 模块实例的类型模式。当脚本化一个 PyTorch 模块时,模块对象总是在 TorchScript 外部创建(即作为参数传递给 forward)。Python 模块类被视为模块实例类,因此 Python 模块类的 __init__() 方法不受 TorchScript 类型检查规则的约束。


    TSModuleType ::= "class" 标识符 "(torch.nn.Module)" ":"
                        类体定义

其中:

  • forward() 和其他用 @torch.jit.export 装饰的方法必须可编译为 TorchScript IR,并受 TorchScript 的类型检查规则约束。

与自定义类不同,只有模块类型的 forward 方法和其他用 @torch.jit.export 装饰的方法需要可编译。最值得注意的是,__init__() 不被视为 TorchScript 方法。因此,模块类型构造函数不能在 TorchScript 范围内调用。相反,TorchScript 模块对象总是在外部构造并传递给 torch.jit.script(ModuleObj)

示例 1

此示例说明了一些模块类型的特性:

  • TestModule 实例在 TorchScript 范围之外创建(即在调用 torch.jit.script 之前)。

  • __init__() 不被视为 TorchScript 方法,因此它不必注解并且可以包含任意 Python 代码。此外,实例类的 __init__() 方法不能在 TorchScript 代码中调用。因为 TestModule 实例在 Python 中实例化,在这个示例中,TestModule(2.0)TestModule(2) 创建了两个具有不同类型数据属性的实例。self.x 对于 TestModule(2.0)float 类型,而 self.y 对于 TestModule(2.0)int 类型。

  • TorchScript 自动编译其他方法(例如 mul()),这些方法由通过 @torch.jit.exportforward() 方法注解的方法调用。

  • TorchScript 程序的入口点是模块类型的 forward()、注解为 torch.jit.script 的函数或注解为 torch.jit.export 的方法。

    import torch

    class TestModule(torch.nn.Module):
        def __init__(self, v):
            super().__init__()
            self.x = v

        def forward(self, inc: int):
            return self.x + inc

    m = torch.jit.script(TestModule(1))
    print(f"First instance: {m(3)}")

    m = torch.jit.script(TestModule(torch.ones([5])))
    print(f"Second instance: {m(3)}")

上面的示例产生以下输出:

    First instance: 4
    Second instance: tensor([4., 4., 4., 4., 4.])

示例 2

以下示例展示了模块类型的错误用法。具体来说,此示例在 TorchScript 范围内调用了 TestModule 的构造函数:


    import torch

    class TestModule(torch.nn.Module):
        def __init__(self, v):
            super().__init__()
            self.x = v

        def forward(self, x: int):
            return self.x + x

    class MyModel:
        def __init__(self, v: int):
            self.val = v

        @torch.jit.export
        def doSomething(self, val: int) -> int:
            # 错误:不应在 TorchScript 范围内调用模块类型的构造函数
            myModel = TestModule(self.val)
            return myModel(val)

    # m = torch.jit.script(MyModel(2)) # 结果是以下 RuntimeError
    # RuntimeError: Could not get name of python class object

类型注解#

由于 TorchScript 是静态类型的,程序员需要在 TorchScript 代码的 关键点 进行类型注解,以便每个局部变量或实例数据属性都有静态类型,并且每个函数和方法都有静态类型的签名。

何时注解类型#

一般来说,类型注解仅在静态类型无法自动推断的地方(例如,参数或有时是方法或函数的返回类型)才需要。局部变量和数据属性的类型通常从它们的赋值语句中自动推断。有时,推断的类型可能过于严格,例如,通过赋值 x = Nonex 被推断为 NoneType,而 x 实际上被用作 Optional。在这种情况下,可能需要类型注解来覆盖自动推断,例如 x: Optional[int] = None。请注意,即使局部变量或数据属性的类型可以自动推断,也始终可以安全地进行类型注解。注解的类型必须与 TorchScript 的类型检查一致。

当参数、局部变量或数据属性未进行类型注解且其类型无法自动推断时,TorchScript 假设其为默认类型 TensorTypeList[TensorType]Dict[str, TensorType]

注解函数签名#

由于参数可能无法从函数体(包括函数和方法)中自动推断,因此需要进行类型注解。否则,它们将假定为默认类型 TensorType

TorchScript 支持两种方法和函数签名类型注解的风格:

  • Python3 风格 直接在签名上注解类型。因此,它允许个别参数不注解(其类型将是默认类型 TensorType),或允许返回类型不注解(其类型将自动推断)。


    Python3Annotation ::= "def" 标识符 [ "(" ParamAnnot* ")" ] [ReturnAnnot] ":"
                                FuncOrMethodBody
    ParamAnnot        ::= 标识符 [ ":" TSType ] ","
    ReturnAnnot       ::= "->" TSType

请注意,使用 Python3 风格时,类型 self 是自动推断的,不应注解。

  • Mypy 风格 在函数/方法声明下方作为注释注解类型。在 Mypy 风格中,由于参数名称不出现在注解中,所有参数都必须注解。


    MyPyAnnotation ::= "# type:" "(" ParamAnnot* ")" [ ReturnAnnot ]
    ParamAnnot     ::= TSType ","
    ReturnAnnot    ::= "->" TSType

示例 1

在此示例中:

  • a 未注解,假定为默认类型 TensorType

  • b 注解为类型 int

  • 返回类型未注解,自动推断为类型 TensorType(基于返回值的类型)。

import torch

def f(a, b: int):
    return a+b

m = torch.jit.script(f)
print("TorchScript:", m(torch.ones([6]), 100))
TorchScript: tensor([101., 101., 101., 101., 101., 101.])

示例 2

以下示例使用 Mypy 风格注解。请注意,即使某些参数假定默认类型,参数或返回值也必须注解。

import torch

def f(a, b):
    # type: (torch.Tensor, int) → torch.Tensor
    return a+b

m = torch.jit.script(f)
print("TorchScript:", m(torch.ones([6]), 100))

注解变量和数据属性#

一般来说,数据属性(包括类和实例数据属性)和局部变量的类型可以从赋值语句中自动推断。然而,有时如果变量或属性与不同类型的值相关联(例如,作为 NoneTensorType),则可能需要显式注解为更宽泛的类型,如 Optional[int]Any

局部变量#

局部变量可以根据 Python3 类型模块注解规则进行注解,即:


    LocalVarAnnotation ::= 标识符 [":" TSType] "=" Expr

一般来说,局部变量的类型可以自动推断。然而,在某些情况下,您可能需要为可能与不同具体类型相关联的局部变量注解多类型。典型的多类型包括 Optional[T]Any

示例

import torch

def f(a, setVal: bool):
    value: torch.Tensor|None = None
    if setVal:
        value = a
    return value

ones = torch.ones([6])
m = torch.jit.script(f)
print("TorchScript:", m(ones, True), m(ones, False))
TorchScript: tensor([1., 1., 1., 1., 1., 1.]) None

实例数据属性#

对于 ModuleType 类,实例数据属性可以根据 Python3 类型模块注解规则进行注解。实例数据属性可以通过 Final 可选地注解为最终属性。


    "class" ClassIdentifier "(torch.nn.Module):"
    InstanceAttrIdentifier ":" ["Final("] TSType [")"]
    ...

其中:

  • InstanceAttrIdentifier 是实例属性的名称。

  • Final 表示该属性不能在 __init__ 外部重新赋值或在子类中覆盖。

示例


    import torch

    class MyModule(torch.nn.Module):
        offset_: int

    def __init__(self, offset):
        self.offset_ = offset

    ...

类型注解 API#

torch.jit.annotate(T, expr)#

此 API 将类型 T 注解到表达式 expr。这通常用于当表达式的默认类型不是程序员预期的类型时。例如,空列表(字典)的默认类型是 List[TensorType]Dict[TensorType, TensorType]),但有时它可能用于初始化其他类型的列表。另一个常见用例是为 tensor.tolist() 的返回类型进行注解。然而,它不能用于注解 __init__ 中的模块属性;应改用 torch.jit.Attribute

示例

在此示例中,[] 通过 torch.jit.annotate 声明为整数列表(而不是假定 [] 为默认类型 List[TensorType]):

import torch
from typing import List

def g(l: List[int], val: int):
    l.append(val)
    return l

def f(val: int):
    l = g(torch.jit.annotate(List[int], []), val)
    return l

m = torch.jit.script(f)
print("Eager:", f(3))
print("TorchScript:", m(3))
Eager: [3]
TorchScript: [3]

更多信息请参见 torch.jit.annotate()

类型注解附录#

TorchScript 类型系统定义#


    TSAllType       ::= TSType | TSModuleType
    TSType          ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType

    TSMetaType      ::= "Any"
    TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"

    TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional |
                         TSUnion | TSFuture | TSRRef | TSAwait
    TSTuple         ::= "Tuple" "[" (TSType ",")* TSType "]"
    TSNamedTuple    ::= "namedtuple" "(" (TSType ",")* TSType ")"
    TSList          ::= "List" "[" TSType "]"
    TSOptional      ::= "Optional" "[" TSType "]"
    TSUnion         ::= "Union" "[" (TSType ",")* TSType "]"
    TSFuture        ::= "Future" "[" TSType "]"
    TSRRef          ::= "RRef" "[" TSType "]"
    TSAwait         ::= "Await" "[" TSType "]"
    TSDict          ::= "Dict" "[" KeyType "," TSType "]"
    KeyType         ::= "str" | "int" | "float" | "bool" | TensorType | "Any"

    TSNominalType   ::= TSBuiltinClasses | TSCustomClass | TSEnum
    TSBuiltinClass  ::= TSTensor | "torch.device" | "torch.stream"|
                        "torch.dtype" | "torch.nn.ModuleList" |
                        "torch.nn.ModuleDict" | ...
    TSTensor        ::= "torch.tensor" 及其子类

不支持的类型构造#

TorchScript 不支持 Python3 typing <https://docs.python.org/3/library/typing.html#module-typing>_ 模块的所有功能和类型。本文档中未明确指定的 typing <https://docs.python.org/3/library/typing.html#module-typing>_ 模块的任何功能都不受支持。下表总结了在 TorchScript 中不受支持或受限制的 typing 构造。

项目

描述

typing.Any

开发中

typing.NoReturn

不支持

typing.Callable

不支持

typing.Literal

不支持

typing.ClassVar

不支持

typing.Final

支持模块属性、类属性和注解,但不支持函数。

typing.AnyStr

不支持

typing.overload

开发中

类型别名

不支持

名义类型

开发中

结构类型

不支持

NewType

不支持

泛型

不支持

torch.* API#

远程过程调用#

TorchScript 支持一部分 RPC API,这些 API 支持在指定的远程工作线程上运行函数,而不是在本地运行。

具体来说,以下 API 完全支持:

  • torch.distributed.rpc.rpc_sync()

    • rpc_sync() 发起一个阻塞的 RPC 调用,在远程工作线程上运行一个函数。RPC 消息的发送和接收与 Python 代码的执行并行进行。

    • 更多关于其用法和示例的详细信息可以在 rpc_sync() 中找到。

  • torch.distributed.rpc.rpc_async()

    • rpc_async() 发起一个非阻塞的 RPC 调用,在远程工作线程上运行一个函数。RPC 消息的发送和接收与 Python 代码的执行并行进行。

    • 更多关于其用法和示例的详细信息可以在 rpc_async() 中找到。

  • torch.distributed.rpc.remote()

    • remote() 在远程工作线程上执行一个远程调用,并获得一个远程引用 RRef 作为返回值。

    • 更多关于其用法和示例的详细信息可以在 remote() 中找到。

异步执行#

TorchScript 使您能够创建异步计算任务,以更好地利用计算资源。这是通过支持一系列仅在 TorchScript 中可用的 API 来实现的:

  • torch.jit.fork()

    • 创建一个异步任务执行 func,并返回对该任务结果值的引用。Fork 将立即返回。

    • torch.jit._fork() 同义,后者仅出于向后兼容的原因保留。

    • 更多关于其用法和示例的详细信息可以在 fork() 中找到。

  • torch.jit.wait()

    • 强制完成 torch.jit.Future[T] 异步任务,并返回任务的结果。

    • torch.jit._wait() 同义,后者仅出于向后兼容的原因保留。

    • 更多关于其用法和示例的详细信息可以在 wait() 中找到。

类型注解#

TorchScript 是静态类型的。它提供并支持一组实用工具来帮助注解变量和属性:

  • torch.jit.annotate()

    • 在 Python 3 风格类型提示无法正常工作的地方提供类型提示给 TorchScript。

    • 一个常见的例子是为 [] 这样的表达式注解类型。[] 默认被视为 List[torch.Tensor]。当需要不同类型时,可以使用此代码向 TorchScript 提示:torch.jit.annotate(List[int], [])

    • 更多详细信息可以在 :meth:~torch.jit.annotate 中找到。

  • torch.jit.Attribute

    • 常见的用例包括为 torch.nn.Module 属性提供类型提示。因为它们的 __init__ 方法不会被 TorchScript 解析,所以在模块的 __init__ 方法中应使用 torch.jit.Attribute 而不是 torch.jit.annotate

    • 更多详细信息可以在 Attribute 中找到。

  • torch.jit.Final

    • 是 Python 的 typing.Final 的别名。torch.jit.Final 仅出于向后兼容的原因保留。

元编程#

TorchScript 提供了一组实用工具来促进元编程:

  • torch.jit.is_scripting()

    • 返回布尔值,指示当前程序是否由 torch.jit.script 编译。

    • 当在 assertif 语句中使用时,torch.jit.is_scripting() 评估为 False 的作用域或分支不会被编译。

    • 其值可以在编译时静态评估,因此通常在 if 语句中使用,以阻止 TorchScript 编译其中一个分支。

    • 更多详细信息和示例可以在 is_scripting() 中找到。

  • torch.jit.is_tracing()

    • 返回布尔值,指示当前程序是否由 torch.jit.trace / torch.jit.trace_module 跟踪。

    • 更多详细信息可以在 is_tracing() 中找到。

  • @torch.jit.ignore

    • 此装饰器指示编译器应忽略该函数或方法,并将其保留为 Python 函数。

    • 这允许您在模型中保留尚未与 TorchScript 兼容的代码。

    • 如果从 TorchScript 调用由 @torch.jit.ignore 装饰的函数,忽略的函数将把调用分派给 Python 解释器。

    • 带有忽略函数的模型无法导出。

    • 更多详细信息和示例可以在 ignore() 中找到。

  • @torch.jit.unused

    • 此装饰器指示编译器应忽略该函数或方法,并用引发异常替换。

    • 这允许您在模型中保留尚未与 TorchScript 兼容的代码,并仍然导出您的模型。

    • 如果从 TorchScript 调用由 @torch.jit.unused 装饰的函数,将引发运行时错误。

    • 更多详细信息和示例可以在 unused() 中找到。

类型细化#

  • torch.jit.isinstance()

    • 返回一个布尔值,指示变量是否为指定类型。

    • 更多关于其用法和示例的详细信息可以在 isinstance() 中找到。