Home > OS >  Numba messes up dtype when broadcasting
Numba messes up dtype when broadcasting

Time:02-03

I want to safe storage by using small dtypes. However when I add or multiply a number to an array numba changes the dtype to int64:

Pure Numpy

In:

def f():
    a=np.ones(10, dtype=np.uint8)
    return a 1
f()

Out:

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)

Now with numba:

In:

@njit
def f():
    a=np.ones(10, dtype=np.uint8)
    return a 1
f()

Out:

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int64)

One solution is to replace a 1 with a np.ones(a.shape, dtype=a.dtype) but I cannot imagine something uglier.

Thanks a lot for help!

CodePudding user response:

As you mentioned in the comments, this is probably because numba's default type is int64, and the smaller dtype uint8 gets converted to the larger int64.

Why not just convert it?

@njit
def f():
    a=np.ones(10, dtype=np.uint8)
    return (a 1).astype('uint8')
f()

Output:

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)

That's less ugly than a np.ones(a.shape, dtype=a.dtype). ;)

CodePudding user response:

You can use np.ones_like:

@njit
def f():
    a=np.ones(10, dtype=np.uint8)
    return a   np.ones_like(a)
f()

Output:

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)

...or np.full_like:

@njit
def f():
    a=np.ones(10, dtype=np.uint8)
    return a   np.full_like(a, 100)
f()

Output:

array([101, 101, 101, 101, 101, 101, 101, 101, 101, 101], dtype=uint8)

CodePudding user response:

You can fix this if you are willing to make your function accept inputs. I rewrote your function using the signature_or_function argument of njit:

@numba.njit(signature_or_function='uint8[:](uint8)')
def f(x):
    a = np.ones(10, dtype=np.uint8)
    return a x

f(1)
# array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)

Some documentation on numba signatures. If you define signatures, numba will compile a specialized function for each unique signature and try to use a compatible pre-compiled signature for anything for which a signature isn't explicitly defined. The signature there tells it that it will return an array if unsigned 8-bit integers ('uint8[:]') and take an input of an unsigned 8-bit integer value.

Note that in this case, I had to make the function accept an input because numba seems to default to treating integer literals (e.g., the 1 of a 1) as int64 values, but if you specify that the input to the function is a uint8 and you don't make a more permissive signature, then when you compile and run the function, it will treat your input as uint8 and not up-convert since it doesn't need to.

CodePudding user response:

I guess the simplest thing is to just add two np.uint8:

import numpy as np
from numba import njit

@njit
def f():
    a=np.ones(10, dtype=np.uint8)
    return a   np.uint8(1)
print(f().dtype)

Output:

uint8

I find this more elegant than changing the type of the full array or working with np.ones or np.full.

  •  Tags:  
  • Related