异步任务

异步任务#

torch.jit.fork()#

torch.jit.fork() 创建异步任务执行函数,并获取该执行结果的引用值。

import torch
torch.jit.fork
<function torch.jit._async.fork(func, *args, **kwargs)>

fork() 会立即返回,因此 func 的返回值可能尚未计算完成。要强制完成任务并访问返回值,请在 Future 上调用 torch.jit.wait()fork() 调用时,如果 func 返回类型为 T,则其类型为 torch.jit.Future[T]fork() 调用可以任意嵌套,并且可以接受位置参数和关键字参数。异步执行仅在 TorchScript 中运行时才会发生。如果在纯 Python 中运行,fork() 不会并行执行。在跟踪过程中调用 fork() 时,也不会并行执行,但 fork()torch.jit.wait() 调用将被捕获在导出的 IR 图中。

警告

fork() 任务将以非确定性方式执行。建议仅对不修改其输入、模块属性或全局状态的纯函数生成并行 fork() 任务。

fork 自由函数:

import torch
from torch import Tensor

def foo(a : Tensor, b : int) -> Tensor:
    return a + b

def bar(a):
    fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
    return torch.jit.wait(fut)
script_bar = torch.jit.script(bar)
x = torch.tensor(2)
# only the scripted version executes asynchronously
assert script_bar(x) == bar(x)
# trace is not run asynchronously, but fork is captured in IR
graph = torch.jit.trace(bar, (x,)).graph
assert "fork" in str(graph)

fork 模块:

import torch
from torch import Tensor
class AddMod(torch.nn.Module):
    def forward(self, a: Tensor, b : int):
        return a + b
class Mod(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.mod = AddMod()
    def forward(self, a):
        fut = torch.jit.fork(self.mod, a, b=2)
        return torch.jit.wait(fut)
x = torch.tensor(2)
mod = Mod()
assert mod(x) == torch.jit.script(mod).forward(x)