news 2026/6/22 19:27:39

Day 38 - Dataset 和 DataLoader

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day 38 - Dataset 和 DataLoader

在深度学习任务中,数据处理是至关重要的一环。面对大规模数据集,显存往往无法一次性存储所有数据,因此需要采用分批训练(Batch Training)的策略。PyTorch 提供了两个核心工具类来解决数据加载和预处理的问题:DatasetDataLoader

本文将深入探讨这两个类的原理、用法以及它们之间的关系,并以经典的 MNIST 手写数字数据集为例进行演示。

一、 PyTorch 数据处理核心架构

在 PyTorch 中,数据处理流程被解耦为两个独立的部分:

  1. Dataset (数据集):负责定义“数据是什么”,即如何获取单个样本及其对应的标签,以及如何进行预处理。
  2. DataLoader (数据加载器):负责定义“如何加载数据”,即如何将 Dataset 中的样本组装成批次(Batch),并提供多线程加载、随机打乱等功能。

形象比喻

  • Dataset就像是厨师,他的工作是负责把每一个菜品(样本)切好、洗好、调好味(预处理)。
  • DataLoader就像是服务员,他的工作是把厨师做好的菜品,按照订单的要求(Batch Size),打包好端给客人(模型)。

二、 Dataset 类详解

torch.utils.data.Dataset是一个抽象基类,所有自定义的数据集都必须继承它,并实现其核心接口。

1. 核心魔术方法

PyTorch 要求 Dataset 子类必须实现以下两个魔术方法(Magic Methods):

  • __len__(self)
    • 作用:返回数据集的样本总数。
    • 调用方式:当使用len(dataset)时自动调用。
    • 意义:DataLoader 需要知道数据集的大小,以便计算一个 Epoch 需要多少个 Batch。
  • __getitem__(self, idx)
    • 作用:根据索引idx获取单个样本的数据和标签。
    • 调用方式:当使用dataset[idx]时自动调用。
    • 意义:这是数据读取和预处理发生的地方。

2. Python 魔术方法原理解析

为了更好地理解__len____getitem__,我们来看一个简单的 Python 自定义类示例:

class MyList: def __init__(self): self.data = [10, 20, 30, 40, 50] # 实现索引访问功能 def __getitem__(self, idx): return self.data[idx] # 实现长度获取功能 def __len__(self): return len(self.data) # 实例化对象 my_list_obj = MyList() # 1. 测试 __getitem__ # 对象可以直接使用 [] 索引访问,像内置列表一样 print(f"索引为2的元素: {my_list_obj[2]}") # 输出: 30 # 2. 测试 __len__ # 对象可以直接使用 len() 函数 print(f"列表长度: {len(my_list_obj)}") # 输出: 5

3. 自定义 Dataset 示例

基于上述原理,一个典型的自定义 Dataset 结构如下:

from torch.utils.data import Dataset class MNIST(Dataset): def __init__(self, root, train=True, transform=None): """ 初始化:加载文件路径、标签文件等 """ # 假设 fetch_mnist_data 是一个自定义函数,用于读取数据 self.data, self.targets = fetch_mnist_data(root, train) self.transform = transform # 预处理操作流水线 def __len__(self): """ 返回数据集大小 """ return len(self.data) def __getitem__(self, idx): """ 获取指定索引 idx 的样本 """ # 1. 根据索引获取原始数据和标签 img, target = self.data[idx], self.targets[idx] # 2. 应用预处理(如转 Tensor、归一化等) if self.transform is not None: img = self.transform(img) return img, target

三、 实战:使用 torchvision 加载 MNIST

torchvision是 PyTorch 官方的计算机视觉库,其中torchvision.datasets模块内置了许多常用数据集(如 MNIST, CIFAR10, ImageNet 等),它们都已经实现了 Dataset 的接口。

1. 数据预处理 (Transforms)

在加载图像数据时,通常需要进行一系列预处理,如转为张量(Tensor)、归一化(Normalize)等。

from torchvision import transforms # 定义预处理流水线 transform = transforms.Compose([ transforms.ToTensor(), # 将图像转换为 PyTorch 张量,并将像素值归一化到 [0, 1] transforms.Normalize((0.1307,), (0.3081,)) # 标准化:(x - mean) / std。参数为 MNIST 数据集的全局均值和标准差 ])

2. 加载数据集

from torchvision import datasets # 加载训练集 train_dataset = datasets.MNIST( root='./data', # 数据存储路径 train=True, # True 表示加载训练集 download=True, # 如果路径下不存在数据,是否自动下载 transform=transform # 应用上面定义的预处理 ) # 加载测试集 test_dataset = datasets.MNIST( root='./data', train=False, # False 表示加载测试集 transform=transform )

注意:在 PyTorch 的设计哲学中,数据预处理通常是在加载阶段(即__getitem__被调用时)动态进行的,而不是先处理好再保存。这样做可以节省磁盘空间,并支持动态的数据增强。

