I have a code that uses tensor API. But I added to it a funcation call, that to my understanding, I can't have it in tensor (but not sure about that). In the below code, the call to:
image= io.imread(image_path , plugin='simpleitk')
but the reset is all tensor. Currently, run fail with unknow reason. is there a solution to either change the code (either the io.imread or the reset) to work with each other:
!pip install SimpleITK
# Image preprocessing utils
import skimage.io as io
@tf.function
def parse_images(image_path):
#--------------io.read is not tensor
image= io.imread(image_path , plugin='simpleitk')
#---------------
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, size=[224, 224])
return image
to be called by this code:
import sys
# Create TensorFlow dataset
BATCH_SIZE = 64
train_ds = tf.data.Dataset.from_tensor_slices( my_train_images)
train_ds = (
train_ds
.map(parse_images, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.shuffle(1024)
.batch(BATCH_SIZE, drop_remainder=True)
.prefetch(tf.data.experimental.AUTOTUNE)
)
CodePudding user response:
skimage.io seems to crash every time in graph mode and eager execution mode. You could try using tf.io.read_file:
Generate random data
import numpy
from PIL import Image
for i in range(5):
imarray = numpy.random.rand(300,300,3) * 255
im = Image.fromarray(imarray.astype('uint8')).convert('RGBA')
im.save('result_image{}.png'.format(i))
Process data
import tensorflow as tf
import matplotlib.pyplot as plt
normalization_layer = tf.keras.layers.Rescaling(1./255)
@tf.function
def parse_images(image_path):
raw = tf.io.read_file(image_path)
image = tf.image.decode_png(raw, channels=3)
image = tf.image.resize(normalization_layer(image), size=[224, 224])
return image
train_ds = tf.data.Dataset.from_tensor_slices(['/content/result_image0.png',
'/content/result_image1.png',
'/content/result_image2.png',
'/content/result_image3.png',
'/content/result_image4.png'])
train_ds = (
train_ds
.map(parse_images, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.shuffle(5)
.batch(2, drop_remainder=True)
.prefetch(tf.data.experimental.AUTOTUNE)
)
images = next(iter(train_ds.take(1)))
image = images[0] # (224, 224, 3)
plt.imshow(image.numpy())
Update 1: For *.mhd images, you can try use tf.py_function with the SimpleITK library:
import tensorflow as tf
import matplotlib.pyplot as plt
import SimpleITK as sitk
normalization_layer = tf.keras.layers.Rescaling(1./255)
def parse_images(image_path):
itkimage = sitk.ReadImage(image_path.numpy().decode("utf-8"))
image = sitk.GetArrayFromImage(itkimage)
image = tf.image.resize(normalization_layer(image), size=[224, 224])
return image
train_ds = tf.data.Dataset.from_tensor_slices(['/content/result_image0.mhd',
'/content/result_image1.mhd',
'/content/result_image2.mhd',
'/content/result_image3.mhd',
'/content/result_image4.mhd'])
train_ds = (
train_ds
.map(lambda x: tf.py_function(parse_images, [x], tf.float32))
.shuffle(5)
.batch(2, drop_remainder=True)
.prefetch(tf.data.experimental.AUTOTUNE)
)
images = next(iter(train_ds.take(1)))
image = images[0] # (224, 224, 3)
plt.imshow(image.numpy())

