news 2026/1/10 12:25:58

Day 39 模型可视化与推理

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day 39 模型可视化与推理

@浙大疏锦行

一、nn.Module核心自带方法

nn.Module封装了模型的核心逻辑,以下是高频使用的自带方法,按功能分类:

1. 模型状态控制(训练 / 评估模式)

方法作用
model.train()切换为训练模式:启用 Dropout、BatchNorm 等层的训练行为(默认模式)
model.eval()切换为评估模式:关闭 Dropout、固定 BatchNorm 均值 / 方差,用于推理 / 验证
model.training属性,返回布尔值:True= 训练模式,False= 评估模式

示例

import torch import torch.nn as nn class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 16, 3) self.dropout = nn.Dropout(0.5) # 训练时随机失活,评估时关闭 def forward(self, x): x = self.conv(x) x = self.dropout(x) return x model = SimpleCNN() print(model.training) # True(默认训练模式) model.eval() print(model.training) # False(评估模式,dropout失效) model.train() print(model.training) # True(切回训练模式)

2. 设备迁移(CPU/GPU)

方法作用
model.to(device)将模型所有参数 / 缓冲区移到指定设备(cuda/cpu/mps),返回模型实例
model.cuda()快捷方式:移到默认 GPU(等价于model.to('cuda')
model.cpu()快捷方式:移到 CPU(等价于model.to('cpu')

示例

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # 模型移到GPU/CPU # 验证设备 print(next(model.parameters()).device) # 输出:cuda:0 或 cpu

3. 参数管理(查看 / 遍历参数)

方法作用
model.parameters()返回生成器:包含所有可训练参数(nn.Parameter类型)
model.named_parameters()返回生成器:(参数名,参数张量),便于定位参数
model.named_parameters()返回生成器:(参数名,参数张量),便于定位参数
model.state_dict()返回字典:{参数名:参数值},用于保存模型参数
model.load_state_dict()加载参数字典,用于恢复模型

示例

# 查看所有参数名称和形状 for name, param in model.named_parameters(): print(f"参数名:{name},形状:{param.shape},设备:{param.device}") # 统计总参数量(手动实现,无第三方库时用) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"总参数:{total_params},可训练参数:{trainable_params}")

4. 结构遍历(查看模型层)

方法作用
model.children()返回生成器:仅包含直接子层(如 Sequential 内的第一层),不递归
model.named_children()返回生成器:(层名,子层),仅直接子层
model.modules()返回生成器:递归包含所有层(包括嵌套层)
model.named_modules()返回生成器:(层名,层),递归所有层

示例

# 定义嵌套模型 class NestedModel(nn.Module): def __init__(self): super().__init__() self.block1 = nn.Sequential( nn.Conv2d(3, 16, 3), nn.ReLU() ) self.block2 = nn.Linear(16*30*30, 10) model = NestedModel() # children():仅直接子层(block1、block2) print("=== children() ===") for name, layer in model.named_children(): print(name, layer) # modules():递归所有层(包括Sequential内的Conv2d、ReLU) print("\n=== modules() ===") for name, layer in model.named_modules(): print(name, layer)

5. 前向传播与梯度

方法作用
model.forward(x)手动调用前向传播(不推荐),建议直接model(x)(调用__call__
model(x)等价于model.__call__(x),自动执行 forward + 钩子(hook)逻辑
model.zero_grad()清空所有参数的梯度(训练时反向传播前必须调用)

示例

x = torch.randn(1, 3, 32, 32).to(device) output = model(x) # 推荐:调用__call__,等价于model.forward(x) + 钩子 model.zero_grad() # 清空梯度 output.sum().backward() # 反向传播计算梯度

二、torchsummary库的summary方法

torchsummary是早期轻量库,核心功能是快速打印模型层结构、输出形状、总参数量,仅支持单输入模型,对嵌套模型 / 多输入支持差,维护较少。

1. 安装与基本用法

pip install torchsummary
from torchsummary import summary # 定义模型(输入:3通道32×32图像) class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(16, 32, 3, padding=1) self.fc1 = nn.Linear(32 * 8 * 8, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.pool(nn.functional.relu(self.conv1(x))) x = self.pool(nn.functional.relu(self.conv2(x))) x = x.view(-1, 32 * 8 * 8) x = nn.functional.relu(self.fc1(x)) x = self.fc2(x) return x # 设备配置 + 模型初始化 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SimpleCNN().to(device) # 调用summary:参数(模型,输入形状(通道,高,宽),batch_size可选) summary(model, input_size=(3, 32, 32), batch_size=1)

2. 输出解读

---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [1, 16, 32, 32] 448 MaxPool2d-2 [1, 16, 16, 16] 0 Conv2d-3 [1, 32, 16, 16] 4,640 MaxPool2d-4 [1, 32, 8, 8] 0 Linear-5 [1, 128] 262,272 Linear-6 [1, 10] 1,290 ================================================================ Total params: 268,650 Trainable params: 268,650 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.01 Forward/backward pass size (MB): 0.29 Params size (MB): 1.02 Estimated Total Size (MB): 1.32 ----------------------------------------------------------------

3. 优缺点

优点缺点
极简、无多余依赖仅支持单输入模型
输出简洁、易理解对嵌套模型 / 多分支模型支持差
快速查看参数量 / 形状无批次维度、无内存占用细分
支持 GPU/CPU维护停滞,仅兼容 PyTorch 旧版本

三、torchinfo库的summary方法(推荐)

torchinfotorchsummary的升级版(原torchsummaryX),解决了多输入、嵌套模型、维度展示不清晰的问题,功能更全面,是当前 PyTorch 模型可视化的首选。

1. 安装与基本用法

pip install torchinfo
from torchinfo import summary # 复用上面的SimpleCNN模型 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SimpleCNN().to(device) # 核心参数:model, input_size, batch_dim, device, col_width等 summary( model, input_size=(1, 3, 32, 32), # (batch_size, 通道, 高, 宽) batch_dim=0, # 批次维度的位置(默认0) device=device, # 模型设备 col_width=20, # 列宽 col_names=["input_size", "output_size", "num_params", "trainable"], # 显示列 row_settings=["var_names"] # 显示层变量名 )

2. 输出解读

========================================================================================== Layer (type (var_name)) Input Shape Output Shape Param # Trainable ========================================================================================== SimpleCNN (SimpleCNN) [1, 3, 32, 32] [1, 10] -- -- ├─Conv2d (conv1) [1, 3, 32, 32] [1, 16, 32, 32] 448 True ├─MaxPool2d (pool) [1, 16, 32, 32] [1, 16, 16, 16] -- -- ├─Conv2d (conv2) [1, 16, 16, 16] [1, 32, 16, 16] 4,640 True ├─MaxPool2d (pool) [1, 32, 16, 16] [1, 32, 8, 8] -- -- ├─Linear (fc1) [1, 2048] [1, 128] 262,272 True ├─Linear (fc2) [1, 128] [1, 10] 1,290 True ========================================================================================== Total params: 268,650 Trainable params: 268,650 Non-trainable params: 0 Total mult-adds (M): 2.15 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 0.29 Params size (MB): 1.02 Estimated Total Size (MB): 1.32 ==========================================================================================

四、推理的写法:评估模式

def evaluate_classification(model, dataloader, device): """ 分类模型评估:计算准确率、F1-score(宏平均)、混淆矩阵 """ # 1. 切换到评估模式(必须!) model.eval() # 2. 初始化指标容器 all_preds = [] all_labels = [] # 3. 关闭梯度计算(加速+省显存) with torch.no_grad(): for batch_idx, (x, y) in enumerate(dataloader): # 数据移到设备 x = x.to(device, dtype=torch.float32) y = y.to(device, dtype=torch.long) # 4. 推理(前向传播) outputs = model(x) # 输出:(batch_size, num_classes) preds = torch.argmax(outputs, dim=1) # 取概率最大的类别 # 5. 收集预测结果和真实标签(转回CPU便于计算指标) all_preds.extend(preds.cpu().numpy()) all_labels.extend(y.cpu().numpy()) # 可选:打印进度 if (batch_idx + 1) % 10 == 0: print(f"Batch [{batch_idx+1}/{len(dataloader)}] 完成") # 6. 计算评估指标 accuracy = accuracy_score(all_labels, all_preds) f1_macro = f1_score(all_labels, all_preds, average="macro") # 宏平均F1(适合类别均衡) f1_weighted = f1_score(all_labels, all_preds, average="weighted") # 加权F1(适合类别不均衡) # 7. 打印结果 print("="*50) print(f"分类模型评估结果:") print(f"准确率 (Accuracy): {accuracy:.4f}") print(f"宏平均F1-score: {f1_macro:.4f}") print(f"加权F1-score: {f1_weighted:.4f}") print("="*50) return { "accuracy": accuracy, "f1_macro": f1_macro, "f1_weighted": f1_weighted, "preds": all_preds, "labels": all_labels } # 执行评估 eval_results = evaluate_classification(model, test_loader, device)
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/1/5 14:28:18

Object.defineProperty和Proxy实现拦截的区别

1.Object.definedProperty的实现拦截必须得声明一个额外的变量,例如下面这样 const obj {}; let _data "这是一些数据"; Object.defineProperty(obj, "data", {get() {console.log("读取data的操作被拦截了");return _data;}, }); …

作者头像 李华
网站建设 2026/1/6 5:28:43

若依物联网

物联网平台 - Thinglinks-iot ## 🌟 项目简介 一个功能完备、高可扩展的物联网平台,提供完整的设备接入、管理和数据处理解决方案。支持多种网络协议,具备强大的消息解析和实时告警能力,帮助企业快速构建物联网应用。 该项目现已纳…

作者头像 李华
网站建设 2026/1/8 15:04:26

PSEN1抗体:如何揭示阿尔茨海默病致病机制与治疗新靶点?

一、PSEN1基因为何成为神经退行性疾病研究的关键靶点? PSEN1(早老素1)基因位于人类14号染色体q24.2区域,全长87kb,包含14个外显子,编码由467个氨基酸组成的跨膜蛋白,分子量约为53kD。该基因在进…

作者头像 李华
网站建设 2026/1/4 12:48:47

Docker Engine 升级指南:保障容器安全的关键步骤

无论是为了获得新功能、性能优化,还是更关键的——为了修复重大的安全漏洞(如 runc 漏洞 CVE-2024-21626),定期升级 Docker Engine 都是容器基础设施运维中的一项重要任务。 本篇文章将为您提供一个通用的升级流程,确保…

作者头像 李华