news 2026/6/25 3:46:13

CIFAR10彩色图片识别

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
CIFAR10彩色图片识别

CIFAR10彩色图片识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

    文章目录

    • CIFAR10彩色图片识别
    • 前言
    • 一、准备工作
      • 1.项目环境
      • 2.设置CPU
      • 3.下载数据集
      • 3.加载数据集
      • 4.从数据集里取一个批次的图片和标签,进行数据可视化
    • 二、构建CNN网络
    • 三、加载并打印模型
    • 四、训练模型
      • 1.设置超参数
      • 2.编写训练函数
      • 3.编写测试函数
      • 4.正式训练
    • 五、结果可视化

前言

本文章旨在基于CIFAR10彩色图片识别学习构建CNN网络。


一、准备工作

1.项目环境

#语言环境:python 3.9
#编译环境:jupyter lab
#深度学习环境 :torch=1.12.1+cu113 torchvision=0.13.1+cu113

tips:如何查看python版本以及相关库版本?

!python--version#输出python版本#在jupyter lab运行该指令需要加!,vscode或者pycharm则不用。importtorch;print(torch.__version__)#输出torch版本importtorchvision;print(torchvision.__version__)#输出torchvision版本

运行结果如下:

Python3.9.251.12.1+cu1130.13.1+cu113

2.设置CPU

importtorchimporttorch.nnasnn#nn = Neural Network(神经网络)#用来快速搭建层:卷积层、全连接层、激活函数、损失函数等importmatplotlib.pyplotasplt#画图工具:用来显示图片、绘制损失曲线、可视化结果importtorchvision#作用:计算机视觉专用库#提供现成数据集#提供常用模型#提供图片预处理工具device=torch.device("cpu")device

运行结果如下:

device(type='cpu')

3.下载数据集

使用torchvision.datasets函数下载CIFAR10数据集:

