news 2026/7/5 1:40:46

【图像分类】实战ResNet——从零构建到CIFAR-10分类(Pytorch)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【图像分类】实战ResNet——从零构建到CIFAR-10分类(Pytorch)

1. 初识ResNet:为什么它能解决深度神经网络的瓶颈问题

第一次接触ResNet是在处理一个图像分类项目时,当时我遇到了所有深度学习工程师都会面临的经典问题:随着网络层数增加,模型性能不升反降。这就像给小孩子叠积木,叠得越高反而越容易倒塌。ResNet的提出者何恺明团队用"残差学习"的概念完美解决了这个问题。

残差块(Residual Block)的设计其实非常巧妙。想象你在学习骑自行车,如果直接学习完整的骑行动作很困难,但如果你已经会骑三轮车,现在只需要学习"保持平衡"这个差异部分,学习难度就大大降低了。ResNet正是采用了这种思想,通过shortcut connection让网络只需要学习当前输出与输入之间的残差(差异部分)。

我常用的ResNet-18和ResNet-34都采用基础残差块(BasicBlock),它们的结构对比如下:

组件ResNet-18ResNet-34
卷积层总数1834
残差块类型BasicBlockBasicBlock
参数量(M)11.721.8
ImageNet Top1准确率69.76%73.30%

在实际项目中,当计算资源有限时,我通常会先尝试ResNet-18。它的参数量只有11.7M,在CIFAR-10这种小规模数据集上训练速度很快,而且准确率也能达到不错的效果。记得第一次在Colab上跑ResNet-18时,只用15分钟就完成了训练,测试准确率轻松突破85%,这让我深刻体会到好模型不在于复杂,而在于设计巧妙。

2. 环境准备与数据加载:打造高效的PyTorch工作流

搭建环境就像准备厨房,工具齐全才能做出好菜。我习惯用conda创建独立环境,避免包版本冲突。以下是完整的安装步骤:

conda create -n resnet python=3.8 conda activate resnet pip install torch torchvision torchaudio pip install matplotlib tqdm numpy

加载CIFAR-10数据集时,有几个细节需要特别注意。第一次使用时我犯了个错误:直接下载的图片没有做归一化,导致模型难以收敛。正确的做法是使用torchvision提供的标准化参数:

from torchvision import transforms from torchvision.datasets import CIFAR10 transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) train_set = CIFAR10(root='./data', train=True, download=True, transform=transform_train) test_set = CIFAR10(root='./data', train=False, download=True, transform=transform_test)

数据增强是提升模型泛化能力的关键。在CIFAR-10这种小数据集上,我通常会加入随机水平翻转和随机裁剪。曾经对比过使用和不使用数据增强的效果,在ResNet-18上准确率相差近5个百分点!

创建数据加载器时,batch_size的设置很有讲究。在我的RTX 3060显卡上,32-64是比较理想的范围。太大会导致显存不足,太小则无法充分利用GPU并行计算能力:

from torch.utils.data import DataLoader train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=4) test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)

3. 构建ResNet模型:从残差块到完整网络

实现ResNet的核心在于正确构建残差块。我第一次实现时犯了个典型错误:忘记在shortcut连接中添加1x1卷积当维度不匹配时。这导致模型根本无法训练,损失值居高不下。正确的BasicBlock实现应该是这样的:

import torch.nn as nn class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(out_channels) self.shortcut = nn.Sequential() if stride != 1 or in_channels != self.expansion * out_channels: self.shortcut = nn.Sequential( nn.Conv2d( in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False ), nn.BatchNorm2d(self.expansion * out_channels) ) def forward(self, x): out = nn.ReLU()(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = nn.ReLU()(out) return out

完整ResNet的构建需要特别注意层与层之间的通道数变化。ResNet-18的结构可以分为以下几个部分:

  1. 初始卷积层:7x7卷积 -> 这个在CIFAR-10上我改成了3x3,因为图像尺寸较小
  2. 四个残差层:每层包含多个BasicBlock
  3. 全局平均池化和全连接层
class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super().__init__() self.in_channels = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_channels, out_channels, stride)) self.in_channels = out_channels * block.expansion return nn.Sequential(*layers) def forward(self, x): x = nn.ReLU()(self.bn1(self.conv1(x))) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x

