news 2026/1/21 15:43:27

Day 42 Dataset 和 Dataloader 类

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day 42 Dataset 和 Dataloader 类

@浙大疏锦行

一、核心定位

核心角色核心作用
Dataset「数据容器」/「数据加工厂」定义单条数据的读取、预处理逻辑(如从 CSV 读一行、编码、填充、标准化),支持按索引取数
DataLoader「数据搬运工」/「批量调度器」封装Dataset,实现批量加载、数据打乱、多线程读取、分批迭代,解决显存和效率问题

简单来说:

  • Dataset解决 “单条数据怎么来” 的问题;
  • DataLoader解决 “怎么把数据批量喂给模型” 的问题。

二、Dataset类详解

PyTorch 提供了内置Dataset(如TensorDatasetImageFolder),但实际项目中,必须自定义Dataset,需继承torch.utils.data.Dataset并实现两个核心方法:

  1. __len__():返回数据集的总条数(让DataLoader知道有多少数据);
  2. __getitem__(idx):根据索引idx返回单条数据(特征 + 标签),是预处理的核心。

1. 自定义Dataset实战

import torch from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np from sklearn.preprocessing import StandardScaler class CreditDefaultDataset(Dataset): """ 信用违约预测自定义Dataset :param data_path: CSV数据路径 :param is_train: 是否为训练集(用于区分训练/测试的标准化) :param scaler: 标准化器(训练集拟合,测试集复用) """ def __init__(self, data_path, is_train=True, scaler=None): self.data_path = data_path self.is_train = is_train self.scaler = scaler # 1. 读取数据 + 预处理(复用你之前的逻辑) self.data = self._preprocess_data() # 2. 分离特征和标签 self.X = self.data.drop(['Credit Default'], axis=1).values # 特征数组 self.y = self.data['Credit Default'].values # 标签数组 # 3. 标准化(训练集拟合scaler,测试集复用) if self.is_train: self.scaler = StandardScaler() self.X = self.scaler.fit_transform(self.X) else: assert self.scaler is not None, "测试集必须传入训练集的scaler!" self.X = self.scaler.transform(self.X) def _preprocess_data(self): """封装你之前的预处理逻辑(编码、填充缺失值)""" data = pd.read_csv(self.data_path) data2 = pd.read_csv(self.data_path) # 1. 字符串变量编码 # Home Ownership 标签编码 home_mapping = {'Own Home':1, 'Rent':2, 'Have Mortgage':3, 'Home Mortgage':4} data['Home Ownership'] = data['Home Ownership'].map(home_mapping) # Years in current job 标签编码 years_mapping = {'< 1 year':1, '1 year':2, '2 years':3, '3 years':4, '4 years':5, '5 years':6, '6 years':7, '7 years':8, '8 years':9, '9 years':10, '10+ years':11} data['Years in current job'] = data['Years in current job'].map(years_mapping) # Purpose 独热编码 + bool转int data = pd.get_dummies(data, columns=['Purpose']) new_cols = [col for col in data.columns if col not in data2.columns] for col in new_cols: data[col] = data[col].astype(int) # Term 映射 + 重命名 term_mapping = {'Short Term':0, 'Long Term':1} data['Term'] = data['Term'].map(term_mapping) data.rename(columns={'Term': 'Long Term'}, inplace=True) # 2. 缺失值填充(连续特征用众数) cont_feats = data.select_dtypes(include=['int64', 'float64']).columns.tolist() for feat in cont_feats: mode_val = data[feat].mode()[0] data[feat].fillna(mode_val, inplace=True) return data def __len__(self): """返回数据集总条数(必须实现)""" return len(self.X) def __getitem__(self, idx): """ 根据索引返回单条数据(必须实现) :param idx: 数据索引(int) :return: (特征tensor, 标签tensor) """ # 取单条数据(numpy数组) x = self.X[idx] y = self.y[idx] # 转换为PyTorch张量(MLP需要float32类型) x_tensor = torch.from_numpy(x).float() y_tensor = torch.from_numpy(np.array(y)).float() # 二分类标签用float(适配BCELoss) return x_tensor, y_tensor

2.Dataset核心方法解释

