概述

tensorflow的自带例程用两个文件演示了“全连接层+softmax回归”实现mnist图片识别的功能。一个文件是mnist.py,在之前一篇文章《tensorflow 11:双隐层+softmax回归实现mnist图片识别》已经介绍过了。不过mnist.py侧重搭建计算图,没有调用过程。

本文讲解fully_connected_feed.py这个文件,主要讲解调用过程相关的知识点。

fully_connected_feed.py速览

fully_connected_feed.py所做的工作除了正常的参数解析、建立图、读取数据、循环训练,还包括了保存图结构、信息汇总、保存检查点文件。

参数解析

fully_connected_feed.py把模型的(超)参数全部解析到全局变量FLAGS里面,然后其它地方用FLAGS获取用户传参。

参数解析相关知识请看之前的一篇博文《tensorflow 9. 参数解析和经典入口函数tf.app.run》

状态可视化

tensorflow提供了一个工具可视化训练过程,该工具叫tensorboard。为tensorboard提供的信息一般分两类:图结构和即时信息。

保存图结构

为了保存图结构,在计算图简历完毕以后,只要实例化一个SummaryWriter就行了。

# 实例化一个 SummaryWriter 输出 summaries 和 Graph.
summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

只要调用一次summary_writer.flush()或者summary_writer.close(), 计算图的结构就会写入tensorboard的日志文件中。
如下:
tensorflow 12:双隐层+softmax回归实现mnist图片识别之二-LMLPHP

name_scope

可以看到上图的结构是分层显示的,双击一个模块,会显示模块内的详细信息。双击隐层1之后效果如下:
tensorflow 12:双隐层+softmax回归实现mnist图片识别之二-LMLPHP

这种分层的效果是用name_scope实现的:

  with tf.name_scope('hidden1'):
    weights = tf.Variable(
        tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
                            stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
                            name='weights')
    biases = tf.Variable(tf.zeros([hidden1_units]), name='biases')
    hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)

在结构图上,所有with tf.name_scope(‘hidden1’)模块内的节点都会包含在‘hidden1’方块内。可以嵌套使用name_scope。

即时信息(汇总)

除了图结构,我们还关心训练过程的loss值等即时信息。也可以把这些信息写入tensorboard日志文件,然后从网页观察。汇总即时信息需要以下几步。

构建计算图时

为了保存即时信息,在构建计算图时需要构建即时信息的节点,并在最优化时指定global_step。比如loss信息的添加方式如下:

def training(loss, learning_rate):
  # 为保存loss的值添加一个标量汇总(scalar summary).
  tf.summary.scalar('loss', loss)
  # 根据给定的学习率创建梯度下降优化器
  optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  # 创建一个变量来跟踪global step.
  global_step = tf.Variable(0, name='global_step', trainable=False)
  # 在训练节点,使用optimizer将梯度下降法应用到可调参数上来最小化损失
  # (同时不断增加 global step 计数器) .
  train_op = optimizer.minimize(loss=loss,global_step=global_step)
  return train_op

除了标量,还可以汇总直方图、图片、音频等数据。

构建计算图之后

所有的即时数据(在这里只有一个)都要在图表构建阶段合并至一个操作(op)中。

summary_op = tf.merge_all_summaries()

session建立之后

在创建好会话(session)之后,可以实例化一个tf.train.SummaryWriter,用于写入包含了图表本身和即时数据具体值的事件文件。

# 实例化一个 SummaryWriter 输出 summaries 和 Graph.
summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

运行summary_op

最后,每次运行summary_op时,都会往事件文件中写入最新的即时数据,函数的输出会传入SummaryWriter的add_summary()函数,并指定当前的global_step。

if step % 100 == 0:
    # 更新事件文件.还是调用sess.run()方法
    summary_str = sess.run(summary, feed_dict=feed_dict)
    summary_writer.add_summary(summary_str, global_step=step)
    summary_writer.flush()

可视化总结

借用《TensorFlow实现全连接fully_connected_feed.py》中的一张图:
tensorflow 12:双隐层+softmax回归实现mnist图片识别之二-LMLPHP

即时信息的汇总效果如下:

tensorflow 12:双隐层+softmax回归实现mnist图片识别之二-LMLPHP