对于CIFAR-10,我们需要调整网络输入部分,因为原始ResNet是为ImageNet设计的。主要修改包括:

  • 将第一个7x7卷积改为3x3卷积
  • 去掉初始的max pooling层
  • 最后的平均池化改为自适应池化

4. 模型训练与调优:从基础训练到高级技巧

训练神经网络就像教小朋友学习,既要有耐心又要讲究方法。我总结了一套有效的训练流程:

基础训练配置

import torch.optim as optim model = ResNet(BasicBlock, [2, 2, 2, 2]).to(device) # ResNet-18 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)

学习率设置很关键,我习惯用学习率预热策略。初期用小学习率(如0.01),5个epoch后再调到0.1:

for epoch in range(5): # Warmup train(..., lr=0.01) for epoch in range(5, 200): train(..., lr=0.1)

训练过程中的重要技巧

  1. 混合精度训练:可以显著减少显存占用,加快训练速度
scaler = torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): outputs = model(inputs) loss = criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
  1. 标签平滑:缓解过拟合,提升模型泛化能力
class LabelSmoothingCrossEntropy(nn.Module): def __init__(self, epsilon=0.1): super().__init__() self.epsilon = epsilon def forward(self, outputs, targets): n_classes = outputs.size(-1) log_preds = F.log_softmax(outputs, dim=-1) loss = -log_preds.mean() return loss * self.epsilon + (1 - self.epsilon) * F.nll_loss(log_preds, targets)
  1. 模型EMA:保持模型参数的滑动平均,提升最终性能
class ModelEMA: def __init__(self, model, decay=0.999): self.ema = deepcopy(model).eval() self.decay = decay def update(self, model): with torch.no_grad(): for ema_p, model_p in zip(self.ema.parameters(), model.parameters()): ema_p.mul_(self.decay).add_(model_p, alpha=1 - self.decay)

训练监控

我习惯用TensorBoard记录训练过程,关键指标包括:

  • 训练/验证损失曲线
  • 学习率变化
  • 分类准确率
  • 参数分布直方图
from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() for epoch in range(200): train_loss, train_acc = train(...) val_loss, val_acc = validate(...) writer.add_scalar('Loss/train', train_loss, epoch) writer.add_scalar('Loss/val', val_loss, epoch) writer.add_scalar('Accuracy/train', train_acc, epoch) writer.add_scalar('Accuracy/val', val_acc, epoch)

5. 模型评估与可视化:深入理解模型行为

训练完成后,我们需要全面评估模型性能。基础的准确率指标远远不够,我通常会从以下几个维度进行分析:

1. 混淆矩阵分析

from sklearn.metrics import confusion_matrix import seaborn as sns conf_mat = confusion_matrix(all_labels, all_preds) plt.figure(figsize=(10, 8)) sns.heatmap(conf_mat, annot=True, fmt='d', xticklabels=classes, yticklabels=classes) plt.xlabel('Predicted') plt.ylabel('Actual')

2. 特征可视化

使用t-SNE降维可视化最后一层特征:

from sklearn.manifold import TSNE features = [] # 收集模型最后一层前的特征 labels = [] with torch.no_grad(): for data, target in test_loader: data = data.to(device) feature = model.conv1(data) # ... 通过所有层直到最后一层前 features.append(feature.cpu()) labels.append(target) features = torch.cat(features).numpy() labels = torch.cat(labels).numpy() tsne = TSNE(n_components=2, random_state=42) features_2d = tsne.fit_transform(features) plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels, alpha=0.6) plt.colorbar()

