news 2026/6/23 18:18:07

逻辑回归(Logistic Regression)进行多分类的实战

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
逻辑回归(Logistic Regression)进行多分类的实战

一、多分类策略
逻辑回归处理多分类主要有三种策略:

  1. OvR(One-vs-Rest)
    为每个类别训练一个二分类器

预测时选择概率最高的类别

Scikit-learn默认使用此方法

  1. OvO(One-vs-One)
    为每对类别训练一个分类器

适合类别较少但样本均衡的情况

  1. Softmax回归(Multinomial)
    直接输出多个类别的概率分布

使用交叉熵损失函数

二、完整实战代码示例

importnumpyasnpimportpandasaspdimportmatplotlib.pyplotaspltimportseabornassnsfromsklearnimportdatasetsfromsklearn.model_selectionimporttrain_test_split,cross_val_score,GridSearchCVfromsklearn.preprocessingimportStandardScalerfromsklearn.linear_modelimportLogisticRegressionfromsklearn.metricsimport(classification_report,confusion_matrix,accuracy_score,roc_curve,auc,roc_auc_score)fromsklearn.multiclassimportOneVsRestClassifierimportwarnings warnings.filterwarnings('ignore')# 设置中文显示plt.rcParams['font.sans-serif']=['SimHei']plt.rcParams['axes.unicode_minus']=False

步骤1:加载和准备数据

# 加载鸢尾花数据集(3个类别)iris=datasets.load_iris()X=iris.data y=iris.target feature_names=iris.feature_names target_names=iris.target_namesprint(f"特征形状:{X.shape}")print(f"标签形状:{y.shape}")print(f"类别:{target_names}")print(f"特征名:{feature_names}")# 查看数据分布print("\n类别分布:")fori,nameinenumerate(target_names):print(f"{name}:{np.sum(y==i)}个样本")# 划分训练集和测试集X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3,random_state=42,stratify=y)# 标准化特征scaler=StandardScaler()X_train_scaled=scaler.fit_transform(X_train)X_test_scaled=scaler.transform(X_test)

步骤2:模型训练与评估
方法1:使用默认的OvR策略

# 创建逻辑回归模型(默认使用OvR)model_ovr=LogisticRegression(multi_class='ovr',# One-vs-Restsolver='lbfgs',# 适用于小数据集max_iter=1000,random_state=42,C=1.0# 正则化强度,越小正则化越强)# 训练模型model_ovr.fit(X_train_scaled,y_train)# 预测y_pred_ovr=model_ovr.predict(X_test_scaled)y_pred_proba_ovr=model_ovr.predict_proba(X_test_scaled)# 评估print("=== OvR策略评估 ===")print(f"准确率:{accuracy_score(y_test,y_pred_ovr):.4f}")print("\n分类报告:")print(classification_report(y_test,y_pred_ovr,target_names=target_names))

方法2:使用Softmax回归

# 创建Softmax回归模型model_softmax=LogisticRegression(multi_class='multinomial',# Softmax回归solver='lbfgs',max_iter=1000,random_state=42,C=1.0)# 训练模型model_softmax.fit(X_train_scaled,y_train)# 预测y_pred_softmax=model_softmax.predict(X_test_scaled)# 评估print("\n=== Softmax回归评估 ===")print(f"准确率:{accuracy_score(y_test,y_pred_softmax):.4f}")print("\n分类报告:")print(classification_report(y_test,y_pred_softmax,target_names=target_names))

方法3:使用OneVsRestClassifier包装器

# 使用包装器实现OvRmodel_ovr_wrapper=OneVsRestClassifier(LogisticRegression(solver='lbfgs',max_iter=1000,random_state=42))model_ovr_wrapper.fit(X_train_scaled,y_train)y_pred_ovr_wrapper=model_ovr_wrapper.predict(X_test_scaled)print("\n=== OvR包装器评估 ===")print(f"准确率:{accuracy_score(y_test,y_pred_ovr_wrapper):.4f}")

步骤3:可视化分析

defplot_confusion_matrix(y_true,y_pred,class_names,title):"""绘制混淆矩阵"""cm=confusion_matrix(y_true,y_pred)plt.figure(figsize=(8,6))sns.heatmap(cm,annot=True,fmt='d',cmap='Blues',xticklabels=class_names,yticklabels=class_names)plt.title(f'混淆矩阵 -{title}',fontsize=14)plt.ylabel('真实标签')plt.xlabel('预测标签')plt.tight_layout()plt.show()# 绘制混淆矩阵plot_confusion_matrix(y_test,y_pred_ovr,target_names,"OvR策略")# 绘制特征重要性defplot_feature_importance(model,feature_names,target_names):"""绘制特征重要性(权重)"""ifhasattr(model,'coef_'):weights=model.coef_ fig,axes=plt.subplots(1,len(target_names),figsize=(15,5))fori,(ax,class_name)inenumerate(zip(axes,target_names)):ax.barh(feature_names,weights[i])ax.set_title(f'类别:{class_name}')ax.set_xlabel('权重')plt.suptitle('逻辑回归特征权重(每个类别的决策边界)',fontsize=14)plt.tight_layout()plt.show()plot_feature_importance(model_ovr,feature_names,target_names)

步骤4:概率可视化

