I have 6 tensors of shape (batch_size, S, S, 1) and I want to combine them in one python list of size (batch_size, S*S, 6) - so every element of tensor should be inside the inner list.
Can this be achieved without using loops? What's the efficient way to solve it?
CodePudding user response:
Let batch_size=10 and S=4 for the purpose of this example:
>>> x = [torch.rand(10, 4, 4, 1) for _ in range(6)]
Indeed the first step is to concatenate the tensor on the last dimension axis=3:
>>> y = torch.cat(x, -1)
>>> y.shape
torch.Size([10, 4, 4, 6])
Then reshape to flatten axis=1 and axis=2, you can do so with torch.flatten here since the two axes as adjacent:
>>> y = torch.cat(x, -1).flatten(1, 2)
>>> y.shape
torch.Size([10, 16, 6])
