关注B站查看更多手把手教学:

基本用法

torch.utils.data.Dataset 是 PyTorch 中一个非常重要的抽象类,它用于表示数据集,方便数据加载和预处理。通过实现这个类的两个方法 __len____getitem__,你可以自定义自己的数据集类。__len__ 方法应返回数据集的大小(即样本数),而 __getitem__ 方法则根据给定的索引返回一个样本。

以下是一个简单的示例,说明如何使用 torch.utils.data.Dataset 创建一个自定义的数据集类:

import torch  
from torch.utils.data import Dataset  
  
class MyCustomDataset(Dataset):  
    def __init__(self, data, targets):  
        """  
        参数:  
            data: 样本数据, 形状为 [num_samples, ...] (例如 [num_samples, num_channels, height, width])  
            targets: 样本标签, 形状为 [num_samples, ...] (例如 [num_samples])  
        """  
        self.data = data  
        self.targets = targets  
  
    def __len__(self):  
        # 返回数据集的样本数  
        return len(self.data)  
  
    def __getitem__(self, idx):  
        # 根据索引 idx 返回一个样本 (数据和标签)  
        return self.data[idx], self.targets[idx]  
  
# 示例数据和标签  
X = torch.randn(100, 3, 32, 32)  # 假设有 100 个 3x32x32 的样本  
y = torch.randint(0, 10, (100,))  # 假设有 100 个对应的标签 (0-9)  
  
# 创建数据集实例  
dataset = MyCustomDataset(X, y)  
  
# 可以使用 len() 获取数据集大小  
print(len(dataset))  # 输出: 100  
  
# 可以使用索引获取样本  
sample, label = dataset[0]  # 获取第一个样本和标签  
print(sample.shape)  # 输出: torch.Size([3, 32, 32])  
print(label)  # 输出: 一个整数 (0-9)

在上面的示例中,我们创建了一个名为 MyCustomDataset 的自定义数据集类,该类继承自 torch.utils.data.Dataset。在类的构造函数中,我们接收样本数据和标签,并将它们存储在类的实例变量中。我们还实现了 __len____getitem__ 方法,分别用于返回数据集的大小和根据索引获取样本。最后,我们创建了一个数据集实例,并展示了如何使用它来获取数据集的大小和样本。

标准数据集

在PyTorch的torchvision.datasets模块中,包含了多个标准的数据集,这些数据集在计算机视觉领域非常流行。以下是一些常用的标准数据集:

  1. MNIST:手写数字识别数据集,包含了大量的手写数字图片和对应的标签。
  2. CIFAR:包含CIFAR-10和CIFAR-100两个数据集,分别用于10类和100类的小图片分类任务。
  3. ImageNet:一个大规模的图片分类数据集,包含了上千万张标注过的图片,通常用于训练深度神经网络。在torchvision.datasets中,可以通过ImageFolder类来加载按文件夹组织的ImageNet风格的数据集。虽然完整的ImageNet数据集很大并不直接包含在torchvision.datasets中,但PyTorch提供了处理这种数据集的工具。
  4. COCO (Common Objects in Context):用于图像标注、目标检测和语义分割的大型数据集。它包含了图片、物体的标注框、分割掩码以及关键点等信息。
  5. LSUN (Large-scale Scene UNderstanding):场景理解的大型数据集,包含了不同类别的场景图片。
  6. FashionMNIST:类似于MNIST,但是用于时尚服装和配饰的图片分类。
  7. SVHN (Street View House Numbers):从谷歌街景图片中提取的门牌号识别数据集。
  8. PhotoTour:用于图像匹配的数据集,包含了从不同角度拍摄的同一景点的图片对。
  9. STL10:一个用于无监督学习和半监督学习的图像数据集,包含了少量的标注数据和大量的无标注数据。
  10. Kinetics:用于视频动作识别的大型数据集。
  11. CelebA (CelebFaces Attributes):用于人脸检测和属性识别的大型人脸数据集。

这些标准数据集可以通过简单地调用torchvision.datasets中的相应类来加载和预处理。例如,加载MNIST数据集可以通过以下代码实现:

import torchvision.datasets as dsets  
  
# 加载MNIST训练集  
train_dataset = dsets.MNIST(root='./data',  
                            train=True,  
                            transform=transforms.ToTensor(),  
                            download=True)  
  
# 加载MNIST测试集  
test_dataset = dsets.MNIST(root='./data',  
                           train=False,  
                           transform=transforms.ToTensor())

注意,上面的代码中使用了transforms.ToTensor()来对图片进行预处理,将其转换为PyTorch的Tensor格式。在实际使用中,你可能还需要根据具体任务添加其他的预处理步骤,比如裁剪、归一化等。这些都可以通过组合torchvision.transforms中的不同变换来实现。

03-11 10:56