news 2026/2/28 22:09:07

DAY 39 Dataset和Dataloader

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
DAY 39 Dataset和Dataloader

一、数据介绍

CIFAR 是机器学习和计算机视觉领域中广泛使用的图像分类基准数据集,由加拿大高级研究学院(Canadian Institute for Advanced Research,CIFAR)的研究团队发布,主要用于小尺寸图像的分类任务,是入门和验证图像分类模型性能的经典数据集。

1、数据集的核心版本

CIFAR 数据集主要分为两个核心版本,二者在类别复杂度和样本划分上有明显区别:

  1. CIFAR-10

    • 类别数量:10 个互斥的图像类别,包括飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车。
    • 样本规模:共 60000 张 32×32 的彩色 RGB 图像,每个类别包含 6000 张图像;其中 50000 张为训练集(每个类别 5000 张),10000 张为测试集(每个类别 1000 张)。
    • 特点:类别间区分度相对清晰,适合入门级图像分类模型的训练和验证。
  2. CIFAR-100

    • 类别层级:包含 100 个细分类别,同时这些类别又归属于 20 个粗分类别(如 “水生哺乳动物” 包含海狮、海豹等细分类别)。
    • 样本规模:同样是 60000 张 32×32 的彩色 RGB 图像,每个细分类别包含 600 张图像;训练集 50000 张(每个细分类别 500 张),测试集 10000 张(每个细分类别 100 张)。
    • 特点:类别数量更多且部分类别相似度高,任务难度显著高于 CIFAR-10,常用于验证模型的细粒度分类能力。

2、数据集特点

  1. 图像尺寸小:32×32 的分辨率远低于真实场景的图像,模型学习到的特征相对有限,容易出现过拟合。
  2. 数据多样性:图像涵盖了自然和人造物体,且包含不同角度、光照和背景的样本,能一定程度上模拟真实世界的图像分布。
  3. 无标注噪声:数据集标注质量高,无明显标注错误,适合作为模型性能的客观基准。

二、实例化

1. 数据预处理

import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具 from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块 import matplotlib.pyplot as plt # 设置随机种子,确保结果可复现 torch.manual_seed(42) # 1. 数据预处理,该写法非常类似于管道pipeline # transforms 模块提供了一系列常用的图像预处理操作 # 1. 数据预处理:先归一化到[0,1],再标准化(适配CIFAR-10的3通道参数) transform = transforms.Compose([ transforms.ToTensor(), # 转换为张量,将HWC格式的PIL图转为CHW格式,同时归一化到[0,1] # CIFAR-10经典的均值和标准差(3通道,对应RGB) transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616)) ]) # 2. 加载CIFAR-10数据集(修改root为压缩包所在文件夹) train_dataset = datasets.CIFAR10( root=r"D:\PythonStudy", # 压缩包所在的文件夹路径(关键!不是压缩包本身) train=True, download=True, # 检测到文件夹内有压缩包时,仅解压不下载 transform=transform ) test_dataset = datasets.CIFAR10( root=r"D:\PythonStudy", # 同训练集的root路径 train=False, transform=transform ) # 可选:验证数据集基本信息 print(f"CIFAR-10训练集数量:{len(train_dataset)}") print(f"CIFAR-10测试集数量:{len(test_dataset)}") print(f"单张图片形状:{train_dataset[0][0].shape}") # 输出 torch.Size([3, 32, 32])

2. Dataset类

现在我们想要取出来一个图片,看看长啥样,因为datasets.CIFAR10(若使用 CIFAR-100 则为datasets.CIFAR100)本质上继承了torch.utils.data.Dataset,所以自然需要有对应的方法。

import matplotlib.pyplot as plt # 随机选择一张图片,可以重复运行,每次都会随机选择 sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引 # len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字 image, label = train_dataset[sample_idx] # 获取图片和标签

PyTorch 的torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:

- __len__():返回数据集的样本总数。

- __getitem__(idx):根据索引idx返回对应样本的数据和标签。

PyTorch 要求所有数据集必须实现__getitem__和__len__,这样才能被DataLoader等工具兼容。这是一种接口约定,类似函数参数的规范。这意味着,如果你要创建一个自定义数据集,你需要实现这两个方法,否则PyTorch将无法识别你的数据集。

在 Python 中,__getitem__和__len__ 是类的特殊方法(也叫魔术方法 ),它们不是像普通函数那样直接使用,而是需要在自定义类中进行定义,来赋予类特定的行为。以下是关于这两个方法具体的使用方式:

1.__getitem__方法

__getitem__方法用于让对象支持索引操作,当使用[]语法访问对象元素时,Python 会自动调用该方法。

通过定义__getitem__方法,让MyList类的实例能够像 Python 内置的列表一样使用索引获取元素。

