BlockingQueue线程安全的队列, 作为caffe训练时数据同步的重要数据结构,本文做简要分析。


template<typename T>
class BlockingQueue {
public:
    explicit BlockingQueue();

    void push(const T& t);
    bool try_pop(T* t);// non-blocking
    // This logs a message if the threads needs to be blocked
    // useful for detecting e.g. when data feeding is too slow
    T pop(const string& log_on_wait = "");
    bool try_peek(T* t);

    // Return element without removing it
    T peek();
    size_t size() const;

protected:
    class sync;  // class froward decalration
    std::queue<T> queue_;
    shared_ptr<sync> sync_;
    DISABLE_COPY_AND_ASSIGN(BlockingQueue);
};

线程同步的条件变量:


template<typename T>
class BlockingQueue<T>::
sync {
public:
    mutable boost::mutex mutex_;
    boost::condition_variable condition_;
};

构造函数:


template<typename T>
BlockingQueue<T>::BlockingQueue()
        : sync_(new sync()) {
}

下面仅仅给出阻塞pop与非阻塞的try_pop,说明BlockingQueue的使用:

template<typename T>
bool BlockingQueue<T>::try_pop(T* t) {
    boost::mutex::scoped_lock lock(sync_->mutex_);

    if (queue_.empty()) {
        return false;  // 立即返回
    }

    *t = queue_.front();
    queue_.pop();
    return true;
}

template<typename T>
T BlockingQueue<T>::pop(const string& log_on_wait) {
    boost::mutex::scoped_lock lock(sync_->mutex_);

    while (queue_.empty()) {
        if (!log_on_wait.empty()) {
            LOG_EVERY_N(INFO, 1000)<< log_on_wait;
        }
        sync_->condition_.wait(lock); //阻塞等待条件变量
    }

    T t = queue_.front();
    queue_.pop();
    return t;
}

模板特化:

template class BlockingQueue<Batch<float>*>;
template class BlockingQueue<Batch<double>*>;
template class BlockingQueue<Datum*>;
template class BlockingQueue<shared_ptr<DataReader::QueuePair> >;

其中:

template <typename Dtype>
class Batch {
public:
    Blob<Dtype> data_, label_;
};

DataLayer中使用线程读取Batch(image,label)push到队列中,然后pop出来前向传播:

template <typename Dtype>
void BasePrefetchingDataLayer<Dtype>::Forward_cpu(
        const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {

    Batch<Dtype>* batch = prefetch_full_.pop("Data layer prefetch queue empty");
    // Reshape to loaded data.
    top[0]->ReshapeLike(batch->data_);

    // Copy the data
    caffe_copy(batch->data_.count(), batch->data_.cpu_data(),
               top[0]->mutable_cpu_data());

    if (this->output_labels_) {
        // Reshape to loaded labels.
        top[1]->ReshapeLike(batch->label_);
        // Copy the labels.
        caffe_copy(batch->label_.count(), batch->label_.cpu_data(),
                   top[1]->mutable_cpu_data());
    }

    prefetch_free_.push(batch);
}

BlockingQueue成员:

template <typename Dtype>
class BasePrefetchingDataLayer :
        public BaseDataLayer<Dtype>, public InternalThread {
	//.......
protected:
    Batch<Dtype> prefetch_[PREFETCH_COUNT];
    BlockingQueue<Batch<Dtype>*> prefetch_free_;
    BlockingQueue<Batch<Dtype>*> prefetch_full_;
};
10-07 10:45