本文介绍了为什么 Tensorflow Reshape tf.reshape() 会破坏梯度流?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在创建一个 tf.Variable(),然后使用该变量创建一个简单的函数,然后使用 tf.reshape() 将原始变量展平,然后在函数和展平之间使用 tf.gradients()多变的.为什么会返回 [None].

I am creating a tf.Variable() and then create a simple function using that variable, then I flatten the original variable using tf.reshape() and then I take the tf.gradients() between the function and the flattened variable. Why does that return [None].

var = tf.Variable(np.ones((5,5)), dtype = tf.float32)
f = tf.reduce_sum(tf.reduce_sum(tf.square(var)))
var_f = tf.reshape(var, [-1])
print tf.gradients(f,var_f)

上述代码块在执行时返回 [None].这是一个错误吗?请帮忙!

The above codeblock when executed returns [None]. Is this a bug? Please Help!

推荐答案

你正在寻找 f 相对于 var_f 的导数,但是 f> 不是 var_f 的函数,而是 var 的函数.这就是为什么你得到 [无].现在,如果您将代码更改为:

You are finding derivative of f with respect to var_f, but f is not a function of var_f but var instead. Thats why you are getting [None]. Now if you change the code to:

 var = tf.Variable(np.ones((5,5)), dtype = tf.float32)
 var_f = tf.reshape(var, [-1])
 f = tf.reduce_sum(tf.reduce_sum(tf.square(var_f)))
 grad = tf.gradients(f,var_f)
 print(grad)

您的渐变将被定义:

tf.Tensor 'gradients_28/Square_32_grad/mul_1:0' shape=(25,) dtype=float32>

以下代码的图形可视化如下:

The visualization of the graphs for the following code is given below:

 var = tf.Variable(np.ones((5,5)), dtype = tf.float32, name='var')
 f = tf.reduce_sum(tf.reduce_sum(tf.square(var)), name='f')
 var_f = tf.reshape(var, [-1], name='var_f')
 grad_1 = tf.gradients(f,var_f, name='grad_1')
 grad_2 = tf.gradients(f,var, name='grad_2')

grad_1 的导数未定义,而对于 'grad_2` 已定义.显示了两个梯度的反向传播图(梯度图).

The derivative of grad_1 is not defined, while for 'grad_2` its defined. The back-propagation graph (gradient graphs) of the two gradients are shown.

这篇关于为什么 Tensorflow Reshape tf.reshape() 会破坏梯度流?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

09-23 01:58