I have two numpy arrays like :
a = [False, False, False, False, False, True, False, False]
b = [1, 2, 3, 4, 5, 6, 7, 8]
I need to sum over b, not the full array, but only until the elements with the equivalent index in a is True
In other words, I want to do 1 2 3 4 5=15 instead of 1 2 3 4 5 6 7 8=36
I need an efficient solution, I think I need to mask all elements from b that are after the first True in a and make them 0.
Side note: My code is in jax.numpy and not original numpy but I guess it doesn't really matter.
CodePudding user response:
You can do a cumulated sum
np.sum(b[np.cumsum(a)==0])
CodePudding user response:
I would suggest to convert the array to a list with .tolist() and then apply .index() to obtain the index of the first True: i = a.tolist().index(True).
Then simple slicing and summing: total = numpy.sum(b[:i])
CodePudding user response:
I can think of two ways of doing this: you could do it by constructing a mask with cumsum (this will also work in regular numpy):
a = jnp.array([False, False, False, False, False, True, False, False])
b = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
mask = a.cumsum() == 0
b.sum(where=mask) # 15
Or you could find the first True index with jnp.where (note that the size argument only exists in JAX's version of jnp.where, not in numpy's):
idx = jnp.where(a, size=1)[0][0]
b[:idx].sum() # 15
You might do some microbenchmarks to determine which is more efficient for the size of arrays that you're concerned with.
