TorchScript 语言参考#
TorchScript 是 Python 的静态类型子集,可以直接编写(使用 @torch.jit.script
装饰器)或通过跟踪从 Python 代码自动生成。当使用跟踪时,通过只记录张量上的实际算子并简单地执行并丢弃其他周围的 Python 代码,代码将自动转换为 Python 的这个子集。
当直接使用 @torch.jit.script
装饰器编写 TorchScript 时,程序员必须只使用 TorchScript 中支持的 Python 子集。
与 Python 不同,TorchScript 函数中的每个变量必须有静态类型。这使得优化 TorchScript 函数更容易。
比如下面是类型不匹配的例子:
import torch
@torch.jit.script
def an_error(x):
if x:
r = torch.rand(1)
else:
r = 4
return r
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[1], line 5
1 import torch
4 @torch.jit.script
----> 5 def an_error(x):
6 if x:
7 r = torch.rand(1)
File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/jit/_script.py:1341, in script(obj, optimize, _frames_up, _rcb, example_inputs)
1339 if _rcb is None:
1340 _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
-> 1341 fn = torch._C._jit_script_compile(
1342 qualified_name, ast, _rcb, get_default_args(obj)
1343 )
1344 # Forward docstrings
1345 fn.__doc__ = obj.__doc__
RuntimeError:
Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
File "/tmp/ipykernel_351805/3025638472.py", line 6
@torch.jit.script
def an_error(x):
if x:
~~~~~
r = torch.rand(1)
~~~~~~~~~~~~~~~~~
else:
~~~~~
r = 4
~~~~~ <--- HERE
return r
and was used here:
File "/tmp/ipykernel_351805/3025638472.py", line 10
else:
r = 4
return r
~ <--- HERE
默认类型#
默认情况下,TorchScript 函数的所有参数都假定为张量。要指定 TorchScript 函数的参数是另一种类型,可以使用 MyPy 风格的类型注解。
import torch
@torch.jit.script
def foo(x, tup):
# type: (int, Tuple[Tensor, Tensor]) -> Tensor
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
tensor([3.5251, 3.9520, 4.4217])
也可以使用来自 typing
模块的 Python 3 类型提示来注解类型。
import torch
from typing import Tuple
@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
tensor([4.4262, 4.1869, 3.5746])
空列表被假设为 list[Tensor]
和空字典 dict[str, Tensor]
。要实例化其他类型的空列表或字典,请使用 Python 3 类型注解。
比如:
import torch
from torch import nn
class EmptyDataStructures(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> tuple[list[tuple[int, float]], dict[str, int]]:
# This annotates the list to be a `List[Tuple[int, float]]`
my_list: list[tuple[int, float]] = []
for i in range(10):
my_list.append((i, x.item()))
my_dict: dict[str, int] = {}
return my_list, my_dict
x = torch.jit.script(EmptyDataStructures())
可选类型细化#
当在 if
语句的条件中进行与 None
的比较或在 assert
中进行检查时,TorchScript 将改进 Optional[T]
类型变量的类型。编译器可以推断与 and
、or
和 not
组合在一起的多个 None
检查。对于没有显式编写的 if
语句的 else
块也会进行细化。
None
检查必须在 if
语句的条件内;给一个变量赋值 None
检查,并在 if
语句的条件中使用它,不会改进检查中变量的类型。只有局部变量会被细化,比如 self.x
不会也必须赋值给一个局部变量进行细化。
细化参数和局部变量的类型的示例:
import torch
from torch import nn
from typing import Optional
class M(nn.Module):
z: Optional[int]
def __init__(self, z):
super().__init__()
# If `z` is None, its type cannot be inferred, so it must
# be specified (above)
self.z = z
def forward(self, x, y, z):
# type: (Optional[int], Optional[int], Optional[int]) -> int
if x is None:
x = 1
x = x + 1
# Refinement for an attribute by assigning it to a local
z = self.z
if y is not None and z is not None:
x = y + z
# Refinement via an `assert`
assert z is not None
x += z
return x
module = torch.jit.script(M(2))
module = torch.jit.script(M(None))
更多示例#
编写函数脚本#
@torch.jit.script
装饰器将通过编译函数体来构造 ScriptFunction。
import torch
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
print(type(foo)) # torch.jit.ScriptFunction
# See the compiled graph as Python code
print(foo.code)
# Call the function using the TorchScript interpreter
foo(torch.ones(2, 2), torch.ones(2, 2))
<class 'torch.jit.ScriptFunction'>
def foo(x: Tensor,
y: Tensor) -> Tensor:
_0 = bool(torch.gt(torch.max(x), torch.max(y)))
if _0:
r = x
else:
r = y
return r
tensor([[1., 1.],
[1., 1.]])
使用 example_inputs
编写函数脚本#
示例输入可用于注释函数参数。
备注
需要:
pip install MonkeyType
import torch
def test_sum(a, b):
return a + b
# Annotate the arguments to be int
scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])
print(type(scripted_fn)) # torch.jit.ScriptFunction
# See the compiled graph as Python code
print(scripted_fn.code)
# Call the function using the TorchScript interpreter
scripted_fn(20, 100)
<class 'torch.jit.ScriptFunction'>
def test_sum(a: int,
b: int) -> int:
return torch.add(a, b)
120
脚本化 Module
#
默认情况下,编写 Module
脚本将编译 forward
方法,并递归地编译 forward
调用的任何方法、子模块和函数。如果 Module
只使用 TorchScript 中支持的特性,那么就不需要修改原始模块代码。script
将构造 torch.jit.ScriptModule
,其中包含原始模块的属性、参数和方法的副本。
脚本化带有参数的简单模块:
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super().__init__()
# This parameter will be copied to the new ScriptModule
self.weight = torch.nn.Parameter(torch.rand(N, M))
# When this submodule is used, it will be compiled
self.linear = torch.nn.Linear(N, M)
def forward(self, input):
output = self.weight.mv(input)
# This calls the `forward` method of the `nn.Linear` module, which will
# cause the `self.linear` submodule to be compiled to a `ScriptModule` here
output = self.linear(output)
return output
scripted_module = torch.jit.script(MyModule(2, 3))
脚本化带有跟踪子模块的模块脚本:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModule(nn.Module):
def __init__(self):
super().__init__()
# torch.jit.trace produces a ScriptModule's conv1 and conv2
self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
def forward(self, input):
input = F.relu(self.conv1(input))
input = F.relu(self.conv2(input))
return input
scripted_module = torch.jit.script(MyModule())
要编译一个方法,而不是 forward
编译(并递归编译它调用的任何东西),请向该方法添加 @torch.jit.export
装饰器符。选择退出编译使用 @torch.jit.ignore
或者 @torch.jit.unused
。
模块中导出并被忽略的方法:
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super().__init__()
@torch.jit.export
def some_entry_point(self, input):
return input + 10
@torch.jit.ignore
def python_only_fn(self, input):
# This function won't be compiled, so any
# Python APIs can be used
import pdb
pdb.set_trace()
def forward(self, input):
if self.training:
self.python_only_fn(input)
return input * 99
scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))
使用 example_inputs
对 nn.Module
进行 forward 注解:
import torch
import torch.nn as nn
from typing import NamedTuple
class MyModule(NamedTuple):
result: List[int]
class TestNNModule(torch.nn.Module):
def forward(self, a) -> MyModule:
result = MyModule(result=a)
return result
pdt_model = TestNNModule()
# Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward
scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
# Run the scripted_model with actual inputs
print(scripted_model([20]))