Torch unsqueeze & unsqueeze_

Torch unsqueeze & unsqueeze_#

import set_env
import torch
from tools.tvm_utils import verify_model
torch.set_grad_enabled(False)
input_shape = [10, 10]
class Unsqueeze1(torch.nn.Module):
    def forward(self, *args):
        return args[0].unsqueeze(2)

class Unsqueeze2(torch.nn.Module):
    def forward(self, *args):
        _ = args[0].unsqueeze_(2)
        # Check whether operations after inplace unsqueeze works as expected
        y = args[0].squeeze(2)
        return torch.add(y, y)
input_data = torch.rand(input_shape).float()
verify_model(Unsqueeze1().float().eval(), input_data=input_data)
verify_model(Unsqueeze2().float().eval(), input_data=input_data)
import tvm

@torch.jit.script
def fn(x):
  _ = x.unsqueeze_(2)
  y = x *2
  return y
m,p = tvm.relay.frontend.from_pytorch(fn, [('input', [5, 5])])
m2 = tvm.relay.transform.InferType()(m)
print(m2['main'].body.checked_type)
print(fn(torch.randn(5,5)).shape)
Tensor[(5, 5), float32]
torch.Size([5, 5, 1])