I would like to do simple division and average using jit function where nopython = True.
import numpy as np
from numba import jit,prange,typed
A = np.array([[2,2,2],[1,0,0],[1,2,1]], dtype=np.float32)
B = np.array([[2,0,2],[0,1,0],[1,2,1]],dtype=np.float32)
C = np.array([[2,0,1],[0,1,0],[1,1,2]],dtype=np.float32)
my jit function goes
@jit(nopython=True)
def test(a,b,c):
mask = a b >0
div = np.divide(c, a b, where=mask)
result = div.mean(axis=1)
return result
test_res = test(A,B,C)
however this throws me an error, what would be the workaround for this? I am trying to do this without the loop, any lights would be appreiciate.
CodePudding user response:
numba doesn't support some arguments for some of numpy modules (e.g. np.mean() or where in np.divid) (including "axis" argument which is not included). You can do this by some alternative codes like:
@nb.njit("float64[::1](float32[:, ::1], float32[:, ::1], float32[:, ::1])") # parallel --> , parallel=True
def test(a, b, c):
result = np.zeros(c.shape[0])
for i in range(c.shape[0]): # parallel --> for i in nb.prange(c.shape[0]):
for j in range(c.shape[1]):
if a[i, j] b[i, j] > 0:
c[i, j] = c[i, j] / (a[i, j] b[i, j])
result[i] = c[i, j]
return result / c.shape[1]
