Home > Blockchain >  Not a Tensor API call inside a Tensor block
Not a Tensor API call inside a Tensor block

Time:01-31

I have a code that uses tensor API. But I added to it a funcation call, that to my understanding, I can't have it in tensor (but not sure about that). In the below code, the call to:

image= io.imread(image_path , plugin='simpleitk')

but the reset is all tensor. Currently, run fail with unknow reason. is there a solution to either change the code (either the io.imread or the reset) to work with each other:

    !pip install SimpleITK
    # Image preprocessing utils
    import skimage.io as io
    @tf.function
    def parse_images(image_path):
        
        #--------------io.read is not tensor
        
        image= io.imread(image_path , plugin='simpleitk')
        
        #---------------
        
        image = tf.image.convert_image_dtype(image, tf.float32)
        image = tf.image.resize(image, size=[224, 224])
    
        return image

to be called by this code:

import sys
# Create TensorFlow dataset
BATCH_SIZE = 64

train_ds = tf.data.Dataset.from_tensor_slices( my_train_images) 

train_ds = (
    train_ds
    .map(parse_images, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .shuffle(1024)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

CodePudding user response:

skimage.io seems to crash every time in graph mode and eager execution mode. You could try using tf.io.read_file:

Generate random data

import numpy
from PIL import Image

for i in range(5):
  imarray = numpy.random.rand(300,300,3) * 255
  im = Image.fromarray(imarray.astype('uint8')).convert('RGBA')
  im.save('result_image{}.png'.format(i))

Process data

import tensorflow as tf
import matplotlib.pyplot as plt

normalization_layer = tf.keras.layers.Rescaling(1./255)

@tf.function
def parse_images(image_path):
    raw = tf.io.read_file(image_path) 
    image = tf.image.decode_png(raw, channels=3)
    image = tf.image.resize(normalization_layer(image), size=[224, 224])
    return image

train_ds = tf.data.Dataset.from_tensor_slices(['/content/result_image0.png', 
                                               '/content/result_image1.png', 
                                               '/content/result_image2.png', 
                                               '/content/result_image3.png', 
                                               '/content/result_image4.png']) 

train_ds = (
    train_ds
    .map(parse_images, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .shuffle(5)
    .batch(2, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

images = next(iter(train_ds.take(1)))

image = images[0] # (224, 224, 3)
plt.imshow(image.numpy())

enter image description here

Update 1: For *.mhd images, you can try use tf.py_function with the SimpleITK library:

import tensorflow as tf
import matplotlib.pyplot as plt
import SimpleITK as sitk

normalization_layer = tf.keras.layers.Rescaling(1./255)

def parse_images(image_path):
    itkimage = sitk.ReadImage(image_path.numpy().decode("utf-8"))
    image = sitk.GetArrayFromImage(itkimage)
    image = tf.image.resize(normalization_layer(image), size=[224, 224])
    return image

train_ds = tf.data.Dataset.from_tensor_slices(['/content/result_image0.mhd', 
                                               '/content/result_image1.mhd', 
                                               '/content/result_image2.mhd', 
                                               '/content/result_image3.mhd', 
                                               '/content/result_image4.mhd']) 
train_ds = (
    train_ds
    .map(lambda x: tf.py_function(parse_images, [x], tf.float32))
    .shuffle(5)
    .batch(2, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

images = next(iter(train_ds.take(1)))

image = images[0] # (224, 224, 3)
plt.imshow(image.numpy())
  •  Tags:  
  • Related