Home > database >  how to check if one `string` tensor is alphabetically greater than another `string` tensor in tensor
how to check if one `string` tensor is alphabetically greater than another `string` tensor in tensor

Time:01-16

I know how to check equality of two string tensors. Any of those works just fine.

tf.constant('foo') == 'foo' or

tf.math.equal(tf.constant('foo'), 'foo') or

tf.equal(tf.constant('foo'), 'foo')

However, I couldn't find a way to check if one string tensor alphabetically greater/lower than another one.

Neither of those works and produces the same error:

tf.constant('foo') > 'foo' or

tf.math.greater(tf.constant('foo'), 'foo') or

tf.greater(tf.constant('foo'), 'foo')

error message:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Value for attr 'T' of string is not in the list of allowed values: float, double, int32, uint8, int16, int8, int64, bfloat16, uint16, half, uint32, uint64 ; NodeDef: {{node Greater}}; Op<name=Greater; signature=x:T, y:T -> z:bool; attr=T:type,allowed=[DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, DT_INT8, DT_INT64, DT_BFLOAT16, DT_UINT16, DT_HALF, DT_UINT32, DT_UINT64]> [Op:Greater]

CodePudding user response:

You could use tf.strings with tf.lookup.StaticHashTable:

import tensorflow as tf
import string

foo = tf.constant('foo')
bar = tf.constant('bar')

foo = tf.strings.split(tf.strings.regex_replace(foo, "(.)", r'\1 '))
bar = tf.strings.split(tf.strings.regex_replace(bar, "(.)", r'\1 '))

keys_tensor = tf.constant(tf.constant([*string.ascii_lowercase]))
vals_tensor = tf.constant(tf.range(26))

table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
    default_value=-1)

print('foo -->', table.lookup(foo))
print('bar -->', table.lookup(bar))
print(tf.greater(table.lookup(foo), table.lookup(bar)))
foo --> tf.Tensor([ 5 14 14], shape=(3,), dtype=int32)
bar --> tf.Tensor([ 1  0 17], shape=(3,), dtype=int32)
tf.Tensor([ True  True False], shape=(3,), dtype=bool)

I think you can figure out the rest on how to make sense of the boolean values.

CodePudding user response:

I'll post my own answer. I encourage other folks to chime in with their solutions.

import tensorflow as tf

def maybe_cast_to_tensor(maybe_tensor):
    if not isinstance(maybe_tensor, tf.Tensor):
        return tf.constant(maybe_tensor)
    return maybe_tensor

def is_lower(op1, op2):
    op1 = maybe_cast_to_tensor(op1)
    op2 = maybe_cast_to_tensor(op2)

    # cast strings to ints. we assume that `decode_raw` preserves order.
    op1 = tf.io.decode_raw(op1, tf.uint8)
    op2 = tf.io.decode_raw(op2, tf.uint8)

    max_len = tf.math.reduce_max([tf.shape(op1), tf.shape(op2)])

    # pad `op1`, `op2` to same size
    op1 = tf.pad(op1, [[0, max_len - tf.shape(op1)[0]]])
    op2 = tf.pad(op2, [[0, max_len - tf.shape(op2)[0]]])

    # the main logic
    # vectorized version that checks the existence of such `0 <= k` that:
    # 1. op1[i] == op2[i] for some prefix 1..(k-1)
    # 2. op1[k] < op2[k]
    prev_equal_mask = tf.math.cumsum(tf.cast(op1 == op2, tf.int32)) == tf.range(tf.shape(op1)[0])   1
    prev_equal_mask = tf.concat([tf.constant(True, shape=(1,)), prev_equal_mask[:-1]], 0)
    return tf.reduce_any(prev_equal_mask & (op1 < op2))

print(is_lower('ab', 'cd')) # true
print(is_lower('cd', 'ab')) # false
print(is_lower('ab', 'ab')) # false
print(is_lower('ab', 'abc')) # true
print(is_lower('abcd', 'abe')) # true
print(is_lower('abcd', 'abb')) # false
  •  Tags:  
  • Related