如何定制 TVTensor 类

如何定制 TVTensor 类#

import torch
from torchvision import tv_tensors
from torchvision.transforms import v2

将创建简单的类,它继承自 torchvision.tv_tensors.TVTensor 基类。这个类将足以涵盖您需要了解的实现更复杂用例的知识。如果您需要创建携带元数据的类,可以参考 torchvision.tv_tensors.BoundingBoxes 类的实现

class MyTVTensor(tv_tensors.TVTensor):
    pass


my_dp = MyTVTensor([1, 2, 3])
my_dp
MyTVTensor([1., 2., 3.])

现在已经定义了自定义的TVTensor类,希望它能够与内置的torchvision变换以及功能性API兼容。为此,需要实现内核,该内核执行转换的核心部分,然后通过 register_kernel() 函数将其“挂钩”到我们想要支持的功能上。

下面展示了这个过程:MyTVTensor类的“水平翻转”操作创建内核,并将其注册到功能API中。

from torchvision.transforms.v2 import functional as F


@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
    print("Flipping!")
    out = my_dp.flip(-1)
    return tv_tensors.wrap(out, like=my_dp)

已经注册了内核,我们可以在 MyTVTensor 实例上调用功能性 API:

my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)
Flipping!

也可以采用 RandomHorizontalFlip 这一变换,因为它内部依赖于 hflip() 函数。

t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)
Flipping!

参数转发和确保你的内核的未来兼容性#

你正在接入的功能API是公开的,因此具有向后兼容:我们保证这些功能的参数不会被移除或重命名,除非经过适当的弃用周期。然而,不保证 forward 兼容,未来可能会添加新的参数。

想象一下,在未来的版本中,Torchvision 为其 hflip() 功能添加了新的 inplace 参数。如果你已经定义并注册了自己的内核,那么

def hflip_my_tv_tensor(my_dp):  # noqa
    print("Flipping!")
    out = my_dp.flip(-1)
    return tv_tensors.wrap(out, like=my_dp)

因此,调用F.hflip(my_dp)将会失败,因为hflip会尝试将新的inplace参数传递给你的核函数,但你的核函数并不接受这个参数。

基于这个原因,我们建议总是以*args, **kwargs在你的核函数签名中定义它们,就像上面所做的那样。这样,你的核函数就能够接受我们将来可能添加的任何新参数。(从技术上讲,只添加**kwargs应该就足够了)。