train_ds=torchvision.datasets.CIFAR10('data',train=True,transform=torchvision.transforms.ToTensor(),# 将数据类型转化为Tensordownload=True)test_ds=torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),# 将数据类型转化为Tensordownload=True)

3.加载数据集

使用torch.utils.data.DataLoader函数加载CIFAR10数据集:

batch_size=32#每个批次取得的样本数train_dl=torch.utils.data.DataLoader(train_ds,#训练集batch_size=batch_size,#批次大小shuffle=True)#整个数据集整体随机打乱顺序后按批次大小取数test_dl=torch.utils.data.DataLoader(test_ds,batch_size=batch_size)

4.从数据集里取一个批次的图片和标签,进行数据可视化

importnumpyasnp#从数据集里取一个批次的图片和标签imgs,labels=next(iter(train_dl))# 指定图片大小,图像大小为20宽、5高的绘图(单位为英寸inch)plt.figure(figsize=(20,5))fori,imgsinenumerate(imgs[:20]):# 维度缩减npimg=np.squeeze(imgs.numpy())# 将整个figure分成4行5列,绘制第i+1个子图。plt.subplot(4,5,i+1)plt.imshow(npimg,cmap=plt.cm.binary)plt.axis('off')

运行结果如下:

二、构建CNN网络

importtorch.nn.functionalasF num_classes=10# 图片的类别数classModel(nn.Module):def__init__(self):super().__init__()# 特征提取网络self.conv1=nn.Conv2d(3,64,kernel_size=3)# 第一层卷积,卷积核大小为3*3self.pool1=nn.MaxPool2d(kernel_size=2)# 设置池化层,池化核大小为2*2self.conv2=nn.Conv2d(64,64,kernel_size=3)# 第二层卷积,卷积核大小为3*3self.pool2=nn.MaxPool2d(kernel_size=2)self.conv3=nn.Conv2d(64,128,kernel_size=3)# 第二层卷积,卷积核大小为3*3self.pool3=nn.MaxPool2d(kernel_size=2)# 分类网络self.fc1=nn.Linear(512,256)self.fc2=nn.Linear(256,num_classes)# 前向传播defforward(self,x):x=self.pool1(F.relu(self.conv1(x)))x=self.pool2(F.relu(self.conv2(x)))x=self.pool3(F.relu(self.conv3(x)))x=torch.flatten(x,start_dim=1)x=F.relu(self.fc1(x))x=self.fc2(x)returnx

网络数据shape变化过程为:
3, 32, 32(输入数据)
-> 64, 30, 30(经过卷积层1)-> 64, 15, 15(经过池化层1)
-> 64, 13, 13(经过卷积层2)-> 64, 6, 6(经过池化层2)
-> 128, 4, 4(经过卷积层3) -> 128, 2, 2(经过池化层3)
-> 512 -> 256 -> num_classes(10)

三、加载并打印模型

fromtorchinfoimportsummary# 将模型转移到CPU中model=Model().to(device)summary(model)

运行结果如下:

四、训练模型

1.设置超参数

loss_fn=nn.CrossEntropyLoss()# 创建损失函数learn_rate=1e-2# 学习率opt=torch.optim.SGD(model.parameters(),lr=learn_rate)

2.编写训练函数

# 训练循环deftrain(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset)# 训练集的大小,一共60000张图片num_batches=len(dataloader)# 批次数目,1875(60000/32)train_loss,train_acc=0,0# 初始化训练损失和正确率forX,yindataloader:# 获取图片及其标签X,y=X.to(device),y.to(device)# 计算预测误差pred=model(X)# 网络输出loss=loss_fn(pred,y)# 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()# grad属性归零loss.backward()# 反向传播optimizer.step()# 每一步自动更新# 记录acc与losstrain_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()train_loss+=loss.item()train_acc/=size train_loss/=num_batchesreturntrain_acc,train_loss

3.编写测试函数

deftest(dataloader,model,loss_fn):size=len(dataloader.dataset)# 测试集的大小,一共10000张图片num_batches=len(dataloader)# 批次数目,313(10000/32=312.5,向上取整)test_loss,test_acc=0,0# 当不进行训练时,停止梯度更新,节省计算内存消耗withtorch.no_grad():forimgs,targetindataloader:imgs,target=imgs.to(device),target.to(device)# 计算losstarget_pred=model(imgs)loss=loss_fn(target_pred,target)test_loss+=loss.item()test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()test_acc/=size test_loss/=num_batchesreturntest_acc,test_loss

4.正式训练

epochs=10train_loss=[]train_acc=[]test_loss=[]test_acc=[]forepochinrange(epochs):model.train()epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,opt)model.eval()epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template=('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss))print('Done')

运行结果如下:

五、结果可视化

importmatplotlib.pyplotasplt#隐藏警告importwarnings warnings.filterwarnings("ignore")#忽略警告信息plt.rcParams['font.sans-serif']=['SimHei']# 用来正常显示中文标签plt.rcParams['axes.unicode_minus']=False# 用来正常显示负号plt.rcParams['figure.dpi']=100#分辨率fromdatetimeimportdatetime current_time=datetime.now()# 获取当前时间epochs_range=range(epochs)plt.figure(figsize=(12,3))plt.subplot(1,2,1)plt.plot(epochs_range,train_acc,label='Training Accuracy')plt.plot(epochs_range,test_acc,label='Test Accuracy')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.xlabel(current_time)# 打卡请带上时间戳,否则代码截图无效plt.subplot(1,2,2)plt.plot(epochs_range,train_loss,label='Training Loss')plt.plot(epochs_range,test_loss,label='Test Loss')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.show()

运行结果如下:

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/9 22:47:05

AI监管政策分析框架:从技术不确定性到全球治理的合规导航

1. 项目概述:当AI撞上“红绿灯”最近和几个做AI产品落地的朋友聊天,大家不约而同地提到了同一个词:合规。以前我们聊的都是模型精度、算力成本、用户增长,现在话题的焦点变成了“这个功能会不会触发监管红线”、“数据跨境怎么处理…

作者头像 李华
网站建设 2026/5/9 22:46:47

CANN电力负荷预测算子库

【免费下载链接】elec-ops-prediction elec-ops-prediction 是 CANN 社区 Electrical Engineering SIG(电力行业兴趣小组)旗下的电力负荷预测算子库, 聚焦于电力系统运行、调度、规划与市场交易中的预测核心需求,面向华为昇腾&…

作者头像 李华
网站建设 2026/5/9 22:46:37

CANN/hccl:华为集合通信库

HCCL 【免费下载链接】hccl 集合通信库(Huawei Collective Communication Library,简称HCCL)是基于昇腾AI处理器的高性能集合通信库,为计算集群提供高性能、高可靠的通信方案 项目地址: https://gitcode.com/cann/hccl &am…

作者头像 李华
网站建设 2026/5/9 22:45:52

CANN/opbase AI CPU任务接口

aicpu_task 【免费下载链接】opbase 本项目是CANN算子库的基础框架库,为算子提供公共依赖文件和基础调度能力。 项目地址: https://gitcode.com/cann/opbase 本章接口为预留接口,后续有可能变更或废弃,不建议开发者使用,开…

作者头像 李华
网站建设 2026/5/9 22:45:35

SeCAM:融合Grad-CAM与LIME优势的可解释AI新方法

1. 项目概述:为什么我们需要“新”的可解释性?在图像分类任务里,模型预测的准确率早已不是唯一的衡量标准。一个能告诉你“为什么”的模型,其价值正变得和它“是什么”一样重要。想象一下,在医疗影像诊断中&#xff0c…

作者头像 李华
网站建设 2026/5/9 22:45:24

BUUCTF [SUCTF 2019]CheckIn1

一 查看标题和源码标题中文翻译过来是:办理入住。确实也联想不到什么,但是打开靶场可以看见显而易见是一道上传漏洞题目,也不需要看源码了(一般不会在源码上给提示)。我们可以先上传一个普通图片试试可以明显的看到上传…

作者头像 李华