Home > Enterprise >  Numpy use all rows in multidimensional integer indexing
Numpy use all rows in multidimensional integer indexing

Time:01-12

I have a multidimensional array in numpy:

a=np.array([[1,2,3],[4,5,6],[7,8,9]])

   array([[1, 2, 3],
          [4, 5, 6],
          [7, 8, 9]])

Let's say I want to access the elements 2,5, and 9. I know that I can use integer indexing to do it like:

a[np.arange(3),np.array([1,1,2])]
>>>array([2, 5, 9])

But is there a way to select every row without the use of np.arrange? I already know that a[:,np.array([1,1,2])] is not working since it returns

array([[2, 2, 3],
       [5, 5, 6],
       [8, 8, 9]])

There is probably some easy way, but I missed it in the documentation.

CodePudding user response:

Use np.take_along_axis

np.take_along_axis(a, np.array([1,1,2])[:,None], 1).squeeze() # correction by @hilberts_drinking_problem

Out: array([2, 5, 9])

CodePudding user response:

It's not a solution for easier access to all items in dim, but if you want to select specific items in the array, you can pass a list into dim to say NumPy iterate over the next dim if it is also a list.

e.g.

def some(i):
     return [i, i 1]

a[some(0), [1, 2]]

returns array([2, 6])

  •  Tags:  
  • Related