Home > Net >  How to slice a tensor with values along an axis in a numpy-optimized way
How to slice a tensor with values along an axis in a numpy-optimized way

Time:02-04

I need to slice a rank-3 tensor selecting a different value of the second dimension for each entry of the first dimension.
I need to do it for each 3-tensor, but for example, consider the following (3,4,3) tensor A

[[[ 0  1  2]
[ 3  4  5]
[ 6  7  8]
[ 9 10 11]]

[[12 13 14]
[15 16 17]
[18 19 20]
[21 22 23]]

[[24 25 26]
[27 28 29]
[30 31 32]
[33 34 35]]]

and the following index list indices for the second dimension

[1,2,3]

then want to obtain the following (3,3) tensor out

[[ 3  4  5]
[18 19 20]
[33 34 35]]

I know how to write it with a for loop:

out=[]
for i,ind in enumerate(indices):
    sel.append(A[i,ind,:])
out=np.array(out)

But I was wondering if there is a more optimized way of writing such function with only numpy parallel functions and without a loop.

CodePudding user response:

In addition to providing the indices for the second axis you also need to provide the indices for the first axis. You describe your task as follows:

  • In row 0 select column 1
  • In row 1 select column 2
  • In row 2 select column 3

You already have the column indices [1, 2, 3], but you also need to provide the row indices [0, 1, 2]:

result = A[np.arange(len(A)), [1,2,3], ...]
  •  Tags:  
  • Related