我想通过重用现有神经网络(已经受过训练)的较低层来使用TensorFlow训练新的神经网络。我想删除现有网络的顶层,并用新层替换它们,我还想锁定最低层,以防止反向传播修改它们。以下是一些ascii艺术,可以对此进行总结:
*Original model* *New model*
Output Layer Output Layer (new)
| |
Hidden Layer 3 Hidden Layer 3 (copied)
| ==> |
Hidden Layer 2 Hidden Layer 2 (copied+locked)
| |
Hidden Layer 1 Hidden Layer 1 (copied+locked)
| |
Inputs Inputs
有什么好方法吗?
编辑
我的原始网络是这样创建的:
X = tf.placeholder(tf.float32, shape=(None, 500), name="X")
y = tf.placeholder(tf.int64, shape=(None), name="y")
hidden1 = fully_connected(X, 300, scope="hidden1")
hidden2 = fully_connected(hidden1, 100, scope="hidden2")
hidden3 = fully_connected(hidden2, 50, scope="hidden3")
output = fully_connected(hidden3, 5, activation_fn=None, scope="output)
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, y)
loss = tf.reduce_mean(xentropy, name="loss")
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
training_op = optimizer.minimize(loss)
init = tf.initialize_all_variables()
saver = tf.train.Saver()
# ... Train then save the network using the saver
将加载此网络,锁定2个较低的隐藏层并替换输出层的代码是什么?如果可能的话,能够为每个输入缓存最高锁定层的输出(hidden2)将是很棒的,以加快训练速度。
额外详细信息
我查看了retrain.py和相应的How-To(非常有趣的读物)。该代码基本上会加载原始模型,然后为每个输入计算瓶颈层(即输出层之前的最后一个隐藏层)的输出。然后,它将创建一个全新的模型,并使用瓶颈输出作为输入对其进行训练。这基本上回答了关于复制+锁定层的问题:我只需要在整个训练集上运行原始模型,并存储最顶层锁定层的输出即可。但是我不知道如何处理复制但未锁定(例如,可训练)的层(例如,图中的隐藏第3层)。
谢谢!
最佳答案
TensorFlow使您可以对在每个训练步骤中更新的参数集(Variable
)进行精细控制。例如,在您的模型中,假设这些层都是完全连接的层。然后,您将为每个图层具有一个权重参数和biass参数。假设您在Variable
,W1
,b1
,W2
,b2
,W3
,b3
和Woutput
中具有相应的boutput
对象。假设您正在使用 Optimizer
接口(interface),并假定loss
是您要最小化的值,则只能通过执行以下操作来训练隐藏层和输出层:
opt = GradientDescentOptimizer(learning_rate=0.1)
grads_and_vars = opt.compute_gradients(loss, var_list=[W3, b3, Woutput, boutput])
train_op = opt.apply_gradients(grads_and_vars)
注意:
opt.minimize(loss, var_list)
与上述功能等效,但我将其分成两部分以说明细节。opt.compute_gradients
根据特定的模型参数集计算梯度,并且可以完全控制要考虑的模型参数。请注意,您必须从较早的模型初始化“隐藏第3层”参数,然后随机初始化“输出”层参数。您可以通过从原始模型还原新模型(从原始模型复制所有参数),并添加额外的tf.assign
操作来随机初始化输出层参数来实现。关于neural-network - 如何使用TensorFlow重用现有的神经网络来训练新的神经网络?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/38978972/