2.__len__方法

__len__方法用于返回对象中元素的数量,当使用内置函数len()作用于对象时,Python 会自动调用该方法。

from torch.utils.data import Dataset # CIFAR-10数据集的简化版本 class CIFAR10(Dataset): def __init__(self, root, train=True, transform=None): # 初始化:加载图片路径和标签 self.data, self.targets = fetch_cifar10_data(root, train) 的函数 self.transform = transform # 预处理操作 def __len__(self): return len(self.data) # 返回样本总数 def __getitem__(self, idx): # 获取指定索引的样本 # 获取指定索引的图像和标签 img, target = self.data[idx], self.targets[idx] # 应用图像预处理(如ToTensor、Normalize) if self.transform is not None: # 如果有预处理操作 img = self.transform(img) # 转换图像格式 # 这里假设 img 是一个 PIL 图像对象(CIFAR-10为RGB格式),transform 会将其转换为张量并进行归一化/标准化 return img, target # 返回处理后的图像和标签

1. 模块导入:from torch.utils.data import Dataset

  • 作用Dataset是 PyTorch 中所有自定义数据集的抽象基类,自定义数据集必须继承它,且必须实现__len____getitem__两个核心方法,否则会报错。
  • 为什么必须继承:PyTorch 的DataLoader(数据加载器)只能处理继承自Dataset的类,这是 PyTorch 数据加载流水线的核心规范。

2. 类定义:class CIFAR10(Dataset)

  • 作用:定义专用于 CIFAR-10 的数据集类,命名贴合数据集类型(若适配 CIFAR-100 则改为CIFAR100),继承Dataset后具备 PyTorch 数据集的标准特性。

3.__init__方法(初始化)

代码行核心作用细节说明
self.data, self.targets = fetch_cifar10_data(root, train)加载数据和标签fetch_cifar10_data是自定义数据加载函数,作用是从root路径读取 CIFAR-10 的图像(RGB PIL 格式)和标签;②train参数区分训练集 / 测试集,逻辑与 MNIST 版一致;③ 最终self.data存储所有图像,self.targets存储所有标签。
self.transform = transform接收预处理逻辑① 保存外部传入的预处理操作(如ToTensor()Normalize());② 预处理逻辑与 MNIST 版通用,仅 CIFAR-10 需适配 3 通道的归一化参数;③ 允许外部灵活调整预处理(如增广、归一化等),不修改数据集类本身。

4.__len__方法(返回样本总数)

  • 代码return len(self.data)
  • 作用:① 告诉DataLoader数据集的总样本数,是批量采样、打乱、多进程加载的基础;② 调用len(train_dataset)时会触发该方法,比如验证数据集大小(如 CIFAR-10 训练集返回 50000,测试集返回 10000);③ 逻辑与 MNIST 版完全一致,仅self.data的长度对应 CIFAR-10 的样本数。

5.__getitem__方法(按索引取样本,核心)

这是数据集类的核心方法DataLoader迭代时会反复调用该方法获取单个样本,拆解作用如下:

代码行核心作用细节说明
img, target = self.data[idx], self.targets[idx]读取原始样本idx是传入的样本索引(0~len (data)-1);②img是原始 RGB PIL 图像(32×32×3),target是 0-9 的整数标签(对应 CIFAR-10 的类别);③ 逻辑与 MNIST 版一致,仅img从单通道灰度图变为 3 通道 RGB 图。
if self.transform is not None: img = self.transform(img)应用预处理① 对原始 PIL 图像执行预处理(如ToTensor()转为 C×H×W 张量、Normalize()标准化);② 预处理是可选的(若transform=None则返回原始 PIL 图像);③ 逻辑与 MNIST 版完全一致,仅预处理参数适配 CIFAR-10 的 3 通道。
return img, target返回样本① 标准返回格式:(图像张量,标签),是 PyTorch 模型训练 / 测试的输入格式;② 图像张量形状为[3, 32, 32](CIFAR-10),MNIST 版为[1, 28, 28],仅形状差异,返回逻辑不变。
# 3. 取出单个样本 image, label = train_dataset[5] # 4. 修复后的可视化函数(核心修改) def imshow(img): # 关键:将均值/标准差转为PyTorch张量,并调整维度为[3,1,1]适配广播 std = torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1) mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1) # 反标准化(张量间运算,支持广播) img = img * std + mean # 限制像素值在[0,1](避免数值溢出导致显示异常) img = torch.clamp(img, 0, 1) # 转为numpy并调整维度(C×H×W → H×W×C) npimg = img.numpy() plt.imshow(npimg.transpose((1, 2, 0))) plt.axis('off') # 隐藏坐标轴 plt.show() # 5. 调用可视化(逻辑不变) print(f"Label: {label}") imshow(image)

