经过测试,没有发现比多线程块

import datetime
import os
import threading
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import logging
import cv2

import torch
from torch.utils.data import Dataset
import xml.etree.ElementTree as ET
from common import data_transforms
CLASSES = ('mouse',)

def aaa(listDataset):
    for index in range(len(listDataset.img_files)):
        if index in listDataset.img_d:
            pass
        else:
            img_path = listDataset.img_files[index % len(listDataset.img_files)].rstrip()
            img = cv2.imread(img_path, cv2.IMREAD_COLOR)
            if img is None:
                raise Exception("Read image error: {}".format(img_path))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            h, w, c = img.shape
            label_path = listDataset.label_files[index % len(listDataset.img_files)].rstrip()
            if os.path.exists(label_path):
                # labels = np.loadtxt(label_path).reshape(-1, 5)

                tree = ET.parse(label_path)
                objs = tree.findall('object')
                # num_objs = len(objs)
                # labels = np.zeros(shape=(num_objs, 5), dtype=np.float64)
                labels = []
                for ix, obj in enumerate(objs):
                    if obj.find("difficult") is not None and obj.find("difficult").text == '1':
                        continue
                    bbox = obj.find('bndbox')
                    x1 = max(float(bbox.find('xmin').text), 1)  # - 1
                    y1 = max(float(bbox.find('ymin').text), 1)  # - 1
                    x2 = min(float(bbox.find('xmax').text), 1279)  # - 1
                    y2 = min(float(bbox.find('ymax').text), 719)  # - 1
                    cls = listDataset._class_to_ind[obj.find('name').text.lower().strip()]
                    label_ = [cls, ((x1 + x2) / 2) / w, ((y1 + y2) / 2) / h, (x2 - x1) / w, (y2 - y1) / h]
                    # labels[ix, :] = [cls, ((x1 + x2) / 2) / padded_w, ((y1 + y2) / 2) / padded_h,
                    #                  w / padded_w, h / padded_h]
                    # labels[ix, :] = [cls, ((x1 + x2) / 2) / padded_w, ((y1 + y2) / 2) / padded_h,
                    #                  w_new / padded_w, h_new / padded_h]
                    labels.append(label_)
                labels = np.asarray(labels)
            else:
                logging.info("label does not exist: {}".format(label_path))
                labels = np.zeros((1, 5), np.float32)

            sample = {'image': img, 'label': labels, "img_path": img_path}
            if listDataset.transforms is not None:
                sample = listDataset.transforms(sample)
            listDataset.img_d[index] = sample
class MyThread(threading.Thread):
    def __init__(self,arg):
        super(MyThread, self).__init__()#注意:一定要显式的调用父类的初始化函数。
        self.listDataset=arg

    def run(self):#定义每个线程要运行的函数

        import time
        from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, Executor

        pool = ProcessPoolExecutor(max_workers=30)
        results = list(pool.map(aaa, self.listDataset))
        print("data_load ok")


        # for index in range(len(self.listDataset.img_files)):
        #   if index in self.listDataset.img_d:
        #       pass
        #   else:
        #       img_path = self.listDataset.img_files[index % len(self.listDataset.img_files)].rstrip()
        #       img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        #       if img is None:
        #           raise Exception("Read image error: {}".format(img_path))
        #       img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        #       h, w, c = img.shape
        #       label_path = self.listDataset.label_files[index % len(self.listDataset.img_files)].rstrip()
        #       if os.path.exists(label_path):
        #           # labels = np.loadtxt(label_path).reshape(-1, 5)
        #
        #           tree = ET.parse(label_path)
        #           objs = tree.findall('object')
        #           # num_objs = len(objs)
        #           # labels = np.zeros(shape=(num_objs, 5), dtype=np.float64)
        #           labels = []
        #           for ix, obj in enumerate(objs):
        #               if obj.find("difficult") is not None and obj.find("difficult").text == '1':
        #                   continue
        #               bbox = obj.find('bndbox')
        #               x1 = max(float(bbox.find('xmin').text), 1)  # - 1
        #               y1 = max(float(bbox.find('ymin').text), 1)  # - 1
        #               x2 = min(float(bbox.find('xmax').text), 1279)  # - 1
        #               y2 = min(float(bbox.find('ymax').text), 719)  # - 1
        #               cls = self.listDataset._class_to_ind[obj.find('name').text.lower().strip()]
        #               label_ = [cls, ((x1 + x2) / 2) / w, ((y1 + y2) / 2) / h, (x2 - x1) / w, (y2 - y1) / h]
        #               # labels[ix, :] = [cls, ((x1 + x2) / 2) / padded_w, ((y1 + y2) / 2) / padded_h,
        #               #                  w / padded_w, h / padded_h]
        #               # labels[ix, :] = [cls, ((x1 + x2) / 2) / padded_w, ((y1 + y2) / 2) / padded_h,
        #               #                  w_new / padded_w, h_new / padded_h]
        #               labels.append(label_)
        #           labels = np.asarray(labels)
        #       else:
        #           logging.info("label does not exist: {}".format(label_path))
        #           labels = np.zeros((1, 5), np.float32)
        #
        #       sample = {'image': img, 'label': labels,"img_path":img_path}
        #       if self.listDataset.transforms is not None:
        #           sample = self.listDataset.transforms(sample)
        #       self.listDataset.img_d[index] = sample
        print("data_load ok")
