news 2026/6/23 6:44:31

Day 36 MLP神经网络的训练

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day 36 MLP神经网络的训练

文章目录

  • Day 36 · MLP神经网络的训练
    • 数据的准备
    • 模型设计
    • 训练
    • 可视化

Day 36 · MLP神经网络的训练

pytorch和cuda的安装有很多教程,这里就不多赘述了。

importtorch torch.cuda
<module 'torch.cuda' from '/home/ubuntu24/anaconda3/envs/torch-gpu/lib/python3.13/site-packages/torch/cuda/__init__.py'>
importtorch# 检查CUDA是否可用iftorch.cuda.is_available():print("CUDA可用!")# 获取可用的CUDA设备数量device_count=torch.cuda.device_count()print(f"可用的CUDA设备数量:{device_count}")# 获取当前使用的CUDA设备索引current_device=torch.cuda.current_device()print(f"当前使用的CUDA设备索引:{current_device}")# 获取当前CUDA设备的名称device_name=torch.cuda.get_device_name(current_device)print(f"当前CUDA设备的名称:{device_name}")# 获取CUDA版本cuda_version=torch.version.cudaprint(f"CUDA版本:{cuda_version}")else:print("CUDA不可用。")
CUDA可用! 可用的CUDA设备数量: 1 当前使用的CUDA设备索引: 0 当前CUDA设备的名称: NVIDIA GeForce RTX 4070 Laptop GPU CUDA版本: 12.4

数据的准备

# 导入3分类的鸢尾花数据集fromsklearn.datasetsimportload_irisfromsklearn.model_selectionimporttrain_test_splitimportnumpyasnp iris=load_iris()X=iris.data y=iris.target X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=42)print(X_train.shape)print(y_train.shape)print(X_test.shape)print(y_test.shape)
(120, 4) (120,) (30, 4) (30,)
# 神经网络对于输入数据敏感,因此要对输入的数据进行归一化处理fromsklearn.preprocessingimportMinMaxScaler scaler=MinMaxScaler()X_train=scaler.fit_transform(X_train)X_test=scaler.transform(X_test)
# 将数据转换为张量,Pytorch使用张量进行训练,张量可以理解为特殊的数组X_train=torch.FloatTensor(X_train)y_train=torch.LongTensor(y_train)X_test=torch.FloatTensor(X_test)y_test=torch.LongTensor(y_test)

模型设计

importtorchimporttorch.nnasnnimporttorch.optimasoptim# 定义MLP模型model=nn.Sequential(nn.Linear(4,10),nn.ReLU(),nn.Linear(10,3))# 或者classMLP(nn.Module):def__init__(self):super().__init__()self.fc1=nn.Linear(4,10)# 输入层到隐藏层self.relu=nn.ReLU()# 引入非线性self.fc2=nn.Linear(10,3)# 隐藏层到输出层# 输出层不需要激活函数,因为后面会用到交叉熵函数cross_entropy,交叉熵函数内部有softmax函数,会把输出转化为概率defforward(self,x):out=self.fc1(x)out=self.relu(out)out=self.fc2(out)returnout# MLP_model=MLP()
# 分类问题使用交叉熵损失函数criterion=nn.CrossEntropyLoss()# 使用Adam优化器optimizer=optim.Adam(model.parameters(),lr=0.01)

训练

采用交叉熵损失 + Adam 优化器。训练前先把模型和数据移动到同一设备,随后在循环中维护损失、准确率列表。

