pytorch的数据加载和处理教程非常具体地针对一个示例,有人可以帮我了解更通用的简单图像加载功能应该是什么样的吗?

教程:http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

我的资料:

我在以下文件夹结构中将MINST数据集作为jpg包含在其中。 (我知道我可以只使用数据集类,但这纯粹是为了了解如何在不使用csv或复杂功能的情况下将简单图像加载到pytorch中)。

文件夹名称是标签,图像是灰度的28x28 png,无需任何转换。

data
    train
        0
            3.png
            5.png
            13.png
            23.png
            ...
        1
            3.png
            10.png
            11.png
            ...
        2
            4.png
            13.png
            ...
        3
            8.png
            ...
        4
            ...
        5
            ...
        6
            ...
        7
            ...
        8
            ...
        9
            ...

最佳答案

这是我为pytorch 0.4.1做的(仍应在1.3中使用)

def load_dataset():
    data_path = 'data/train/'
    train_dataset = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=64,
        num_workers=0,
        shuffle=True
    )
    return train_loader

for batch_idx, (data, target) in enumerate(load_dataset()):
    #train network

关于python - 如何将MNIST图像加载到Pytorch DataLoader中?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/50052295/

10-13 06:59