本文介绍了python2.7中tf.gather_nd中的星号出现语法错误的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我使用的是Python2.7,我无法更新它,我有这行代码,它在星号处引发错误,我不知道为什么?以及如何修复!

inp = tf.random.uniform(shape=[4, 6, 2], maxval=20, dtype=tf.int32)out = tf.math.reduce_max(inp,axis=2)am = tf.math.argmax(out,axis=1)o = tf.gather_nd(inp, [*enumerate(am)])

此代码是关于使用 TensorFlow 1.14 根据最大​​值一值从 3D 张量获取 2D 最大值张量.如下图所示:

解决方案

您问题中的语法错误已由 BoarGules 解释

一>.对于您要解决的问题,您可以通过以下方式获得您想要的结果:

将tensorflow导入为tf使用 tf.Graph().as_default(), tf.Session() 作为 sess:# 在 TF 2.x 中:tf.random.set_seedtf.random.set_random_seed(0)# 输入数据inp = tf.random.uniform(shape=[4, 6, 2], maxval=100, dtype=tf.int32)# 查找最后两个维度中最大值的索引s = tf.shape(inp)inp_res = tf.reshape(inp, [s[0], -1])max_idx = tf.math.argmax(inp_res,axis=1,output_type=s.dtype)# 获取行索引除以列数max_row_idx = max_idx//s[2]# 获取具有最大值的行res = tf.gather_nd(inp, tf.expand_dims(max_row_idx,axis=1),batch_dims=1)# 打印输入和结果打印(*sess.run((inp, res)), sep='\n')

输出:

[[[22 78][75 70][31 10][67 9][70 45][5 33]][[82 83][82 81][73 58][18 18][57 11][50 71]][[84 55][80 72][93 1][98 27][36 6][10 95]][[83 24][19 9][46 48][90 87][50 26][55 62]]][[22 78][82 83][98 27][90 87]]

I am using Python2.7, and I can't update it, and I have this line of code, which raise an error at the asterisk, and I don't know why? And how to fix!

inp = tf.random.uniform(shape=[4, 6, 2], maxval=20, dtype=tf.int32)

out = tf.math.reduce_max(inp, axis=2)
am = tf.math.argmax(out, axis=1)
o = tf.gather_nd(inp, [*enumerate(am)])

This code is about getting a 2D max Tensor from a 3D Tensor based on the maximum one value using TensorFlow 1.14. Like the image below illustrate:

The syntax error in your question has been explained by BoarGules. With respect to the problem that you are trying to solve, you can get the result you want with something like this:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    # In TF 2.x: tf.random.set_seed
    tf.random.set_random_seed(0)
    # Input data
    inp = tf.random.uniform(shape=[4, 6, 2], maxval=100, dtype=tf.int32)

    # Find index of greatest value in last two dimensions
    s = tf.shape(inp)
    inp_res = tf.reshape(inp, [s[0], -1])
    max_idx = tf.math.argmax(inp_res, axis=1, output_type=s.dtype)
    # Get row index dividing by number of columns
    max_row_idx = max_idx // s[2]
    # Get rows with max values
    res = tf.gather_nd(inp, tf.expand_dims(max_row_idx, axis=1), batch_dims=1)
    # Print input and result
    print(*sess.run((inp, res)), sep='\n')

Output:

[[[22 78]
  [75 70]
  [31 10]
  [67  9]
  [70 45]
  [ 5 33]]

 [[82 83]
  [82 81]
  [73 58]
  [18 18]
  [57 11]
  [50 71]]

 [[84 55]
  [80 72]
  [93  1]
  [98 27]
  [36  6]
  [10 95]]

 [[83 24]
  [19  9]
  [46 48]
  [90 87]
  [50 26]
  [55 62]]]
[[22 78]
 [82 83]
 [98 27]
 [90 87]]

这篇关于python2.7中tf.gather_nd中的星号出现语法错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

10-29 02:17