Let us use the following code
# !/usr/bin/env python3
# encoding: utf-8
import numpy as np, tensorflow as tf # tf.__version__==2.7.0
sample_array=np.random.uniform(size=(2**10, 120, 20))
to_select=[5, 6, 9, 4]
sample_tensor=tf.convert_to_tensor(value=sample_array)
sample_array[:, :, to_select] # Works okay
sample_tensor[:, :, to_select] # TypeError. How to do this in tensor?
Basically, how to get those elements as a tensor of appropriate dimension, just like numpy? I tried tf.slice and tf.gather, but cannot figure out the proper arguments to pass.
I can convert it to numpy and back, but not sure if it will sacrifice the operation's efficiency, and work as part of a custom training loop.
CodePudding user response:
I have simplified the dimensions to clearly see the result. You could check first if you get what you want.
sample_array=np.random.randint(100,size=( 10, 10, 20))
to_select=tf.constant([5, 6, 9, 4])
sample_tensor=tf.convert_to_tensor(value=sample_array)
print(sample_array)
# sample_array[:, :, to_select] # Works okay
print(sample_tensor[:, :])
print(tf.gather(sample_tensor,
indices=to_select))
CodePudding user response:
The simplest solution would be to use tf.concat:
import numpy as np
import tensorflow as tf
sample_array = np.random.uniform(size=(2, 2, 20))
to_select = [5, 6, 9, 4]
sample_tensor = tf.convert_to_tensor(value = sample_array)
numpy_way = sample_array[:, :, to_select]
tf_way = tf.concat([tf.expand_dims(sample_array[:, :, to_select[i]], axis=-1) for i in tf.range(len(to_select))], axis=-1)
print(numpy_way)
print(tf_way)
[[[0.95155085 0.27463579 0.74310211 0.73047673]
[0.16477047 0.04026846 0.10771453 0.3344928 ]]
[[0.2969326 0.8663296 0.64625728 0.71089697]
[0.51603801 0.45761795 0.59975939 0.35596491]]]
tf.Tensor(
[[[0.95155085 0.27463579 0.74310211 0.73047673]
[0.16477047 0.04026846 0.10771453 0.3344928 ]]
[[0.2969326 0.8663296 0.64625728 0.71089697]
[0.51603801 0.45761795 0.59975939 0.35596491]]], shape=(2, 2, 4), dtype=float64)
A more complicated solution would involve using tf.meshgrid and tf.gather_nd. Check this post or this post and finally this.
