Home > OS >  How to flatten a gradient (list of tensors) in tf.function (graph mode)
How to flatten a gradient (list of tensors) in tf.function (graph mode)

Time:01-18

I want to do some linear algebra (e.g. tf.matmul) using the gradient. By default the gradient is returned as a list of tensors, where the tensors may have different shapes. My solution has been to reshape the gradient into a single vector. This works in eager mode, but now I want to compile my code using tf.function. It seems there is no way to write a function which can 'flatten' the gradient in graph mode (tf.function).

grad = [tf.ones((2,10)), tf.ones((3,))]  # an example of what a gradient from tape.gradient can look like

# this works for flattening the gradient in eager mode only
def flatten_grad(grad):
    return tf.concat([tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i]))) for i in range(len(grad))], 0)

I tried converting it like this, but it doesn't work with tf.function either.

@tf.function
def flatten_grad1(grad):
    temp = [None]*len(grad)
    for i in tf.range(len(grad)):
        i = tf.cast(i, tf.int32)
        temp[i] = tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i])))
    return tf.concat(temp, 0)

I tried TensorArrays, but it also does not work.

@tf.function
def flatten_grad2(grad):
    temp = tf.TensorArray(tf.float32, size=len(grad), infer_shape=False)
    for i in tf.range(len(grad)):
        i = tf.cast(i, tf.int32)
        temp = temp.write(i, tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i]))))
    return temp.concat()

CodePudding user response:

Maybe you could try directly iterating over your list of tensors instead of getting individual tensors by their index:

import tensorflow as tf

grad = [tf.ones((2,10)), tf.ones((3,))]  # an example of what a gradient from tape.gradient can look like

@tf.function
def flatten_grad1(grad):
    temp = [None]*len(grad)
    for i, g in enumerate(grad):
        temp[i] = tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), ))
    return tf.concat(temp, axis=0)
print(flatten_grad1(grad))
tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(23,), dtype=float32)

With tf.TensorArray:

@tf.function
def flatten_grad2(grad):
    temp = tf.TensorArray(tf.float32, size=0, dynamic_size=True, infer_shape=False)
    for g in grad:
        temp = temp.write(temp.size(), tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), )))
    return temp.concat()

print(flatten_grad2(grad))

CodePudding user response:

Hi i think the biggest problem is the loops where in python computing loops are not encouraged.

Here's an example of how to flatten using tf functions for your gradient variables looks kind of weird normally should be a consistent shape with a batch

import tensorflow as tf
import numpy as np

@tf.function
def flatten(arr):
     dim = tf.math.reduce_prod(tf.shape(arr)[1:])
     return tf.reshape(arr, [-1, dim])

grad = tf.Variable(np.random.randn(100, 10, 10, 3))

flatten_grad = flatten(grad)
  •  Tags:  
  • Related