保存检查点(checkpoint)

为了得到可以用来后续恢复模型以进一步训练或评估的检查点文件(checkpoint file),我们实例化一个tf.train.Saver。这一步不依赖计算图的建立状态。

saver = tf.train.Saver(max_to_keep=5)

saver默认只保留最近的5个ckpt文件,可以通过max_to_keep来改变。

在训练循环中,将定期(隔一定步数)调用saver.save()方法,向训练文件夹中写入包含了当前所有可训练变量值得检查点文件。

      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
        saver.save(sess, checkpoint_file, global_step=step)

保存ckpt目录下的文件如下:

tensorflow 12:双隐层+softmax回归实现mnist图片识别之二-LMLPHP

这样,我们以后(暂停或断电后)就可以使用saver.restore()方法,重载模型的参数,继续训练。

saver.restore(sess, FLAGS.train_dir)

完整代码

上面说了这么多,还没上完整的fully_connected_feed.py代码。这个是我修改注释过的:

"""Trains and Evaluates the MNIST network using a feed dictionary."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=missing-docstring
import argparse
import os
import sys
import time

from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data
import mnist_softmax_2hidden

# 全局变量,用来存放基本的模型(超)参数.
FLAGS = None

# 产生 placeholder variables 来表达输入张量
def placeholder_inputs(batch_size):
  """Generate placeholder variables to represent the input tensors.

  These placeholders are used as inputs by the rest of the model building
  code and will be fed from the downloaded data in the .run() loop, below.

  Args:
    batch_size: The batch size will be baked into both placeholders.

  Returns:
    images_placeholder: Images placeholder.
    labels_placeholder: Labels placeholder.
  """
  # Note that the shapes of the placeholders match the shapes of the full
  # image and label tensors, except the first dimension is now batch_size
  # rather than the full size of the train or test data sets.
  images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
                                                         mnist_softmax_2hidden.IMAGE_PIXELS))
  labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
  return images_placeholder, labels_placeholder

# 填充 feed_dict 用于一个指定的训练阶段(given training step)
def fill_feed_dict(data_set, images_pl, labels_pl):
  """Fills the feed_dict for training the given step.

  A feed_dict takes the form of:
  feed_dict = {
      <placeholder>: <tensor of values to be passed for placeholder>,
      ....
  }

  Args:
    data_set: The set of images and labels, from input_data.read_data_sets()
    images_pl: The images placeholder, from placeholder_inputs().
    labels_pl: The labels placeholder, from placeholder_inputs().

  Returns:
    feed_dict: The feed dictionary mapping from placeholders to values.
  """
  # Create the feed_dict for the placeholders filled with the next
  # `batch size` examples.
  images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
                                                 FLAGS.fake_data)
  feed_dict = {
      images_pl: images_feed,
      labels_pl: labels_feed,
  }
  return feed_dict


def do_eval(sess,
            eval_correct,
            images_placeholder,
            labels_placeholder,
            data_set):
  """运行一个回合(one epoch)的评估过程.

  Args:
    sess: The session in which the model has been trained.
    eval_correct: The Tensor that returns the number of correct predictions.
    images_placeholder: The images placeholder.
    labels_placeholder: The labels placeholder.
    data_set: The set of images and labels to evaluate, from
      input_data.read_data_sets().
  """
  # And run one epoch of eval.
  true_count = 0  # Counts the number of correct predictions.
  steps_per_epoch = data_set.num_examples // FLAGS.batch_size
  num_examples = steps_per_epoch * FLAGS.batch_size
  for step in xrange(steps_per_epoch):
    feed_dict = fill_feed_dict(data_set,
                               images_placeholder,
                               labels_placeholder)
    true_count += sess.run(eval_correct, feed_dict=feed_dict)
  precision = float(true_count) / num_examples
  print('Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
        (num_examples, true_count, precision))


def run_training():
  """对MNIST网络训练指定的次数(一次训练称为一个training step)"""
  # 获取用于训练,验证和测试的图像数据以及类别标签集合
  data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)

  # 告诉TensorFlow,模型将会被构建在默认的Graph上.
  with tf.Graph().as_default():
    # 为图像特征向量数据和类标签数据创建输入占位符
    images_placeholder, labels_placeholder = placeholder_inputs(
        FLAGS.batch_size)

    # 从前向推断模型中构建用于预测的计算图
    logits = mnist_softmax_2hidden.inference(images_placeholder,
                             FLAGS.hidden1,
                             FLAGS.hidden2)

    # 为计算图添加计算损失的节点.
    loss = mnist_softmax_2hidden.loss(logits, labels_placeholder)

    # 为计算图添加计算和应用梯度的训练节点
    train_op = mnist_softmax_2hidden.training(loss, FLAGS.learning_rate)

    # 添加节点用于在评估过程中比较 logits 和 ground-truth labels .
    eval_correct = mnist_softmax_2hidden.evaluation(logits, labels_placeholder)

    # 基于 TF collection of Summaries构建汇总张量.
    summary = tf.summary.merge_all()

    # 添加变量初始化节点(variable initializer Op).
    init = tf.global_variables_initializer()

    # 创建一个 saver 用于写入 训练过程中的模型的检查点文件(checkpoints).
    saver = tf.train.Saver(max_to_keep=5)

    # 创建一个会话用来运行计算图中的节点
    sess = tf.Session()

    # 实例化一个 SummaryWriter 输出 summaries 和 Graph.
    summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

    # And then after everything is built:

    # 运行初始化节点来初始化所有变量(Variables).
    sess.run(init)

    # 开启训练循环.
    for step in xrange(FLAGS.max_steps):
      start_time = time.time()

      # 使用真实的图像和类标签数据集填充 feed dictionary
      feed_dict = fill_feed_dict(data_sets.train,
                                 images_placeholder,
                                 labels_placeholder)

      # 在当前批次样本上把模型运行一步(run one step).
      # 返回值是从`train_op`和`loss`节点拿到的activations
      _, loss_value = sess.run([train_op, loss],
                               feed_dict=feed_dict)

      duration = time.time() - start_time

      # 每隔100个批次就写入summaries并输出overview
      if step % 100 == 0:
        # Print status to stdout.
        print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
        # 更新事件文件.还是调用sess.run()方法
        summary_str = sess.run(summary, feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, global_step=step)
        summary_writer.flush()

      # 周期性的保存一个检查点文件并评估当前模型的性能
      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
        saver.save(sess, checkpoint_file, global_step=step)
        # 在所有训练集上评估模型
        print('Training Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.train)
        # 在验证集上评估模型.
        print('Validation Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.validation)
        # 在测试集上评估模型.
        print('Test Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.test)

# 创建日志文件夹,启动训练过程
def main(_):
  if tf.gfile.Exists(FLAGS.log_dir):
    tf.gfile.DeleteRecursively(FLAGS.log_dir)
  tf.gfile.MakeDirs(FLAGS.log_dir)
  run_training()


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--learning_rate',
      type=float,
      default=0.01,
      help='Initial learning rate.'
  )
  parser.add_argument(
      '--max_steps',
      type=int,
      default=2000,
      help='Number of steps to run trainer.'
  )
  parser.add_argument(
      '--hidden1',
      type=int,
      default=128,
      help='Number of units in hidden layer 1.'
  )
  parser.add_argument(
      '--hidden2',
      type=int,
      default=32,
      help='Number of units in hidden layer 2.'
  )
  parser.add_argument(
      '--batch_size',
      type=int,
      default=100,
      help='Batch size.  Must divide evenly into the dataset sizes.'
  )
  parser.add_argument(
      '--input_data_dir',
      type=str,
      default='./data',
      help='Directory to put the input data.'
  )
  parser.add_argument(
      '--log_dir',
      type=str,
      default='./logs',
      help='Directory to put the log data.'
  )
  parser.add_argument(
      '--fake_data',
      default=False,
      help='If true, uses fake data for unit testing.',
      action='store_true'
  )

  # 把模型的(超)参数全部解析到全局变量FLAGS里面
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

参考资料

TensorFlow运作方式入门

TensorFlow实现全连接fully_connected_feed.py

tensorflow 11:双隐层+softmax回归实现mnist图片识别

tensorflow 9. 参数解析和经典入口函数tf.app.run

10-07 09:01