TLDR: given two tensors t1 and t2 that represent b samples of a tensor with shape c,h,w (i.e, every tensor has shape b,c,h,w), i'm trying to calculate the pairwise distance between t1[i] and t2[j] for all i,j efficiently
some more context - I've extracted ResNet18 activations for both my train and test data (CIFAR10) and I'm trying to implement k-nearest-neighbours. A possible pseudo-code might be:
for te in test_activations:
distances = []
for tr in train_activations:
distances.append(||te-tr||)
neighbors = k_smallest_elements(distances)
prediction(te) = majority_vote(labels(neighbors))
I'm trying to vectorise this process given batches from the test and train activations datasets. I've tried iterating the batches (and not the samples) and using torch.cdist(train_batch,test_batch), but I'm not quite sure how this function handles multi-dimensional tensors, as in the documentation it states
torch.cdist(x1, x2,...):
Ifx1has shapeBxPxMandx2has shapeBxRxMthen the output will have shapeBxPxR
Which doesn't seem to handle my case (see below)
A minimal example can be found here:
b,c,h,w = 1000,128,28,28 # actual dimensions in my problem
train_batch = torch.randn(b,c,h,w)
test_batch = torch.randn(b,c,h,w)
d = torch.cdist(train_batch,test_batch)
You can think of test_batch and train_batch as the tensors in the for loop for test_batch in train: for train_batch in test:..., and the expected output would have a shape (b,).
CodePudding user response:
It is common to have to reshape your data before feeding it to a builtin PyTorch operator. As you've said torch.cdist works with two inputs shaped (B, P, M) and (B, R, M) and returns a tensor shaped (B, P, R).
Instead, you have two tensors shaped the same way: (b, c, h, w). If we match those dimensions we have: B=b, M=c, while P=h*w (from the 1st tensor) and R=h*w (from the 2nd tensor). This requires flattening the spatial dimensions together and swapping the last two axes. Something like:
>>> x1 = train_batch.flatten(2).transpose(1,2)
>>> x2 = test_batch.flatten(2).transpose(1,2)
>>> d = torch.cdist(x1, x2)
Now d contains distance between all possible pairs (train_batch[b, :, iy, ix], test_batch[b, :, jy, jx]) and is shaped (b, h*w, h*w).
You can then apply a knn using argmax to retrieve the k closest neighbour from one element of the training batch to the test batch.
