Home > Mobile >  How to remove black canvas from image in TensorFlow
How to remove black canvas from image in TensorFlow

Time:01-14

I'm currenly trying working with tensorflow dataset 'tf_flowers', and noticed that a lot of images consist mostly of black canvas, like this: flower1 flower2 Is there any easy way to remove/or filter it out? Preferably it should work on batches, and compile into a graph with @tf.function, as I plan to use it also for bigger datasets with dataset.map(...)

CodePudding user response:

The black pixels are just because of padding. This is a simple operation that allows you to have network inputs having the same size (i.e. you have batches containing images with the of size: 223x221 because smaller images are padded with black pixels).

An alternative to padding that removes the need of adding black pixels to the image, is that of preprocessing the images by:

  • removing padding via cropping operation
  • resizing the cropped images to the same size (e.g. 223x221)

You can do all of these operations in simple python, thanks to tensorflow map function. First, define your python function

def py_preprocess_image(numpy_image):
    input_size = numpy_image.shape  # this is (223, 221) 
    image_proc = crop_by_removing_padding(numpy_image)
    image_proc = resize(image_proc, size=input_size)
    return image_proc 

Then, given your tensorflow dataset train_data, map the above python function on each input:

# train_data is your tensorflow dataset
train_data = train_data.map(
                 lambda x: tf.py_func(preprocess_image,
                                      inp = [x], Tout=[tf.float32]),
                 num_parallel_calls=num_threads
             )

Now, you only need to define crop_by_removing_padding and resize, which operate on ordinary numpy arrays and can thus be written in pure python code. For example:

def crop_by_removing_padding(img):
    xmax, ymax = np.max(np.argwhere(img), axis=0)
    img_crop = img[:xmax   1, :ymax   1]
    return img_crop

def resize(img, new_size):
    img_rs = cv2.resize(img, (new_size[1], new_size[0]), interpolation=cv2.INTER_CUBIC)
    return img_rs
  •  Tags:  
  • Related