3. Dataloader类

DataLoader是 PyTorch 封装的批量数据加载器,核心作用是将Dataset(数据集,定义了 “如何取单个样本”)封装为 “批量迭代器”,适配模型训练 / 测试时的批量输入需求,同时支持打乱、多进程加载等优化。

# 3. 创建数据加载器 train_loader = DataLoader( train_dataset, batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关 shuffle=True # 随机打乱数据 ) test_loader = DataLoader( test_dataset, batch_size=1000 # 每个批次1000张图片 # shuffle=False # 测试时不需要打乱数据 )

训练集加载器(train_loader)

基于训练数据集train_dataset创建批量加载器:

  • batch_size=64:每次加载 64 张图片为一个批次(选 2 的幂次方适配 GPU 计算效率);
  • shuffle=True:训练前随机打乱数据顺序,避免模型学 “数据顺序”,提升泛化能力。

测试集加载器(test_loader)

基于测试数据集test_dataset创建批量加载器:

  • batch_size=1000:测试集批次更大(无训练开销,提速评估);
  • 默认shuffle=False:测试不打乱数据,保证结果可复现,也节省计算开销。

核心:DataLoader把数据集封装成 “批量迭代器”,训练侧重打乱 + 小批次适配 GPU,测试侧重大批次提速 + 不打乱保复现。

三、总结

核心结论

  • Dataset 类:定义数据的内容和格式(即 “如何获取单个样本”),包括:
    • 数据存储路径 / 来源(如文件路径、数据库查询)。
    • 原始数据的读取方式(如图像解码为 PIL 对象、文本读取为字符串)。
    • 样本的预处理逻辑(如裁剪、翻转、归一化等,通常通过 transform 参数实现)。
    • 返回值格式(如 (image_tensor, label))。
  • DataLoader 类:定义数据的加载方式和批量处理逻辑(即 “如何高效批量获取数据”),包括:
    • 批量大小(batch_size)。
    • 是否打乱数据顺序(shuffle)。

勇闯python的第39天@浙大疏锦行

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/2/27 17:31:12

WaveTools鸣潮120帧解锁与游戏性能优化全攻略

WaveTools鸣潮120帧解锁与游戏性能优化全攻略 【免费下载链接】WaveTools 🧰鸣潮工具箱 项目地址: https://gitcode.com/gh_mirrors/wa/WaveTools 🎯 你是否在鸣潮1.2版本更新后,发现原本流畅的120帧体验突然消失了?别担心…

作者头像 李华
网站建设 2026/2/26 19:48:28

三步学会百度网盘极速下载:告别龟速的终极方案

三步学会百度网盘极速下载:告别龟速的终极方案 【免费下载链接】baidu-wangpan-parse 获取百度网盘分享文件的下载地址 项目地址: https://gitcode.com/gh_mirrors/ba/baidu-wangpan-parse 还在为百度网盘的下载速度而烦恼吗?当你明明拥有高速网络…

作者头像 李华
网站建设 2026/2/28 8:26:41

5大实用技巧:用Calibre-Douban插件智能管理电子书元数据

5大实用技巧:用Calibre-Douban插件智能管理电子书元数据 【免费下载链接】calibre-douban Calibre new douban metadata source plugin. Douban no longer provides book APIs to the public, so it can only use web crawling to obtain data. This is a calibre D…

作者头像 李华
网站建设 2026/2/27 14:09:48

飞书文档批量导出终极指南:一键解决文档迁移难题

飞书文档批量导出终极指南:一键解决文档迁移难题 【免费下载链接】feishu-doc-export 项目地址: https://gitcode.com/gh_mirrors/fe/feishu-doc-export 你是否曾经为文档迁移而头疼不已?当公司决定更换办公平台,或是需要将飞书知识库…

作者头像 李华
网站建设 2026/2/27 10:04:20

Source Han Serif思源宋体:免费开源中文字体专业应用指南

Source Han Serif思源宋体:免费开源中文字体专业应用指南 【免费下载链接】source-han-serif-ttf Source Han Serif TTF 项目地址: https://gitcode.com/gh_mirrors/so/source-han-serif-ttf 还在为字体版权问题而烦恼?Source Han Serif思源宋体为…

作者头像 李华
网站建设 2026/2/26 22:51:37

DOM Element:深入理解与操作

DOM Element:深入理解与操作 引言 在Web开发领域,DOM(文档对象模型)是一个至关重要的概念。DOM允许开发者与网页内容进行交互,如HTML和XML文档。DOM Element是DOM模型中的核心组成部分,它代表了HTML或XML文档中的每一个元素。本文将深入探讨DOM Element的概念、属性、方…

作者头像 李华