方法作用
__init__初始化:读取数据、预处理、分离特征 / 标签、标准化(核心预处理逻辑都在这里)
__len__返回数据总条数,DataLoader会用这个方法知道 “有多少批数据”
__getitem__按索引取单条数据,是DataLoader批量加载的基础(每次取 1 条,再拼 batch)
自定义方法(如_preprocess_data封装预处理逻辑,让代码更整洁(非必须,但推荐)

三、DataLoader类详解

DataLoaderDataset的 “上层封装”,核心作用是Dataset中的单条数据拼成批次,并提供高效读取能力。

1. 核心参数

参数作用训练集 / 测试集建议
dataset传入自定义的Dataset实例(必须)-
batch_size每批数据的条数(如 32、64)训练集:32/64;测试集:可更大(如 128)
shuffle是否打乱数据(避免模型学习顺序)训练集:True;测试集:False
num_workers多线程读取数据(加速)Windows:0;Linux/Mac:4/8(根据 CPU 核数)
drop_last是否丢弃最后一批不足batch_size的数据训练集:True;测试集:False
pin_memory是否锁定内存(GPU 训练时加速数据传输)GPU 训练:True;CPU:False

2.DataLoader实战

# 数据路径(替换为你的实际路径) DATA_PATH = r"E:\study\PythonStudy\python60-days-challenge-master\data.csv" # Step1:创建训练集Dataset(拟合scaler) train_dataset = CreditDefaultDataset(data_path=DATA_PATH, is_train=True) # 提取训练集的scaler(供测试集复用) train_scaler = train_dataset.scaler # Step2:划分训练/测试集(可选:如果Dataset已包含全量数据,这里拆分) # 注意:也可以在Dataset中直接划分,这里用切片示例 train_size = int(0.8 * len(train_dataset)) test_size = len(train_dataset) - train_size train_subset, test_subset = torch.utils.data.random_split( train_dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42) ) # 测试集Dataset(复用训练集的scaler) test_dataset = CreditDefaultDataset(data_path=DATA_PATH, is_train=False, scaler=train_scaler) # Step3:创建DataLoader(核心) train_loader = DataLoader( dataset=train_subset, batch_size=32, shuffle=True, # 训练集打乱 num_workers=0, # Windows建议设0,避免多线程报错 drop_last=True, # 丢弃最后一批不足32条的数据 pin_memory=True if torch.cuda.is_available() else False # GPU加速 ) test_loader = DataLoader( dataset=test_subset, batch_size=64, shuffle=False, # 测试集不打乱 num_workers=0, drop_last=False, pin_memory=True if torch.cuda.is_available() else False ) # Step4:迭代DataLoader(训练/测试时的核心用法) # 示例:遍历训练集批次 print("===== 训练集批次示例 =====") for batch_idx, (x_batch, y_batch) in enumerate(train_loader): print(f"批次 {batch_idx+1}:") print(f" 特征形状:{x_batch.shape} (batch_size={x_batch.shape[0]}, 特征数={x_batch.shape[1]})") print(f" 标签形状:{y_batch.shape}") # 训练时:将批次数据移到GPU → 前向传播 → 反向传播 if batch_idx == 2: # 只打印前3批 break # 示例:遍历测试集批次 print("\n===== 测试集批次示例 =====") for batch_idx, (x_batch, y_batch) in enumerate(test_loader): print(f"批次 {batch_idx+1}:特征形状={x_batch.shape},标签形状={y_batch.shape}") if batch_idx == 1: break

3.DataLoader迭代逻辑解释

  • enumerate(train_loader)会逐批返回(批次索引, (特征批次, 标签批次))
  • 特征批次形状:(batch_size, 特征数)(如 32 个样本,20 个特征 →(32, 20));
  • 标签批次形状:(batch_size,)(如 32 个样本 →(32,));
  • 训练时,每批数据会被喂给模型:outputs = model(x_batch)
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/1/20 20:33:45

蓝桥杯JAVA--启蒙之路(三)语句

一前言今天依旧更新有关JAVA基础的知识&#xff0c;唉。自从更新JAVA之后浏览量什么的都下降了&#xff0c;可能是大家也不喜欢这么枯燥的基础学习吧&#xff0c;但是基础还是很重要的&#xff0c;明天和后天可能会停更&#xff0c;因为我要回家了。二主要内容if条件判断&#…

作者头像 李华
网站建设 2026/1/16 20:29:01

金融级情绪识别模型训练全攻略(基于千万级对话数据的优化经验)

第一章&#xff1a;金融客服Agent情绪识别的技术背景与业务价值 在金融服务领域&#xff0c;客户与客服代理&#xff08;Agent&#xff09;之间的交互质量直接影响用户满意度与品牌信任度。随着人工智能技术的发展&#xff0c;尤其是自然语言处理与语音情感分析的进步&#xff…

作者头像 李华
网站建设 2026/1/18 3:20:28

计算机系统基础 bufbomb 实验三

听报告无事&#xff0c;顺手写下做过的实验报告,话不多说&#xff0c;开始正文1、实验目的加深对IA-32函数调用规则和栈帧结构的理解。2、实验原理对目标程序实施缓冲区溢出攻击&#xff0c;通过造成缓冲区溢出来破坏目标程序的栈帧结构&#xff0c;继而执行一些原来程序中没有…

作者头像 李华
网站建设 2026/1/21 14:13:14

Tomcat内存机制以及按场景调优

Tomcat内存机制深度解析与场景化调优 Tomcat作为Java生态中最主流的Web容器&#xff0c;其内存管理直接决定应用的稳定性、响应速度和并发能力。本文将从内存机制底层原理、内存区域划分、常见问题根源&#xff0c;到不同业务场景的调优策略&#xff0c;进行超详细、全维度的拆…

作者头像 李华
网站建设 2026/1/20 20:14:21

ConvertX:自托管的在线文件转换器

ConvertX&#xff1a;自托管的在线文件转换器 在当今信息化时代&#xff0c;文件格式的多样性带来了很多不便。无论是处理文档、图像、视频还是音频&#xff0c;往往需要将文件转换成适合自己需求的格式。为了解决这一问题&#xff0c;ConvertX应运而生&#xff0c;它是一款强大…

作者头像 李华
网站建设 2026/1/19 3:16:15

2025年支持企业实现社会价值与商业价值的战略

在2025年&#xff0c;企业面临的挑战是同时实现社会价值与商业价值。通过创新战略&#xff0c;企业可以有效应对这一挑战。首先&#xff0c;构建以社会责任为核心的商业模式&#xff0c;将信任与责任感融入品牌之中&#xff0c;能够带来更高的顾客忠诚度和市场竞争力。其次&…

作者头像 李华