news 2026/6/23 18:51:11

day37简单的神经网络@浙大疏锦行

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
day37简单的神经网络@浙大疏锦行

day37简单的神经网络@浙大疏锦行

使用 sklearn 的 load_digits 数据集 (8x8 像素的手写数字) 进行 MLP 训练。

importtorchimporttorch.nnasnnimporttorch.optimasoptimfromsklearn.datasetsimportload_digitsfromsklearn.model_selectionimporttrain_test_splitfromsklearn.preprocessingimportMinMaxScalerimportnumpyasnpimportmatplotlib.pyplotasplt# 1. 加载数据digits=load_digits()X=digits.data y=digits.targetprint(f"数据形状:{X.shape}")print(f"标签形状:{y.shape}")# 查看一张图片plt.imshow(digits.images[0],cmap='gray')plt.title(f"Label:{y[0]}")plt.show()

数据形状: (1797, 64) 标签形状: (1797,)

# 2. 数据预处理# 划分训练集和测试集X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=42)# 归一化scaler=MinMaxScaler()X_train=scaler.fit_transform(X_train)X_test=scaler.transform(X_test)# 转换为 TensorX_train=torch.FloatTensor(X_train)y_train=torch.LongTensor(y_train)X_test=torch.FloatTensor(X_test)y_test=torch.LongTensor(y_test)print("训练集 Tensor 形状:",X_train.shape)print("测试集 Tensor 形状:",X_test.shape)

训练集 Tensor 形状: torch.Size([1437, 64])

测试集 Tensor 形状: torch.Size([360, 64])

# 3. 定义模型classMLP(nn.Module):def__init__(self):super(MLP,self).__init__()# 输入层 64 (8*8像素) -> 隐藏层 32 -> 输出层 10 (0-9数字)self.fc1=nn.Linear(64,32)self.relu=nn.ReLU()self.fc2=nn.Linear(32,10)defforward(self,x):out=self.fc1(x)out=self.relu(out)out=self.fc2(out)returnout model=MLP()print(model)

MLP(

(fc1): Linear(in_features=64, out_features=32, bias=True) (relu): ReLU()

(fc2): Linear(in_features=32, out_features=10, bias=True)

)

# 4. 定义损失函数和优化器criterion=nn.CrossEntropyLoss()optimizer=optim.SGD(model.parameters(),lr=0.1)# 学习率稍微调大一点,或者增加epoch
# 5. 训练模型num_epochs=2000losses=[]forepochinrange(num_epochs):# 前向传播outputs=model(X_train)loss=criterion(outputs,y_train)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()losses.append(loss.item())if(epoch+1)%100==0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss:{loss.item():.4f}')

# 6. 可视化损失plt.plot(range(num_epochs),losses)plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss')plt.show()

# 7. 模型评估withtorch.no_grad():# 训练集准确率outputs_train=model(X_train)_,predicted_train=torch.max(outputs_train,1)accuracy_train=(predicted_train==y_train).sum().item()/y_train.size(0)# 测试集准确率outputs_test=model(X_test)_,predicted_test=torch.max(outputs_test,1)accuracy_test=(predicted_test==y_test).sum().item()/y_test.size(0)print(f'训练集准确率:{accuracy_train:.4f}')print(f'测试集准确率:{accuracy_test:.4f}')

@浙大疏锦行

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/23 13:42:06

JAVA的平凡之路——此峰乃是最高峰JVM-附加小菜-04

图1.1每台机器300/s,每个订单对象假设1KB,300KB/s可能会涉及其他对象放大20倍,并且可能涉及其他操作情况,再放大10 300*20*10 大约每秒60MB/s 当前堆内存 3072 MB,新生代占1/3,大约 1g ,并且ede…

作者头像 李华
网站建设 2026/6/22 22:44:18

【电力系统】电力系统优化与控制热液调度附Matlab代码和报告

✅作者简介:热爱科研的Matlab仿真开发者,擅长数据处理、建模仿真、程序设计、完整代码获取、论文复现及科研仿真。 🍎 往期回顾关注个人主页:Matlab科研工作室 🍊个人信条:格物致知,完整Matlab代码获取及仿…

作者头像 李华
网站建设 2026/6/22 16:57:43

Golang实战:构建综合多头(逾期+反欺诈)风险查询的高性能客户端

一、用 Go 构建毫秒级风控“熔断器” 在实时信贷审批场景中&#xff0c;风控系统需要在极短的时间内&#xff08;通常 < 200ms&#xff09;做出决策。如果一个申请人当前存在信贷逾期或属于欺诈团伙成员&#xff0c;系统必须立即“熔断”流程&#xff0c;直接拒单&#xff0…

作者头像 李华
网站建设 2026/6/22 19:16:19

【TSP问题】基于蜣螂算法DBO和改进的蜣螂算法FADBO求解旅行商TSP问题(可根据自己的经纬度设置自己想要到达的地区)附Matlab代码

✅作者简介&#xff1a;热爱科研的Matlab仿真开发者&#xff0c;擅长数据处理、建模仿真、程序设计、完整代码获取、论文复现及科研仿真。&#x1f34e; 往期回顾关注个人主页&#xff1a;Matlab科研工作室&#x1f34a;个人信条&#xff1a;格物致知,完整Matlab代码获取及仿真…

作者头像 李华