YOLOv5中使用torch加载自定义模型进行目标检测
在智能安防、工业质检和机器人视觉等实际场景中,我们常常面临这样一个问题:训练好的YOLOv5模型如何快速部署到真实环境中?很多开发者卡在“训练完成却不会用”的尴尬阶段——明明.pt文件就躺在runs/train/exp/weights/目录下,但一到推理环节就报错不断。
其实,PyTorch提供了一个极其简洁的解决方案:torch.hub.load。它不仅能从GitHub一键拉取预训练模型,更重要的是支持本地加载自定义权重,让我们跳过繁琐的模型重建过程,直接进入推理验证阶段。本文将带你彻底打通这一关键链路。
深入理解YOLOv5的加载机制
YOLOv5之所以能在工业界广泛落地,除了其出色的精度-速度平衡外,一个核心优势就是极简的部署接口设计。不同于传统做法需要手动定义网络结构再加载权重,YOLOv5通过hubconf.py暴露统一入口,使得外部调用变得异常简单。
以torch.hub.load为例,它的本质是动态导入指定仓库中的模型构造函数,并执行初始化。对于本地自定义模型,最关键的是三个参数的配合:
model = torch.hub.load( repo_or_dir='./yolov5', # 必须指向项目根目录 model='custom', # 告诉hub这是用户训练的权重 path='runs/train/exp/weights/best.pt', # 权重路径 source='local' # 禁止联网下载 )这里有个容易被忽视的细节:repo_or_dir不是随便填个路径就行,它必须包含完整的YOLOv5代码结构(如models/,utils/等),否则会因找不到模块而抛出ModuleNotFoundError。这也就是为什么很多人直接传.pt路径会失败——PyTorch并不知道这个权重对应什么网络结构。
更进一步,model='custom'这个设定非常巧妙。它触发了内部的custom()函数,该函数会根据权重文件自动推断输入尺寸、类别数等信息,避免了硬编码带来的兼容性问题。这也是为何即使你修改了data.yaml里的类别名称,也能正确映射到输出结果中的原因。
实战:摄像头实时检测实现
下面是一个经过生产环境验证的完整脚本,适用于大多数基于PC或边缘设备(如Jetson)的实时检测任务。
import cv2 import torch import sys import os # === 添加项目路径 === YOLOV5_ROOT = 'D:\\projects\\yolov5' # ⚠️ 修改为你的实际路径 if not os.path.exists(YOLOV5_ROOT): raise FileNotFoundError(f"未找到YOLOv5项目目录: {YOLOV5_ROOT}") sys.path.insert(0, YOLOV5_ROOT) # === 加载模型 === MODEL_WEIGHTS = os.path.join(YOLOV5_ROOT, 'runs', 'train', 'exp', 'weights', 'best.pt') if not os.path.exists(MODEL_WEIGHTS): raise FileNotFoundError(f"模型权重不存在: {MODEL_WEIGHTS}") print("🔧 正在加载模型...") model = torch.hub.load( repo_or_dir=YOLOV5_ROOT, model='custom', path=MODEL_WEIGHTS, source='local' ) model.conf = 0.4 # 设置全局置信度阈值 print("✅ 模型加载成功") # === 启动摄像头 === cap = cv2.VideoCapture(0) if not cap.isOpened(): print("❌ 无法打开摄像头") exit() print("🎥 摄像头已启动,按 'q' 键退出") while True: ret, frame = cap.read() if not ret: print("⚠️ 视频流中断") break # 注意:YOLOv5期望RGB输入 img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = model(img_rgb) # 使用内置render方法绘制结果(省去cv2.rectangle循环) annotated_frame = results.render()[0] annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) cv2.imshow('YOLOv5 Detection', annotated_frame) if cv2.waitKey(1) & 0xFF == ord('q'): break cap.release() cv2.destroyAllWindows()有几个工程实践中总结的经验值得强调:
- 绝对路径优先:虽然相对路径看起来更优雅,但在多级目录或打包发布时极易出错。建议始终使用
os.path.join拼接路径。 - 前置校验:在加载前检查路径是否存在,能极大提升调试效率。比起运行时报错,提前发现问题显然更友好。
- 动态阈值设置:通过
model.conf = 0.4可全局调整检测灵敏度,无需重新训练模型即可适应不同光照或遮挡场景。
进阶技巧:获取结构化检测结果
可视化只是第一步,真正有价值的是对检测结果做进一步处理。比如在自动化产线上,我们需要统计缺陷产品的数量;在监控系统中,则可能要记录陌生人出现的时间段。
这时就要用到results.pandas().xyxy[0],它返回一个Pandas DataFrame,结构清晰且易于操作:
| xmin | ymin | xmax | ymax | confidence | class | name |
|---|---|---|---|---|---|---|
| 120 | 80 | 280 | 200 | 0.95 | 0 | defect |
| 300 | 100 | 400 | 180 | 0.87 | 1 | normal |
我们可以轻松实现各种逻辑:
df = results.pandas().xyxy[0] # 筛选高置信度结果 high_conf = df[df['confidence'] > 0.8] # 统计各类别数量 counts = high_conf['name'].value_counts() # 保存到CSV用于后续分析 df.to_csv('detection_log.csv', index=False, mode='a', header=not os.path.exists('detection_log.csv'))甚至可以结合OpenCV做精准裁剪:
for _, row in high_conf.iterrows(): x1, y1, x2, y2 = int(row['xmin']), int(row['ymin']), int(row['xmax']), int(row['ymax']) crop = frame[y1:y2, x1:x2] cv2.imwrite(f"crop_{row['name']}_{row['confidence']:.2f}.jpg", crop)这种“检测+分析+存储”的流水线模式,在实际项目中极为常见。
常见陷阱与避坑指南
尽管流程看似简单,但在真实部署中仍有不少“坑”需要注意:
1. 路径问题导致的导入失败
最常见的错误是:
ModuleNotFoundError: No module named 'models'根源在于Python找不到models/common.py这类模块。解决方法只有两个:
- 确保sys.path包含了YOLOv5根目录
- 或者将整个项目作为包安装(pip install -e .)
推荐前者,轻量且可控。
2. GPU/CPU不匹配
如果你在GPU上训练但想在无显卡设备上运行,记得加上设备参数:
model = torch.hub.load(..., device='cpu')否则会出现类似RuntimeError: Attempting to deserialize object on a CUDA device的错误。
3. 自定义类名未生效
有些用户发现输出还是显示class 0而不是person。这是因为data.yaml没有随模型一起保存。正确的做法是在训练时确保save_dir中保留原始配置文件,或者手动指定names列表。
当一切配置妥当后,你会看到这样的画面:
📷实时检测效果示意:
+-------------------------------------------+ | | | 🔧 defect ✅ normal | | ┌─────────┐ ┌──────────────┐ | | │ │ │ │ | | │ │ │ │ | | └─────────┘ └──────────────┘ | | | | YOLOv5 Real-time Detection | +-------------------------------------------+同时终端输出详细信息:
🔍 当前帧检测结果: 📦 defect | 置信度: 0.963 | 位置: (120, 80, 280, 200) 📦 normal | 置信度: 0.891 | 位置: (300, 100, 400, 180)在GTX 1660上,YOLOv5s模型可稳定达到50+ FPS,完全满足多数实时应用需求。
这套基于torch.hub.load的加载方式,已经成为连接训练与部署的事实标准。它不仅简化了开发流程,更重要的是提升了系统的可维护性——只需替换.pt文件即可更新模型,无需改动任何推理代码。
未来你可以在此基础上做更多扩展:
- 将模型导出为ONNX/TensorRT格式,部署到嵌入式设备
- 使用Flask封装成REST API,供其他系统调用
- 集成DeepSORT实现多目标跟踪
技术的魅力就在于此:一个简单的接口背后,蕴藏着通往无限可能的大门。掌握它,你就掌握了将AI想法变为现实的第一把钥匙。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考