news 2026/2/2 20:46:02

Day 37 - 早停策略与模型权重的保存

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day 37 - 早停策略与模型权重的保存

在深度学习的训练过程中,我们经常面临两个核心问题:“训练到什么时候停止?”“训练好的模型怎么存?”

如果训练时间太短,模型欠拟合;训练时间太长,模型过拟合。手动盯着Loss曲线决定何时停止既累人又不精确。早停策略 (Early Stopping)就是为了解决这个问题而生的自动化工具。而模型保存则是将我们消耗算力炼出的“丹”(模型参数)持久化存储的关键步骤。


一、 过拟合与监控机制

1.1 什么是过拟合的信号?

在训练过程中,我们通常会观察到以下现象:

  • 训练集 Loss:持续下降(因为模型在死记硬背训练数据)。
  • 测试集/验证集 Loss:先下降,达到一个最低点后,开始震荡甚至反弹上升。

关键点:当训练集 Loss 下降但测试集 Loss 不再下降(甚至上升)时,就是过拟合的开始。这就是我们应该停止训练的最佳时机。

1.2 如何监控?

我们需要在训练循环中,每隔一定的 Epoch(例如每1个或每100个Epoch),暂停训练模式,切换到评估模式 (model.eval()),计算测试集上的 Loss。

# 伪代码逻辑 for epoch in range(num_epochs): train(...) # 训练 if epoch % check_interval == 0: model.eval() test_loss = validate(...) # 验证 print(f"Train Loss: {train_loss}, Test Loss: {test_loss}")

二、 早停策略 (Early Stopping) 实战

早停策略的核心思想是:给模型几次机会(Patience),如果它在验证集上的表现连续几次都没有提升,那就强制停止。

2.1 核心参数

  • best_score/best_loss: 记录历史最好的指标。
  • patience(耐心值): 允许模型连续多少次没有提升。比如设为 10,意味着即使 Loss 上升了,我也再等你 10 轮,万一后面又降了呢?
  • counter: 计数器,记录连续没有提升的次数。
  • min_delta: 只有当提升幅度超过这个阈值时,才算作“提升”(防止微小的抖动被误判)。

2.2 代码实现模板

这是一个可以直接套用的标准早停逻辑代码块:

# 初始化早停参数 best_test_loss = float('inf') # 初始最佳Loss设为无穷大 patience = 20 # 耐心值:20轮不降就停 counter = 0 # 计数器 early_stopped = False # 停止标志 for epoch in range(num_epochs): # ... (训练代码省略) ... # --- 验证阶段 --- if (epoch + 1) % 10 == 0: # 假设每10轮验证一次 model.eval() with torch.no_grad(): outputs = model(X_test) test_loss = criterion(outputs, y_test) model.train() # 切回训练模式 current_loss = test_loss.item() # --- 早停核心逻辑 --- if current_loss < best_test_loss: # 情况1:Loss 创新低(表现更好) best_test_loss = current_loss counter = 0 # 重置计数器 # 【关键】保存当前最好的模型,防止后面训练这就“烂”了 torch.save(model.state_dict(), 'best_model.pth') print(f"Epoch {epoch}: New best loss {best_test_loss:.4f}, model saved.") else: # 情况2:Loss 没创新低(表现变差或持平) counter += 1 print(f"Epoch {epoch}: No improvement. Counter {counter}/{patience}") if counter >= patience: print("早停触发!停止训练。") early_stopped = True break # 跳出训练循环

重要提示

早停触发后,模型当前的状态通常已经过拟合了(因为最后patience轮都在变差)。所以,必须在训练结束后,重新加载我们中间保存的那个best_model.pth,那才是真正的最佳模型。

if early_stopped: model.load_state_dict(torch.load('best_model.pth')) print("已回滚至最佳模型参数。")

三、 模型的保存与加载

PyTorch 提供了多种保存方式,但在工业界和学术界,只保存参数(state_dict)是绝对的主流和最佳实践。

3.1 方式一:仅保存模型参数 (推荐) ⭐⭐⭐⭐⭐

这是最轻量级、最灵活的方式。它只保存模型的权重(Tensor数据),不保存模型的类定义。

  • 保存
# model.state_dict() 是一个字典,包含所有层的权重 torch.save(model.state_dict(), "model_weights.pth")
  • 加载