device=torch.device('cuda'iftorch.cuda.is_available()else'cpu')model=model.to(device)X_train=X_train.to(device)y_train=y_train.to(device)X_test=X_test.to(device)y_test=y_test.to(device)num_epochs=20_000 train_losses,val_losses=[],[]train_accuracies,val_accuracies=[],[]# 计算准确率defcalculate_accuracy(logits,labels):preds=torch.argmax(logits.detach(),dim=1)return(preds==labels).float().mean().item()forepochinrange(1,num_epochs+1):model.train()optimizer.zero_grad()outputs=model(X_train)train_loss=criterion(outputs,y_train)train_loss.backward()optimizer.step()train_losses.append(train_loss.item())train_accuracies.append(calculate_accuracy(outputs,y_train))model.eval()withtorch.no_grad():val_outputs=model(X_test)val_loss=criterion(val_outputs,y_test).item()val_acc=calculate_accuracy(val_outputs,y_test)val_losses.append(val_loss)val_accuracies.append(val_acc)ifepoch%400==0:print(f'Epoch [{epoch}/{num_epochs}] 'f'train_loss={train_loss.item():.4f}val_loss={val_loss:.4f}'f'train_acc={train_accuracies[-1]:.4f}val_acc={val_acc:.4f}')
Epoch [400/20000] train_loss=0.0629 val_loss=0.0538 train_acc=0.9750 val_acc=0.9667 Epoch [800/20000] train_loss=0.0497 val_loss=0.0292 train_acc=0.9833 val_acc=1.0000 Epoch [1200/20000] train_loss=0.0473 val_loss=0.0203 train_acc=0.9833 val_acc=1.0000 Epoch [1600/20000] train_loss=0.0468 val_loss=0.0173 train_acc=0.9833 val_acc=1.0000 Epoch [2000/20000] train_loss=0.0467 val_loss=0.0161 train_acc=0.9833 val_acc=1.0000 Epoch [2400/20000] train_loss=0.0466 val_loss=0.0157 train_acc=0.9833 val_acc=1.0000 Epoch [2800/20000] train_loss=0.0466 val_loss=0.0156 train_acc=0.9833 val_acc=1.0000 Epoch [3200/20000] train_loss=0.0466 val_loss=0.0155 train_acc=0.9833 val_acc=1.0000 Epoch [3600/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [4000/20000] train_loss=0.0466 val_loss=0.0154 train_acc=0.9833 val_acc=1.0000 Epoch [4400/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [4800/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [5200/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [5600/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [6000/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [6400/20000] train_loss=0.0466 val_loss=0.0154 train_acc=0.9833 val_acc=1.0000 Epoch [6800/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [7200/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [7600/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [8000/20000] train_loss=0.0466 val_loss=0.0154 train_acc=0.9833 val_acc=1.0000 Epoch [8400/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [8800/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [9200/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [9600/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [10000/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [10400/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [10800/20000] train_loss=0.0466 val_loss=0.0152 train_acc=0.9833 val_acc=1.0000 Epoch [11200/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [11600/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [12000/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [12400/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [12800/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [13200/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [13600/20000] train_loss=0.0466 val_loss=0.0150 train_acc=0.9833 val_acc=1.0000 Epoch [14000/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [14400/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [14800/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [15200/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [15600/20000] train_loss=0.0466 val_loss=0.0152 train_acc=0.9833 val_acc=1.0000 Epoch [16000/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [16400/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [16800/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [17200/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [17600/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [18000/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [18400/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [18800/20000] train_loss=0.0466 val_loss=0.0155 train_acc=0.9833 val_acc=1.0000 Epoch [19200/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000 Epoch [19600/20000] train_loss=0.0466 val_loss=0.0152 train_acc=0.9833 val_acc=1.0000 Epoch [20000/20000] train_loss=0.0466 val_loss=0.0153 train_acc=0.9833 val_acc=1.0000

可视化

有了损失/准确率数组,画双子图就能一眼看出模型是否过拟合或欠拟合。实践中建议在这里记录实验备注,比如 epoch 数、学习率、是否使用 GPU,方便未来对比。

importmatplotlib.pyplotasplt epochs=range(1,num_epochs+1)plt.figure(figsize=(12,5))plt.subplot(1,2,1)plt.plot(epochs,train_losses,label='Train Loss')plt.plot(epochs,val_losses,label='Validation Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Loss over Epochs')plt.legend()plt.subplot(1,2,2)plt.plot(epochs,train_accuracies,label='Train Accuracy')plt.plot(epochs,val_accuracies,label='Validation Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.title('Accuracy over Epochs')plt.legend()plt.tight_layout()plt.show()

@浙大疏锦行

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/21 20:54:30

B站视频下载终极指南:免费工具DownKyi完整使用教程

B站视频下载终极指南&#xff1a;免费工具DownKyi完整使用教程 【免费下载链接】downkyi 哔哩下载姬downkyi&#xff0c;哔哩哔哩网站视频下载工具&#xff0c;支持批量下载&#xff0c;支持8K、HDR、杜比视界&#xff0c;提供工具箱&#xff08;音视频提取、去水印等&#xff…

作者头像 李华
网站建设 2026/6/23 7:43:05

搞懂“元数据”:给数据办一张“身份证”

同事发给你一个 Excel 表格&#xff0c;文件名叫 data_final_v2.xlsx。你满怀期待地打开&#xff0c;结果发现&#xff1a;表头是 cryptic 的英文缩写&#xff08;如 c_amt, usr_stat&#xff09;&#xff1b;有一列全是数字 1, 0, 1, 0&#xff0c;你猜不出这代表“男女”还是…

作者头像 李华
网站建设 2026/6/23 19:08:13

04_C 语言进阶之避坑指南:多重 if-else 及多重条件混乱 —— 让逻辑不再 “绕迷宫”

C 语言进阶之避坑指南:多重 if-else 及多重条件混乱 —— 让逻辑不再 “绕迷宫” 一、多重 if-else 的 “逻辑迷宫”,你被困住了吗? “修改一个条件,整个功能逻辑全部错乱?” “多重 if-else 嵌套十几层,代码像绕迷宫,查 BUG 时看到头大?” “多个条件组合判断时,…

作者头像 李华
网站建设 2026/6/23 3:30:10

量子计算开发者必看(VSCode性能调优实战手册)

第一章&#xff1a;量子算法的 VSCode 性能分析在开发和调试量子算法时&#xff0c;VSCode 作为主流集成开发环境&#xff0c;其性能表现直接影响开发效率。通过合理配置插件与资源监控工具&#xff0c;可以显著提升大型量子电路模拟任务的响应速度。环境准备与扩展安装 为支持…

作者头像 李华
网站建设 2026/6/22 23:41:35

Android嵌套滑动冲突完全解析:从原理到实战解决方案

目录 引言:为什么会出现滑动冲突? 一、滑动冲突的三种典型场景 二、触摸事件分发机制回顾 三、解决方案一:外部拦截法 四、解决方案二:内部拦截法 五、解决方案三:NestedScrolling机制(推荐) 六、解决方案四:使用CoordinatorLayout(Google官方方案) 七、最佳实践与性…

作者头像 李华
网站建设 2026/6/23 18:58:30

ASTM D4169-DC13 标准,包装完整性

标准全称&#xff1a;Standard Practice for Performance Testing of Shipping Containers and Systems (运输集装箱和系统性能测试的标准实施规程)发布机构&#xff1a;美国材料与试验协会 (ASTM International)最新版本&#xff1a;ASTM D4169-2023e1&#xff08;2024 年 3 月…

作者头像 李华