本文介绍了output_graph.pb上的tf.GraphKeys.TRAINABLE_VARIABLES导致空列表的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试从已保存的模型中提取所有权重/偏差 output_graph.pb

I'm trying to extract all the weights/biases from a saved model output_graph.pb.

我读了模型:

def create_graph(modelFullPath):
    """Creates a graph from saved GraphDef file and returns a saver."""
    # Creates graph from saved graph_def.pb.
    with tf.gfile.FastGFile(modelFullPath, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

GRAPH_DIR = r'C:\tmp\output_graph.pb'
create_graph(GRAPH_DIR)

尝试这种希望,我将能够提取每一层中的所有权重/偏差

And attempted this hoping I would be able to extract all weights/biaseswithin each layer.

with tf.Session() as sess:
    all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    print (len(all_vars))

但是,我得到的值为0镜头

However, I'm getting a value of 0 as the len.

最终目标是提取权重和偏差并将其保存到文本文件/np.arrays中。

Final goal is to extract the weights and biases and save it to a text file/np.arrays.

推荐答案

tf.import_graph_def()函数不足信息以重构 tf.GraphKeys.TRAINABLE_VARIABLES 集合(为此,您需要)。但是,如果 output.pb 包含冻结的 GraphDef ,则所有权重将存储在节点在图中。要提取它们,可以执行以下操作:

The tf.import_graph_def() function doesn't have enough information to reconstruct the tf.GraphKeys.TRAINABLE_VARIABLES collection (for that, you would need a MetaGraphDef). However, if output.pb contains a "frozen" GraphDef, then all of the weights will be stored in tf.constant() nodes in the graph. To extract them, you can do something like the following:

create_graph(GRAPH_DIR)

constant_values = {}

with tf.Session() as sess:
  constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
  for constant_op in constant_ops:
    constant_values[constant_op.name] = sess.run(constant_op.outputs[0])

请注意, constant_values 可能会包含比权重更多的值,因此您可能需要进一步过滤 op.name 或其他条件。

Note that constant_values will probably contain more values than just the weights, so you may need to filter further by op.name or some other criterion.

这篇关于output_graph.pb上的tf.GraphKeys.TRAINABLE_VARIABLES导致空列表的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

10-29 07:33