# 绘制预测概率分布defplot_probability_distribution(y_pred_proba,y_true,target_names):"""绘制预测概率分布"""fig,axes=plt.subplots(1,3,figsize=(15,5))fori,(ax,class_name)inenumerate(zip(axes,target_names)):# 获取属于当前类别的样本的概率true_class_mask=(y_true==i)prob_for_class=y_pred_proba[true_class_mask,i]ax.hist(prob_for_class,bins=20,alpha=0.7,color='skyblue',edgecolor='black')ax.set_title(f'{class_name}的预测概率分布')ax.set_xlabel('预测概率')ax.set_ylabel('样本数')ax.grid(True,alpha=0.3)plt.suptitle('各类别预测概率分布',fontsize=14)plt.tight_layout()plt.show()plot_probability_distribution(y_pred_proba_ovr,y_test,target_names)

步骤5:模型调优

# 使用网格搜索寻找最佳参数param_grid={'C':[0.001,0.01,0.1,1,10,100],# 正则化强度'solver':['lbfgs','liblinear','saga'],'max_iter':[100,500,1000]}# 创建网格搜索grid_search=GridSearchCV(LogisticRegression(multi_class='ovr',random_state=42),param_grid,cv=5,scoring='accuracy',n_jobs=-1,verbose=1)# 执行网格搜索grid_search.fit(X_train_scaled,y_train)print("\n=== 网格搜索结果 ===")print(f"最佳参数:{grid_search.best_params_}")print(f"最佳交叉验证准确率:{grid_search.best_score_:.4f}")print(f"测试集准确率:{grid_search.score(X_test_scaled,y_test):.4f}")# 使用最佳模型best_model=grid_search.best_estimator_ y_pred_best=best_model.predict(X_test_scaled)print("\n=== 最佳模型评估 ===")print(classification_report(y_test,y_pred_best,target_names=target_names))

步骤6:交叉验证评估

# 交叉验证评估模型稳定性cv_scores=cross_val_score(best_model,X_train_scaled,y_train,cv=5,scoring='accuracy')print("\n=== 交叉验证结果 ===")print(f"交叉验证准确率:{cv_scores.mean():.4f}(+/-{cv_scores.std()*2:.4f})")print(f"各折准确率:{cv_scores}")# 绘制交叉验证结果plt.figure(figsize=(10,6))plt.plot(range(1,6),cv_scores,marker='o',linewidth=2,markersize=8)plt.axhline(y=cv_scores.mean(),color='r',linestyle='--',label=f'均值:{cv_scores.mean():.4f}')plt.fill_between(range(1,6),cv_scores.mean()-cv_scores.std(),cv_scores.mean()+cv_scores.std(),alpha=0.2,color='gray')plt.title('5折交叉验证准确率',fontsize=14)plt.xlabel('折数')plt.ylabel('准确率')plt.legend()plt.grid(True,alpha=0.3)plt.ylim([0.8,1.0])plt.show()

三、关键要点总结
1.策略选择:

类别较少且均衡:考虑OvO

类别较多:使用OvR或Softmax

Softmax通常更直接,但需要计算所有类别的概率

2.特征工程:

逻辑回归对特征缩放敏感,务必标准化

特征间的多重共线性会影响结果

3.正则化:

参数C控制正则化强度(C越小,正则化越强)

防止过拟合的重要工具

4.模型评估:

多分类使用准确率、混淆矩阵、分类报告

考虑使用宏平均和微平均

5.注意事项:

逻辑回归假设特征与log odds线性相关

对于非线性问题,需要特征工程或使用核方法

类别不平衡时需要调整class_weight参数

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/22 17:22:20

RNN(循环神经网络)原理

一、RNN基本思想与核心概念 1.1 为什么需要RNN? 传统神经网络(如全连接网络、CNN)无法处理序列数据,因为它们: 无记忆性:每个输入独立处理,忽略序列中元素的时间/顺序关系 固定输入尺寸&#xf…

作者头像 李华
网站建设 2026/6/16 17:09:50

人机协同重构创作生态——生成式AI赋能内容产业的变革与思考

当内容生产遭遇“产能焦虑”与“创意枯竭”的双重困境,生成式AI正以不可逆转的态势重构行业规则。2025年一季度数据显示,国内72%的内容团队已将AI工具纳入核心工作流,电商文案、短视频脚本等场景的AI渗透率超85%。这场变革不仅是生产效率的提…

作者头像 李华
网站建设 2026/6/23 15:48:59

V助手舆情分析智能体:重塑舆情分析,从“人找信息”到“信息为人”

V助手舆情分析智能体:重塑舆情分析,从“人找信息”到“信息为人”在信息爆炸的时代,舆情分析工作常常面临数据繁杂、流程冗长、响应迟缓等挑战。传统方式不仅耗时耗力,更易错失关键信息与应对先机。如今,随着蜜度V助手…

作者头像 李华
网站建设 2026/6/23 15:56:00

连接2026:十款远程控制软件真实力横评与选择指南

目录引📈 选择前必读:明确你的核心需求🏆 综合王者:ToDesk(评分 9.6/10)🎯 细分领域佼佼者🎮 为游戏而生:网易UU远程(评分 8.4/10)🎬 …

作者头像 李华
网站建设 2026/6/23 15:46:34

计算机毕业设计springboot基于Spark++Vue.js的学生管理系统 Spark+Vue 高校学生综合信息管理平台 基于 SpringBoot+Spark+Vue 的全链路学生事务中心

计算机毕业设计springboot基于SparkVue.js的学生管理系统i2kn7p36 (配套有源码 程序 mysql数据库 论文) 本套源码可以在文本联xi,先看具体系统功能演示视频领取,可分享源码参考。在“数据即资产”的校园时代,传统 Excel 与人工流转…

作者头像 李华