Basically, i have a list of indices generated by tf.random.uniform and a tensor called tensor_big, now i need to create a new tensor tensor_small where inside are all the elements of the main tensor where the second coordinate is inside the indices list.
Example:
Indices = [1, ......]
Then i need to create a new tensor with the weights at position [0,1], [1,1], [2,1] etc for each indices.
import tensorflow as tf
if __name__ == '__main__':
tensor_big = tf.random.uniform(
(3136,512), minval=0, maxval=None, dtype=tf.dtypes.float32, seed=None, name=None
)
indices = tf.random.uniform(shape=[410, ], minval=0, maxval=512, dtype=tf.dtypes.int32, seed=None, name=None)
for weight in tensor_big:
print(weight[1])
tensor_small = tf.reshape(tf.gather(tensor_big, WHERE_SECOND_COORDINATE_INSIDE_INDICES), (3136,410))
print(tensor_small)
CodePudding user response:
You can use tf.gather with the argument axis=1 to select the columns:
tensor_small = tf.gather(tensor_big, indices, axis=1)
