我正在Deep MNIST for Experts中重新实现该示例,该示例构造了一个简单的CNN并在MNIST数据集上运行。但是在以下代码行上出现了错误:h_conv1 = tf.nn.leaky_relu(conv2d(x_image, W_conv1) + b_conv1, alpha=1/3)。整个代码脚本为:

def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

# Real data
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])

# First layer pair
W_conv1 = weight_variable([3, 3, 1, 60])
b_conv1 = bias_variable([60])

x_image = tf.reshape(x, [-1, 28, 28, 1])

h_conv1 = tf.nn.leaky_relu(conv2d(x_image, W_conv1) + b_conv1, alpha=1/3)
h_pool1 = max_pool_2x2(h_conv1)


错误回溯为:

Traceback (most recent call last):
  File "scn_mnist.py", line 118, in <module>
    h_conv1 = tf.nn.leaky_relu(conv2d(x_image, W_conv1) + b_conv1, alpha=1/3)
  File "/root/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/nn_ops.py", line 1543, in leaky_relu
    return math_ops.maximum(alpha * features, features)
  File "/root/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/math_ops.py", line 885, in binary_op_wrapper
    y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y")
  File "/root/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 836, in convert_to_tensor
    as_ref=False)
  File "/root/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 926, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/root/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 774, in _TensorTensorConversionFunction
    (dtype.name, t.dtype.name, str(t)))
ValueError: Tensor conversion requested dtype int32 for Tensor with dtype float32: 'Tensor("add:0", shape=(?, 28, 28, 60), dtype=float32)'


我检查了x_image(tf.float32),W_conv1(tf.float32_ref)和b_conv1(tf.float32_ref)的数据类型。数据类型似乎没有问题,所以我真的无法弄清症结所在。

最佳答案

问题出在alpha上,它将tensorflow解释为一个int。将其更改为alpha=0.3以避免转换。

08-19 20:28