I am trying to use np.where() function with nested lists. I would like to find an index with a given condition of the first layer of the nested list.
For example, if I have the following code
arr = [[1,1], [2,2],[3,3]]
a = np.where(arr == [2,2])
then ideally I would like code to return 'a' as 1. Since [2,2] is in index 1 of the nested list.
However, I am just getting a empty array back as a result.
Of course, I can make it work easily by implementing external for loop such as
for n in range(len(arr)):
if arr[n] == [2,2]:
a = n
but I would like to implement this simply within the function np.where(write the entire code here).
Is there a way to do this?
CodePudding user response:
The best solution is that mentioned by @Michael Szczesny, but using np.where you can do this too:
a = np.where(np.array(arr) == [2, 2])[0]
resulted_ind = np.where(np.bincount(a) == 2)[0] # --> [1]
CodePudding user response:
Well you can write your own function to do so:
You'll need to
- Find every line equal to what you looking for
- Get indices of found rows (You can use
where):
numpy compression
You can use compression operator to see if each line satisfies the condition. Such as:
np_arr = np.array(
[1, 2, 3, 4, 5]
)
print(np_arr < 3)
This will return a boolean where every element is True or False where the condition is satisfied:
[ True True False False False]
For a 2D array you'll get a 2D boolean array:
to_find = np.array([2, 2])
np_arr = np.array(
[
[1, 1],
[2, 2],
[3, 3],
[2, 2]
]
)
print(np_arr == to_find)
The result is:
[[False False]
[ True True]
[False False]
[ True True]]
Now we are looking for lines with all True values. So we can use all method of ndarray. And we will provide the axis we are looking to look to all. X, Y or Both. We want to look to x axis:
to_find = np.array([2, 2])
np_arr = np.array(
[
[1, 1],
[2, 2],
[3, 3],
[2, 2]
]
)
print((np_arr == to_find).all(axis=1))
The result is:
[False True False True]
Get indices of Trues
At the end you are looking for indices where the values are True:
np.where((np_arr == to_find).all(axis=1))
The result would be:
(array([1, 3]),)
CodePudding user response:
numpy runs in Python, so you can use both the basic Python lists and numpy arrays (which are more like MATLAB matrices)
A list of lists:
In [43]: alist = [[1,1], [2,2],[3,3]]
A list has an index method, which tests against each element of the list (elements here are 2 element lists):
In [44]: alist.index([2,2])
Out[44]: 1
In [45]: alist.index([2,3])
Traceback (most recent call last):
Input In [45] in <cell line: 1>
alist.index([2,3])
ValueError: [2, 3] is not in list
alist==[2,2] returns False, because the list is not the same as the [2,2] list.
If we make an array from that list:
In [46]: arr = np.array(alist)
In [47]: arr
Out[47]:
array([[1, 1],
[2, 2],
[3, 3]])
we can do an == test - but it compares numeric elements.
In [48]: arr == np.array([2,2])
Out[48]:
array([[False, False],
[ True, True],
[False, False]])
Underlying this comparison is the concept of broadcasting, allow it to compare a (3,2) array with a (2,) (a 2d with a 1d). Here's its trivial, but it can be much more complicated.
To find rows where all values are True, use:
In [50]: (arr == np.array([2,2])).all(axis=1)
Out[50]: array([False, True, False])
and where finds the True in that array (the result is a tuple with 1 array):
In [51]: np.where(_)
Out[51]: (array([1]),)
In Octave the equivalent is:
>> arr = [[1,1];[2,2];[3,3]]
arr =
1 1
2 2
3 3
>> all(arr == [2,2],2)
ans =
0
1
0
>> find(all(arr == [2,2],2))
ans = 2