3. 查看单个样本

由于train_dataset本质上是一个 Dataset 子类,我们可以直接通过索引访问:

import matplotlib.pyplot as plt import torch # 随机获取一个索引 sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 获取样本(自动调用 __getitem__) image, label = train_dataset[sample_idx] # 可视化(需要反归一化以便人眼观察) def imshow(img): img = img * 0.3081 + 0.1307 # 反标准化 npimg = img.numpy() plt.imshow(npimg[0], cmap='gray') plt.show() print(f"Label: {label}") imshow(image)

四、 DataLoader 类详解

torch.utils.data.DataLoader是 PyTorch 中用于加载数据的核心工具。它接收一个 Dataset 对象,并根据配置参数生成一个可迭代对象。

1. 核心功能

DataLoader 的主要职责包括:

  • Batching:将多个样本打包成一个批次。
  • Shuffling:在每个 Epoch 开始时打乱数据顺序,防止模型记忆数据的顺序特征。
  • Multiprocessing:使用多进程并行加载数据,加速数据准备过程(避免 CPU 成为瓶颈)。

2. 创建 DataLoader

from torch.utils.data import DataLoader # 训练集加载器 train_loader = DataLoader( train_dataset, batch_size=64, # 每个批次包含 64 个样本 shuffle=True # 训练时通常需要打乱数据 ) # 测试集加载器 test_loader = DataLoader( test_dataset, batch_size=1000, # 测试时显存压力较小,可以使用更大的 batch_size shuffle=False # 测试时不需要打乱顺序,以便结果对比 )

关于 Batch Size 的选择

通常选择 2 的幂次方(如 32, 64, 128),这有利于 GPU 的并行计算效率。

五、 总结:Dataset 与 DataLoader 的对比

为了清晰地区分这两个概念,我们可以从以下几个维度进行对比:

维度

Dataset

DataLoader

核心职责

定义“数据内容”和“单个样本获取方式”

定义“批量加载策略”和“迭代方式”

核心方法

__getitem__(获取单个),__len__(总数)

内部实现迭代器协议 (__iter__)

预处理位置

__getitem__中定义具体的转换逻辑

不负责预处理,直接使用 Dataset 返回的结果

并行处理

无(仅处理单样本逻辑)

支持多进程加载 (num_workers)

关键参数

root(路径),transform(变换)

batch_size,shuffle,num_workers

一句话总结

Dataset负责把数据从磁盘读出来并处理成模型能看懂的格式(Tensor),而DataLoader负责把这些 Tensor 批量、高效、随机地喂给模型进行训练。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/22 1:57:44

[C#][winform]基于yolov11的打架行为检测系统C#源码+onnx模型+评估指标曲线+精美GUI界面

【算法介绍】在社会治安管理朝着智能化、精细化方向加速推进的重要阶段,及时且精准地监测公共场所中的打架行为,已然成为维护社会秩序稳定、保障公民人身安全以及提升城市治理水平的核心任务之一。公共场所作为人员密集且流动频繁的区域,其环…

作者头像 李华
网站建设 2026/6/23 14:05:30

2022年TRC SCI1区TOP,基于随机分形搜索算法的多无人机四维航迹优化自适应冲突消解方法,深度解析+性能实测

目录1.摘要2.基于风险的4D航线与飞行冲突建模3.冲突解决和4D路线优化4.随机分形搜索算法5.结果展示6.参考文献7.代码获取8.算法辅导应用定制读者交流1.摘要 随着无人航空系统在城市低空的快速发展,安全高效的低空交通管理亟需突破。飞前四维航迹优化是实现冲突探测…

作者头像 李华
网站建设 2026/6/22 18:25:31

《智能世界2035》——华为预测十年以后智能世界的模样

导语:如果回到十年前,你会做什么?如果你知道十年后的样子,现在你会做什么?如果把 2025 比作 AI 的“青春期”,那么 2035 将是它真正走向社会的“成人礼”。华为《智能世界2035》 用130 页的战略报告介绍了 …

作者头像 李华
网站建设 2026/6/23 8:23:18

FLAC3D随机裂隙建模:从基础到复杂网络

FLAC3D随机裂隙,fractureFLAC3D作为一款功能强大的离散元数值模拟软件,在岩石力学领域有着广泛的应用。其中,随机裂隙网络的建模是岩石力学研究中的重要一环,因为它能够更好地反映实际岩石中的复杂结构。本文将介绍如何在FLAC3D中…

作者头像 李华
网站建设 2026/6/21 13:24:17

终极指南:TUnit服务虚拟化测试实践

终极指南:TUnit服务虚拟化测试实践 【免费下载链接】TUnit A modern, fast and flexible .NET testing framework 项目地址: https://gitcode.com/GitHub_Trending/tun/TUnit 在当今的软件开发中,你是否经常遇到这样的困扰:测试因为外…

作者头像 李华