如何定制 TVTensor 类#
import torch
from torchvision import tv_tensors
from torchvision.transforms import v2
将创建简单的类,它继承自 torchvision.tv_tensors.TVTensor
基类。这个类将足以涵盖您需要了解的实现更复杂用例的知识。如果您需要创建携带元数据的类,可以参考 torchvision.tv_tensors.BoundingBoxes
类的实现。
class MyTVTensor(tv_tensors.TVTensor):
pass
my_dp = MyTVTensor([1, 2, 3])
my_dp
MyTVTensor([1., 2., 3.])
现在已经定义了自定义的TVTensor类,希望它能够与内置的torchvision变换以及功能性API兼容。为此,需要实现内核,该内核执行转换的核心部分,然后通过 register_kernel()
函数将其“挂钩”到我们想要支持的功能上。
下面展示了这个过程:MyTVTensor类的“水平翻转”操作创建内核,并将其注册到功能API中。
from torchvision.transforms.v2 import functional as F
@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return tv_tensors.wrap(out, like=my_dp)
已经注册了内核,我们可以在 MyTVTensor
实例上调用功能性 API:
my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)
Flipping!
也可以采用 RandomHorizontalFlip
这一变换,因为它内部依赖于 hflip()
函数。
t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)
Flipping!
参数转发和确保你的内核的未来兼容性#
你正在接入的功能API是公开的,因此具有向后兼容:我们保证这些功能的参数不会被移除或重命名,除非经过适当的弃用周期。然而,不保证 forward
兼容,未来可能会添加新的参数。
想象一下,在未来的版本中,Torchvision 为其 hflip()
功能添加了新的 inplace
参数。如果你已经定义并注册了自己的内核,那么
def hflip_my_tv_tensor(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return tv_tensors.wrap(out, like=my_dp)
因此,调用F.hflip(my_dp)
将会失败,因为hflip
会尝试将新的inplace
参数传递给你的核函数,但你的核函数并不接受这个参数。
基于这个原因,我们建议总是以*args, **kwargs
在你的核函数签名中定义它们,就像上面所做的那样。这样,你的核函数就能够接受我们将来可能添加的任何新参数。(从技术上讲,只添加**kwargs
应该就足够了)。