I have a list of colors, and I have a function closest_color(pixel, colors) where it compares the given pixels' RGB values with my list of colors, and it outputs the closest color from the list.
I need to apply this function to a whole image. When I try to use it pixel by pixel, (by using 2 nested for-loops) it is slow. Is there a better way to achieve this with numpy?
CodePudding user response:
The task is to turn a picture into a palette version of it. You define a palette, and then you need to find, for every pixel, the nearest neighbor match in the defined palette for that pixel's color. You get an index from that lookup, which you can then turn into the palette color for that pixel.
This is possible using FLANN. It's not much code either. The lookups take two seconds on my old computer.
The advantage of this approach is that it can handle "large" palettes (more than a handful of colors) without requiring lots of memory.
This isn't the quickest possible solution. I can imagine some things that need neither FLANN index structures nor extensive use of memory. See my other answer that uses numba.
CodePudding user response:
Not as fast I would expect. Using np.argmin as indices into precreated container of colors.
import numpy as np
from PIL import Image
import requests
# get some image
im = Image.open(requests.get("https://upload.wikimedia.org/wikipedia/commons/thumb/7/77/Big_Nature_(155420955).jpeg/800px-Big_Nature_(155420955).jpeg", stream=True).raw)
newsize = (1000, 1000)
im = im.resize(newsize)
# im.show()
im = np.asarray(im)
new_shape = (im.shape[0],im.shape[1],1,3)
# Ignore above
# Now we have image of shape (1000,1000,1,3). 1 is there so its easy to subtract from color container
image = im.reshape(im.shape[0],im.shape[1],1,3)
# test colors
colors = [[0,0,0],[255,255,255],[0,0,255]]
# Create color container
## It has same dimensions as image (1000,1000,number of colors,3)
colors_container = np.ones(shape=[image.shape[0],image.shape[1],len(colors),3])
for i,color in enumerate(colors):
colors_container[:,:,i,:] = color
def closest(image,color_container):
shape = image.shape[:2]
total_shape = shape[0]*shape[1]
# calculate distances
### shape = (x,y,number of colors)
distances = np.sqrt(np.sum((color_container-image)**2,axis=3))
# get position of the smalles distance
## this means we look for color_container position ????-> (x,y,????,3)
### before min_index has shape (x,y), now shape = (x*y)
#### reshaped_container shape = (x*y,number of colors,3)
min_index = np.argmin(distances,axis=2).reshape(-1)
# Natural index. Bind pixel position with color_position
natural_index = np.arange(total_shape)
# This is due to easy index access
## shape is (1000*1000,number of colors, 3)
reshaped_container = colors_container.reshape(-1,len(colors),3)
# Pass pixel position with corresponding position of smallest color
color_view = reshaped_container[natural_index,min_index].reshape(shape[0],shape[1],3)
return color_view
# NOTE: Dont pass uint8 due to overflow during subtract
result_image = closest(image,colors_container)
Image.fromarray(result_image.astype(np.uint8)).show()
CodePudding user response:
Here are two variants using numba, a JIT compiler for python code.
from numba import njit, prange
The first variant uses more numpy primitives (np.argmin) and hence "more" memory. Maybe the little bit of memory has an effect, or maybe numba calls numpy routines as is, without being able to optimize those.
@njit(parallel=True)
def lookup1(palette, im):
palette = palette.astype(np.int32)
(rows,cols) = im.shape[:2]
result = np.zeros((rows, cols), dtype=np.uint8)
for i in prange(rows):
for j in range(cols):
sqdists = ((im[i,j] - palette) ** 2).sum(axis=1)
index = np.argmin(sqdists)
result[i,j] = index
return result
I get ~180-190 ms per run on lena.jpg and a palette of 125 colors.
The second variant uses more hand-written code to replace most of the numpy primitives, which makes it even faster.
@njit(parallel=True)
def lookup2(palette, im):
(rows,cols) = im.shape[:2]
result = np.zeros((rows, cols), dtype=np.uint8)
for i in prange(rows): # parallelize over this
for j in range(cols):
pb,pg,pr = im[i,j] # take pixel apart
bestindex = -1
bestdist = 2**20
for index in range(len(palette)):
cb,cg,cr = palette[i] # take palette color apart
dist = (pb-cb)**2 (pg-cg)**2 (pr-cr)**2
if dist < bestdist:
bestdist = dist
bestindex = index
result[i,j] = bestindex
return result
30 ms per run!
I think that's approaching the theoretical maximum to within an order of magnitude. I figure that from the required math operations.
per palette entry: A = 10 ops
3 subtracts, 3 squares, 3 adds, 1 compare
per pixel: B = 1375 ops
len(palette) * (A 1), one index increment
per row: C = 704512 ops
ncols * (B 1), one index increment
per image: D = 360710656 ops
nrows * (C 1), one index increment
So that, in 30 ms, on my ancient quadcore with hyperthreading, gives 12000 MIPS (I won't say flop/s because no floating point). That means close to one instruction per cycle. I'm sure the code lacks some SIMD vectorization... one could investigate what LLVM thinks of these loops but I won't bother with that now.
Some code in cython might be able to beat this because there you can tie down the types of variables even more.