3. 类激活图(CAM)可视化

理解模型关注哪些区域:

class CamExtractor: def __init__(self, model): self.model = model self.gradients = None def save_gradient(self, grad): self.gradients = grad def forward_pass(self, x): conv_output = None for name, module in self.model.named_children(): x = module(x) if name == 'layer4': # 最后一个卷积层 x.register_hook(self.save_gradient) conv_output = x return conv_output, x # 使用示例 extractor = CamExtractor(model) conv_output, model_output = extractor.forward_pass(input_img) model_output = model_output[:, target_class] conv_output.backward() gradients = extractor.gradients pooled_gradients = torch.mean(gradients, dim=[0, 2, 3]) for i in range(conv_output.shape[1]): conv_output[:, i, :, :] *= pooled_gradients[i] heatmap = torch.mean(conv_output, dim=1).squeeze().cpu().numpy() heatmap = np.maximum(heatmap, 0) heatmap /= np.max(heatmap) # 叠加到原图 img = input_img.squeeze().permute(1, 2, 0).cpu().numpy() heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) heatmap = np.uint8(255 * heatmap) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) superimposed_img = heatmap * 0.4 + img * 255

4. 错误分析

收集模型预测错误的样本,分析共同特征:

error_indices = np.where(all_preds != all_labels)[0] error_samples = test_set.data[error_indices] error_preds = all_preds[error_indices] error_labels = all_labels[error_indices] plt.figure(figsize=(15, 10)) for i in range(25): plt.subplot(5, 5, i+1) plt.imshow(error_samples[i]) plt.title(f'P:{classes[error_preds[i]]} A:{classes[error_labels[i]]}') plt.axis('off')

通过这些分析,我们可以发现模型在哪些类别上容易混淆,哪些特征被过度关注等问题。比如在CIFAR-10上,猫和狗、汽车和卡车常常是模型容易混淆的类别。针对这些问题,可以采取数据增强、类别重加权等改进措施。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/7/5 1:40:12

Agent记忆系统设计与实现

一个 Agent 有没有记忆,很大程度上决定了它只是用完即弃的工具,还是能越用越懂你的搭档。没有记忆的 Agent,每次会话都从零开始,你得反复交代项目背景、代码风格、踩过的坑;有记忆的 Agent,能跨会话保留对你…

作者头像 李华
网站建设 2026/7/5 1:39:31

别把知识图谱做成高级文档库——定制化做企业级知识图谱

别把知识图谱做成高级文档库 知识图谱的价值,不是把文档连成网,而是让知识可以被治理 最近我们在做一个知识图谱项目,越做越觉得,很多人对图谱的期待其实放错了地方。 大家一听“知识图谱”,脑子里很容易出现一张很…

作者头像 李华
网站建设 2026/7/5 1:38:06

【面板数据模型实战】从理论到Stata/R/Python实现与选择

1. 面板数据模型入门:从超市会员卡说起想象你是一家连锁超市的数据分析师,手上有过去三年每位会员的月度消费记录。这些数据既有横向维度(不同会员),又有纵向维度(不同月份),这就是典…

作者头像 李华
网站建设 2026/7/5 1:34:20

Rmarkdown动态文档创作与数据科学报告实战指南

1. Rmarkdown核心价值解析Rmarkdown是数据科学领域革命性的文档创作工具,它将代码执行、文本叙述和可视化输出完美融合在一个可重复的工作流中。我使用Rmarkdown五年多来,它彻底改变了我的分析报告产出方式——从枯燥的代码截图拼接模式,升级…

作者头像 李华
网站建设 2026/7/5 1:32:14

【HarmonyOS NEXT】error: failed to install bundle. code:9568322...

🎯 核心原因一:手动签名配置了发布证书(Release Profile)这是最常见的原因之一。发布证书签名的应用,无法直接通过hdc命令安装到真机进行调试。现象:你按照文档配置了生产环境的Profile,设备也添…

作者头像 李华