- I have a 3D tensor
X, of shape[7, 240, 768]. - I have another tensor
mask_idxof shape[7, 240]which contains0/Falseand1/True, where0/Falsemeans I don't want to update the value inX[i][j]and1/Truemeans I want to do thisX[i][j] = tf.zeros([768]).
I have tried using tf.where(mask_idx, tf.zeros([7, 240, 768]), X) but getting this error:
*** tensorflow.python.framework.errors_impl.InvalidArgumentError: condition [7,240], then [7,240,768], and else [7,240,768] must be broadcastable [Op:SelectV2]
Can anyone suggest the correct approach to it?
CodePudding user response:
TF checks if dimensions are broadcastable from right to left, so one simple way is to expand your mask tensor in the last dimension, i.e., make its shape (7,240,1).
tf.where(mask_idx[...,None], 0, X)
