我对整个领域有点陌生,因此决定使用MNIST数据集。我几乎修改了https://github.com/pytorch/examples/blob/master/mnist/main.py中的整个代码,只有一个重大更改:数据加载。我不想在Torchvision中使用预加载的数据集。所以我用MNIST in CSV

我是通过从数据集继承并制作新的数据加载器来从CSV文件加载数据的。
以下是相关代码:

mean = 33.318421449829934
sd = 78.56749081851163
# mean = 0.1307
# sd = 0.3081
import numpy as np
from torch.utils.data import Dataset, DataLoader

class dataset(Dataset):
    def __init__(self, csv, transform=None):
        data = pd.read_csv(csv, header=None)
        self.X = np.array(data.iloc[:, 1:]).reshape(-1, 28, 28, 1).astype('float32')
        self.Y = np.array(data.iloc[:, 0])

        del data
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        item = self.X[idx]
        label = self.Y[idx]

        if self.transform:
            item = self.transform(item)

        return (item, label)

import torchvision.transforms as transforms
trainData = dataset('mnist_train.csv', transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((mean,), (sd,))
]))
testData = dataset('mnist_test.csv', transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((mean,), (sd,))
]))

train_loader = DataLoader(dataset=trainData,
                         batch_size=10,
                         shuffle=True,
                         )
test_loader = DataLoader(dataset=testData,
                        batch_size=10,
                        shuffle=True,
                        )


但是,此代码为我提供了您在图片中看到的绝对奇怪的训练错误图,以及11%的最终验证错误,因为它将所有内容归类为“ 7”。


我设法将问题归结为如何规范化数据,以及是否使用示例代码中给出的值(0.1307和0.3081)进行转换.Normalize以及将数据读取为'uint8'类型都可以正常工作。
请注意,在这两种情况下提供的数据差异很小。对0到1的值进行0.1307和0.3081归一化,与对0到255的值进行33.31和78.56归一化具有相同的效果。该值甚至大体相同(黑色像素对应于-0.4241,而-0.4242在第二)。

如果您想在IPython Notebook中清楚地看到此问题,请查看https://colab.research.google.com/drive/1W1qx7IADpnn5e5w97IcxVvmZAaMK9vL3

我无法理解是什么原因导致这两种加载数据的方式略有不同。任何帮助将不胜感激。

最佳答案

长话短说:您需要将item = self.X[idx]更改为item = self.X[idx].copy()

长话短说:T.ToTensor()运行torch.from_numpy,它返回一个张量,该张量将numpy数组dataset.X的内存作为别名。还有T.Normalize() works inplace,因此每次抽取样本时都会减去mean并除以std,从而导致数据集退化。

编辑:关于为什么它可以在原始MNIST加载程序中运行,因此兔子洞甚至更深。 MNIST中的关键行是将映像transformed放入PIL.Image实例中。该操作声称仅在缓冲区不连续的情况下才复制(在我们的情况下),但是在hood下,它检查是否跨步(它是跨步的),从而进行复制。因此,幸运的是,默认的Torchvision管道涉及一个副本,因此T.Normalize()的就地操作不会破坏我们self.data实例的内存中的MNIST

关于python - MNIST Pytorch中的验证错误意外增加,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/53652015/

10-12 19:39