需要先实例化模型对象(代码中必须有class MLP(...)的定义),然后把参数填进去。

model = MLP() # 1. 先实例化结构 model.load_state_dict(torch.load("model_weights.pth")) # 2. 填充参数 model.eval() # 3. 如果用于推理,记得切到eval模式

3.2 方式二:保存整个模型 (不推荐) ⭐

这种方式会把模型结构和参数打包一起存。

  • 保存torch.save(model, "full_model.pth")
  • 加载model = torch.load("full_model.pth")
  • 缺点:它严重依赖代码目录结构。如果你把代码发给别人,或者把模型类定义的 py 文件改了个名字/移了个位置,加载就会直接报错 (AttributeError)。它是基于 Python 的pickle序列化的,非常脆弱。

3.3 方式三:保存 Checkpoint (断点续训) ⭐⭐⭐⭐

如果你跑一个大模型需要训练几天几夜,你肯定不希望电脑死机后重头再来。这时需要保存所有训练状态

  • 保存
checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), # 优化器也有参数(如动量),必须存! 'loss': loss, } torch.save(checkpoint, 'checkpoint.pth')
  • 加载与恢复
model = MLP() optimizer = optim.Adam(model.parameters()) checkpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] # 从断掉的下一轮开始 # 继续训练... for epoch in range(start_epoch, num_epochs): ...

四、 总结

  1. 不要盲目训练:始终监控测试集 Loss,它是过拟合的“报警器”。
  2. 早停是标配:设置合理的patience(通常 10-50,视数据波动情况而定),配合best_model保存机制,可以让你获得泛化能力最好的模型。
  3. 只存参数:养成使用model.state_dict()的好习惯,避免使用torch.save(model)
  4. 断点保护:对于长时训练,务必定期保存 Checkpoint,包含优化器状态。
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/2/3 7:40:08

计算机毕业设计springboot基于Spark++Vue.js的学生管理系统 Spark+Vue 高校学生综合信息管理平台 基于 SpringBoot+Spark+Vue 的全链路学生事务中心

计算机毕业设计springboot基于SparkVue.js的学生管理系统i2kn7p36 &#xff08;配套有源码 程序 mysql数据库 论文&#xff09; 本套源码可以在文本联xi,先看具体系统功能演示视频领取&#xff0c;可分享源码参考。在“数据即资产”的校园时代&#xff0c;传统 Excel 与人工流转…

作者头像 李华
网站建设 2026/1/31 17:36:55

为什么 C盘空间会莫名其妙减少(即使没装新软件)?

为什么 C盘空间会莫名其妙减少&#xff08;即使没装新软件&#xff09;&#xff1f;你有没有注意到c盘空间在减少&#xff0c;即使你没有安装新程序, 这个常见问题可能让人担心, 但通常有明确原因, windows和其他软件会定期创建临时文件、系统备份和更新, 占用磁盘空间而不会每…

作者头像 李华
网站建设 2026/1/21 18:14:58

17、深入理解 Linux 文件系统机制与结构

深入理解 Linux 文件系统机制与结构 1. 理解长格式文件列表 在 Linux 中,使用 ls -la 命令可以查看详细的文件列表信息,示例输出如下: drwx------ 2 dee dee 4096 Jul 29 07:48 . drwxr-xr-x 5 root root 4096 Jul 27 11:57 .. -rw-r--r-- 1 dee dee 24 Jul 27 …

作者头像 李华
网站建设 2026/1/27 9:43:43

29、Linux 软件使用与故障排除指南

Linux 软件使用与故障排除指南 1. VMWare 和 Wine 软件介绍 VMWare : 缺点 :运行 VMWare 需要系统有额外的性能支持,使用前需查看其系统要求,并尽量让系统配置高于该要求。 优点 :它在独立窗口中运行,几乎等同于拥有另一台计算机。 Wine : 简介 :Wine(www.wi…

作者头像 李华
网站建设 2026/1/30 23:24:41

从入门到转行:网络安全自学与跳槽的终极建议

目录 为什么写这篇文章 为什么我更合适回答这个问题 先问自己3个问题 1.一定要明确自己是否是真喜欢&#xff0c;还是一时好奇。 2.自学的习惯 3.选择网安、攻防这行的目标是什么&#xff1f; 确认无误后&#xff0c;那如何进入这个行业&#xff1f; 1.选择渗透测试集中…

作者头像 李华