本文介绍了Tensorflow:从numpy数组>创建minibatch 2 GB的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试将小批的numpy数组供入我的模型中,但是我一直坚持使用批处理.使用'tf.train.shuffle_batch'会引发错误,因为'images'数组大于2 GB.我试图解决它并创建占位符,但是当我尝试提供数组时,它们仍然由tf.Tensor对象表示.我主要担心的是,我在模型类下定义了操作,并且在运行会话之前不会调用对象.有谁知道如何处理此问题?

I am trying to feed minibatches of numpy arrays to my model, but I'm stuck with batching. Using 'tf.train.shuffle_batch' raises an error because the 'images' array is larger than 2 GB. I tried to go around it and create placeholders, but when I try to feed the the arrays they are still represented by tf.Tensor objects. My main concern is that I defined the operations under the model class and the objects don't get called before running the session. Does anyone have an idea how to handle this issue?

def main(mode, steps):
  config = Configuration(mode, steps)



  if config.TRAIN_MODE:

      images, labels = read_data(config.simID)

      assert images.shape[0] == labels.shape[0]

      images_placeholder = tf.placeholder(images.dtype,
                                                images.shape)
      labels_placeholder = tf.placeholder(labels.dtype,
                                                labels.shape)

      dataset = tf.data.Dataset.from_tensor_slices(
                (images_placeholder, labels_placeholder))

      # shuffle
      dataset = dataset.shuffle(buffer_size=1000)

      # batch
      dataset = dataset.batch(batch_size=config.batch_size)

      iterator = dataset.make_initializable_iterator()

      image, label = iterator.get_next()

      model = Model(config, image, label)

      with tf.Session() as sess:

          sess.run(tf.global_variables_initializer())

          sess.run(iterator.initializer, 
                   feed_dict={images_placeholder: images,
                          labels_placeholder: labels})

          # ...

          for step in xrange(steps):

              sess.run(model.optimize)

推荐答案

您正在使用可初始化tf.Data的迭代器,以将数据提供给您的模型.这意味着您可以根据占位符对数据集进行参数化,然后为迭代器调用初始化程序op以准备使用.

You are using the initializable iterator of tf.Data to feed data to your model. This means that you can parametrize the dataset in terms of placeholders, and then call an initializer op for the iterator to prepare it for use.

如果您使用可初始化的迭代器或tf.Data中的任何其他迭代器将输入馈入模型,则不应使用sess.runfeed_dict参数尝试进行数据馈送.而是根据iterator.get_next()的输出定义模型,并从sess.run省略feed_dict.

In case you use the initializable iterator, or any other iterator from tf.Data to feed inputs to your model, you should not use the feed_dict argument of sess.run to try to do data feeding. Instead, define your model in terms of the outputs of iterator.get_next() and omit the feed_dict from sess.run.

遵循以下原则:

iterator = dataset.make_initializable_iterator()
image_batch, label_batch = iterator.get_next()

# use get_next outputs to define model
model = Model(config, image_batch, label_batch) 

# placeholders fed in while initializing the iterator
sess.run(iterator.initializer, 
            feed_dict={images_placeholder: images,
                       labels_placeholder: labels})

for step in xrange(steps):
     # iterator will feed image and label in the background
     sess.run(model.optimize) 

迭代器将在后台将数据提供给您的模型,无需通过feed_dict进行其他提供.

The iterator will feed data to your model in the background, additional feeding via feed_dict is not necessary.

这篇关于Tensorflow:从numpy数组>创建minibatch 2 GB的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

09-21 00:41