class COCODataset(Dataset):
    def __init__(self, list_path, img_size, is_training, is_debug=False,data_size=1440*10,is_scene=False):

        if is_scene:
            all_files = [list(map(lambda x: os.path.join(root, x), files)) for root, _, files in
                         os.walk(list_path, topdown=False) if os.path.basename(root) == 'Annotations']

            self.label_files = []
            for i in range(len(all_files)):
                self.label_files += all_files[i]
            if len(self.label_files)>data_size:
                self.label_files=self.label_files[:data_size]
            self.img_files = [file.replace('Annotations', 'JPEGImages').replace('xml', 'jpg') for file in
                              self.label_files]
        else:
            list_path_txt = os.path.join(list_path, 'ImageSets\Main/trainval.txt')
            if not is_training:
                list_path_txt = os.path.join(list_path, 'ImageSets\Main/test.txt')
            with open(list_path_txt, 'r') as file:
                # with open(list_path, 'r') as file:
                self.train_files_ = file.readlines()
            if len(self.train_files_) > data_size:
                self.train_files_ = self.train_files_[:data_size]
            if len(self.label_files) > data_size:
                self.label_files = self.label_files[:data_size]
            self.img_files = [os.path.join(list_path, 'JPEGImages', '%s.jpg' % train_file.strip('\n')) for train_file in
                              self.train_files_]
            self.label_files = [os.path.join(list_path, 'Annotations', '%s.xml' % train_file.strip('\n')) for train_file
                                in
                                self.train_files_]

        self.img_size = img_size  # (w, h)
        self.max_objects = 10
        self.is_debug = is_debug

        #  transforms and augmentation
        self.transforms = data_transforms.Compose()
        # if is_training:
        #     self.transforms.add(data_transforms.ImageBaseAug())
        # self.transforms.add(data_transforms.KeepAspect())
        self.transforms.add(data_transforms.ResizeImage(self.img_size))
        self.transforms.add(data_transforms.ToTensor(self.max_objects, self.is_debug))

        self._class_to_ind = dict(list(zip(CLASSES, list(range(len(CLASSES))))))
        self.img_d = {}
        self.lock = threading.RLock()

        executor = ThreadPoolExecutor(max_workers=20)
        future = executor.submit(aaa,self)

        # t = MyThread(self)
        # t.start()
    def __getitem__(self, index):
        if index in self.img_d:
            return self.img_d[index]
        else:
            # self.lock.acquire()
            img_path = self.img_files[index % len(self.img_files)].rstrip()
            img = cv2.imread(img_path, cv2.IMREAD_COLOR)
            if img is None:
                raise Exception("Read image error: {}".format(img_path))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            h,w,c = img.shape
            label_path = self.label_files[index % len(self.img_files)].rstrip()
            if os.path.exists(label_path):
                # labels = np.loadtxt(label_path).reshape(-1, 5)

                tree = ET.parse(label_path)
                objs = tree.findall('object')
                # num_objs = len(objs)
                # labels = np.zeros(shape=(num_objs, 5), dtype=np.float64)
                labels = []
                for ix, obj in enumerate(objs):
                    if obj.find("difficult") is not None and obj.find("difficult").text == '1':
                        continue
                    bbox = obj.find('bndbox')
                    x1 = max(float(bbox.find('xmin').text), 1)  # - 1
                    y1 = max(float(bbox.find('ymin').text), 1)  # - 1
                    x2 = min(float(bbox.find('xmax').text), 1279)  # - 1
                    y2 = min(float(bbox.find('ymax').text), 719)  # - 1
                    cls = self._class_to_ind[obj.find('name').text.lower().strip().replace("mousse", "mouse")]
                    label_ = [cls,((x1 + x2) / 2)/w,((y1 + y2) / 2) / h,(x2-x1)/w,(y2-y1)/h]
                    # labels[ix, :] = [cls, ((x1 + x2) / 2) / padded_w, ((y1 + y2) / 2) / padded_h,
                    #                  w / padded_w, h / padded_h]
                    # labels[ix, :] = [cls, ((x1 + x2) / 2) / padded_w, ((y1 + y2) / 2) / padded_h,
                    #                  w_new / padded_w, h_new / padded_h]
                    labels.append(label_)
                labels = np.asarray(labels)
            else:
                logging.info("label does not exist: {}".format(label_path))
                labels = np.zeros((1, 5), np.float32)

            sample = {'image': img, 'label': labels,"img_path":img_path}
            if self.transforms is not None:
                sample = self.transforms(sample)
            self.img_d[index] = sample
            # self.lock.release()
            return sample

    def __len__(self):
        return len(self.img_files)


#  use for test dataloader
if __name__ == "__main__":
    dataloader = torch.utils.data.DataLoader(COCODataset(r"D:\data\tiny_data\VOC2007",
                                                         (416, 416),is_training=True, is_debug=True),
                                             batch_size=16,
                                             shuffle=False, num_workers=0, pin_memory=False)
    for step, sample in enumerate(dataloader):
        for i, (image, label) in enumerate(zip(sample['image'], sample['label'])):
            image = image.numpy()
            h, w = image.shape[:2]
            for l in label:
                if l.sum() == 0:
                    continue
                x1 = int((l[1] - l[3] / 2) * w)
                y1 = int((l[2] - l[4] / 2) * h)
                x2 = int((l[1] + l[3] / 2) * w)
                y2 = int((l[2] + l[4] / 2) * h)
                cv2.rectangle(image, (x1, y1), (x2, y2), (0, 0, 255))
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
            print(datetime.datetime.now())
            # cv2.imwrite("step{}_{}.jpg".format(step, i), image)
        # only one batch
        # break
 

10-05 22:00