I am building a model that applies a random shuffle to data along the first non batch axis, applies a series of Conv1Ds, then applies the inverse of the shuffle. Unfortunately the tf.gather layer messes up the batch dimension None, and i'm not sure why.
Below is an example of what happens.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
dim = 90
input_img = keras.Input(shape=(dim, 4))
# Get random shuffle order
order = layers.Lambda(lambda x: tf.random.shuffle(tf.range(x)))(dim)
# Apply shuffle
tensor = layers.Lambda(lambda x: tf.gather(x[0], tf.cast(x[1], tf.int32), axis=1,))(input_img, order)
model = keras.models.Model(
inputs=[input_img],
outputs=tensor,
)
Here the summary is as follows:
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 90, 4)] 0
_________________________________________________________________
lambda_51 (Lambda) (90, 90, 4) 0
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
Whereas I want the output shape of lambda_51 to be (None, 90, 4).
CodePudding user response:
Try to wrap input_img and order into a list when you pass them to tensor layer.
In this way tensor layer becomes:
tensor = layers.Lambda(lambda x: tf.gather(x[0], tf.cast(x[1], tf.int32), axis=1,))([input_img, order])
and your summary:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 90, 4)] 0
_________________________________________________________________
lambda_3 (Lambda) (None, 90, 4) 0
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
