问题描述
我正在尝试将小批的numpy数组供入我的模型中,但是我一直坚持使用批处理.使用'tf.train.shuffle_batch'会引发错误,因为'images'数组大于2 GB.我试图解决它并创建占位符,但是当我尝试提供数组时,它们仍然由tf.Tensor对象表示.我主要担心的是,我在模型类下定义了操作,并且在运行会话之前不会调用对象.有谁知道如何处理此问题?
I am trying to feed minibatches of numpy arrays to my model, but I'm stuck with batching. Using 'tf.train.shuffle_batch' raises an error because the 'images' array is larger than 2 GB. I tried to go around it and create placeholders, but when I try to feed the the arrays they are still represented by tf.Tensor objects. My main concern is that I defined the operations under the model class and the objects don't get called before running the session. Does anyone have an idea how to handle this issue?
def main(mode, steps):
config = Configuration(mode, steps)
if config.TRAIN_MODE:
images, labels = read_data(config.simID)
assert images.shape[0] == labels.shape[0]
images_placeholder = tf.placeholder(images.dtype,
images.shape)
labels_placeholder = tf.placeholder(labels.dtype,
labels.shape)
dataset = tf.data.Dataset.from_tensor_slices(
(images_placeholder, labels_placeholder))
# shuffle
dataset = dataset.shuffle(buffer_size=1000)
# batch
dataset = dataset.batch(batch_size=config.batch_size)
iterator = dataset.make_initializable_iterator()
image, label = iterator.get_next()
model = Model(config, image, label)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(iterator.initializer,
feed_dict={images_placeholder: images,
labels_placeholder: labels})
# ...
for step in xrange(steps):
sess.run(model.optimize)
推荐答案
您正在使用可初始化tf.Data
的迭代器,以将数据提供给您的模型.这意味着您可以根据占位符对数据集进行参数化,然后为迭代器调用初始化程序op以准备使用.
You are using the initializable iterator of tf.Data
to feed data to your model. This means that you can parametrize the dataset in terms of placeholders, and then call an initializer op for the iterator to prepare it for use.
如果您使用可初始化的迭代器或tf.Data
中的任何其他迭代器将输入馈入模型,则不应使用sess.run
的feed_dict
参数尝试进行数据馈送.而是根据iterator.get_next()
的输出定义模型,并从sess.run
省略feed_dict
.
In case you use the initializable iterator, or any other iterator from tf.Data
to feed inputs to your model, you should not use the feed_dict
argument of sess.run
to try to do data feeding. Instead, define your model in terms of the outputs of iterator.get_next()
and omit the feed_dict
from sess.run
.
遵循以下原则:
iterator = dataset.make_initializable_iterator()
image_batch, label_batch = iterator.get_next()
# use get_next outputs to define model
model = Model(config, image_batch, label_batch)
# placeholders fed in while initializing the iterator
sess.run(iterator.initializer,
feed_dict={images_placeholder: images,
labels_placeholder: labels})
for step in xrange(steps):
# iterator will feed image and label in the background
sess.run(model.optimize)
迭代器将在后台将数据提供给您的模型,无需通过feed_dict
进行其他提供.
The iterator will feed data to your model in the background, additional feeding via feed_dict
is not necessary.
这篇关于Tensorflow:从numpy数组>创建minibatch 2 GB的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!