news 2026/1/22 22:10:30

【机器学习】案例1.2——决策树进行鸢尾花分类

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【机器学习】案例1.2——决策树进行鸢尾花分类

1. 项目背景及解决问题的方案

1.1 项目背景

鸢尾花(Iris)数据集是机器学习领域的经典基准数据集,由统计学家Fisher于1936年提出,是多分类任务的入门级数据集。该数据集包含150个样本,对应3类鸢尾花(山鸢尾/Iris-setosa、变色鸢尾/Iris-versicolor、维吉尼亚鸢尾/Iris-virginica),每类各50个样本;每个样本包含4个数值型特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度。

从技术角度,决策树是一种基于树状结构做决策的分类/回归算法,具有可解释性强、无需特征归一化、直观易懂等优点,但核心痛点是:当决策树的深度过深时,模型会过度拟合训练数据的细节(如噪声),导致在测试集上的泛化能力下降(过拟合)。

本项目的核心目标:

  • 基于鸢尾花数据集,使用决策树分类器实现鸢尾花种类的精准分类;
  • 探究决策树深度对模型泛化能力(测试集错误率)的影响,验证“深度过深导致过拟合”的现象;
  • 掌握决策树模型的训练、评估、可视化及超参数(深度)调优的核心流程。
1.2 解决问题的方案(分步骤)
步骤核心动作具体实现
数据准备加载+预处理1. 加载sklearn内置的Iris数据集;
2. 转换为Pandas DataFrame,命名特征列并添加目标列;
3. 选择“花瓣长度、花瓣宽度”两个核心特征(区分度更高)。
数据集划分训练/测试集拆分按75%(训练):25%(测试)划分数据,设置random_state=42保证结果可复现。
基础模型训练决策树训练+评估1. 初始化决策树分类器(max_depth=8,gini准则);
2. 训练模型并预测测试集;
3. 计算测试集准确率、输出特征重要性;
4. 导出决策树可视化文件(dot格式)。
单样本验证自定义样本预测对花瓣长度=5、宽度=1.5的样本,预测分类概率和最终结果。
超参数探究深度对性能的影响1. 遍历深度1~14,训练不同深度的决策树;
2. 计算每个深度的测试集错误率;
3. 可视化深度与错误率的关系,验证过拟合。
可视化展示结果可视化设置中文字体,绘制深度-错误率折线图,直观展示规律。

2. 代码详细注释版

# 导入必要的库importpandasaspd# 数据处理库,用于结构化数据操作importnumpyasnp# 数值计算库,用于数组/矩阵操作fromsklearn.datasetsimportload_iris# 加载sklearn内置的鸢尾花数据集fromsklearn.treeimportDecisionTreeClassifier# 决策树分类器fromsklearn.treeimportexport_graphviz# 导出决策树为dot格式(可视化用)fromsklearn.treeimportDecisionTreeRegressor# 决策树回归器(本项目未使用,保留注释)fromsklearn.model_selectionimporttrain_test_split# 划分训练集/测试集fromsklearn.metricsimportaccuracy_score# 计算分类准确率importmatplotlib.pyplotasplt# 绘图库importmatplotlibasmpl# 绘图配置库# ===================== 步骤1:加载并预处理鸢尾花数据集 =====================# 加载鸢尾花数据集iris=load_iris()# 将特征数据转换为DataFrame(方便查看和处理)data=pd.DataFrame(iris.data)# 为特征列命名(对应数据集的4个特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度)data.columns=iris.feature_names# 添加目标列(鸢尾花种类,0=setosa,1=versicolor,2=virginica)data['Species']=load_iris().target# 打印数据集前几行(默认5行),查看数据结构print(data)# 特征选择:仅选取花瓣长度(第3列)和花瓣宽度(第4列)作为输入特征(iloc[:,2:4]表示行全选,列选2-3索引)x=data.iloc[:,2:4]# 目标变量:选取最后一列(Species)作为分类目标y=data.iloc[:,-1]# ===================== 步骤2:划分训练集和测试集 =====================# train_size=0.75:训练集占75%,测试集25%;random_state=42:固定随机种子,保证结果可复现x_train,x_test,y_train,y_test=train_test_split(x,y,train_size=0.75,random_state=42)# ===================== 步骤3:训练基础决策树模型并评估 =====================# 初始化决策树分类器:max_depth=8(树最大深度),criterion='gini'(基尼系数作为分裂准则)tree_clf=DecisionTreeClassifier(max_depth=8,criterion='gini')# 用训练集数据训练模型tree_clf.fit(x_train,y_train)# 用训练好的模型预测测试集y_test_hat=tree_clf.predict(x_test)# 计算并打印测试集准确率print("acc score:",accuracy_score(y_test,y_test_hat))# 打印特征重要性:数值越大表示该特征对分类的贡献越大print("特征重要性(花瓣长度/花瓣宽度):",tree_clf.feature_importances_)# 导出决策树为dot格式文件(可通过dot命令转换为PNG图片查看树结构)export_graphviz(tree_clf,# 训练好的决策树模型out_file="./iris_tree.dot",# 输出文件路径feature_names=iris.feature_names[2:4],# 特征名(仅花瓣长度/宽度)class_names=iris.target_names,# 类别名(setosa/versicolor/virginica)rounded=True,# 节点边框圆角filled=True# 节点填充颜色)# 备注:转换命令(需安装graphviz):./dot -Tpng ~/PycharmProjects/mlstudy/bjsxt/iris_tree.dot -o ~/PycharmProjects/mlstudy/bjsxt/iris_tree.png# ===================== 步骤4:单样本预测 =====================# 预测花瓣长度=5,宽度=1.5的样本属于各类的概率print("单样本分类概率:",tree_clf.predict_proba([[5,1.5]]))# 预测该样本的最终分类结果(输出类别索引)print("单样本分类结果:",tree_clf.predict([[5,1.5]]))# ===================== 步骤5:探究决策树深度对错误率的影响 =====================# 生成深度范围:1到14(包含14)depth=np.arange(1,15)# 存储每个深度对应的错误率err_list=[]# 遍历每个深度,训练模型并计算错误率fordindepth:print(f"当前训练的决策树深度:{d}")# 初始化对应深度的决策树分类器(基尼系数准则)clf=DecisionTreeClassifier(criterion='gini',max_depth=d)# 训练模型clf.fit(x_train,y_train)# 预测测试集y_test_hat=clf.predict(x_test)# 计算预测正确的样本(True/False数组)result=(y_test_hat==y_test)# 仅当深度=1时打印预测正确与否的结果(用于调试)ifd==1:print(f"深度=1时的预测正确结果:{result}")# 计算错误率:1 - 正确样本的均值err=1-np.mean(result)# 打印错误率(百分比)print(f"深度={d}时的错误率(百分比):{100*err:.2f}%")# 将错误率加入列表err_list.append(err)# ===================== 步骤6:可视化深度与错误率的关系 =====================# 设置matplotlib的中文字体(SimHei=黑体,避免中文乱码)mpl.rcParams['font.sans-serif']=['SimHei']# 设置图片背景色为白色plt.figure(facecolor='w')# 绘制折线图:红色圆点+实线,线宽=2plt.plot(depth,err_list,'ro-',lw=2)# 设置x轴标签plt.xlabel('决策树深度',fontsize=15)# 设置y轴标签plt.ylabel('错误率',fontsize=15)# 设置标题plt.title('决策树深度和过拟合',fontsize=18)# 显示网格线plt.grid(True)# 展示图片plt.show()# 决策树回归器(本项目未使用,保留代码注释)# tree_reg = DecisionTreeRegressor(max_depth=2)# tree_reg.fit(X, y)

3. 代码简洁版(核心逻辑,精简注释/打印)

importpandasaspdimportnumpyasnpfromsklearn.datasetsimportload_irisfromsklearn.treeimportDecisionTreeClassifier,export_graphvizfromsklearn.model_selectionimporttrain_test_splitfromsklearn.metricsimportaccuracy_scoreimportmatplotlib.pyplotaspltimportmatplotlibasmpl# 数据加载与预处理iris=load_iris()data=pd.DataFrame(iris.data,columns=iris.feature_names)data['Species']=iris.target x=data.iloc[:,2:4]# 花瓣长度/宽度y=data.iloc[:,-1]# 划分数据集x_train,x_test,y_train,y_test=train_test_split(x,y,train_size=0.75,random_state=42)# 基础模型训练tree_clf=DecisionTreeClassifier(max_depth=8,criterion='gini')tree_clf.fit(x_train,y_train)print("准确率:",accuracy_score(y_test,tree_clf.predict(x_test)))# 导出决策树可视化文件export_graphviz(tree_clf,out_file="./iris_tree.dot",feature_names=iris.feature_names[2:4],class_names=iris.target_names,rounded=True,filled=True)# 单样本预测print("单样本概率:",tree_clf.predict_proba([[5,1.5]]))print("单样本结果:",tree_clf.predict([[5,1.5]]))# 探究深度对错误率的影响depth=np.arange(1,15)err_list=[]fordindepth:clf=DecisionTreeClassifier(criterion='gini',max_depth=d)clf.fit(x_train,y_train)err=1-np.mean(clf.predict(x_test)==y_test)err_list.append(err)# 可视化mpl.rcParams['font.sans-serif']=['SimHei']plt.figure(facecolor='w')plt.plot(depth,err_list,'ro-',lw=2)plt.xlabel('决策树深度',fontsize=15)plt.ylabel('错误率',fontsize=15)plt.title('决策树深度和过拟合',fontsize=18)plt.grid(True)plt.show()

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/1/22 9:16:33

AutoGPT镜像适用于科研场景吗?高校团队已投入使用

AutoGPT镜像在科研中的落地实践:高校团队如何用它加速研究 在人工智能技术快速迭代的今天,一场静悄悄的变革正在实验室和学术办公室中发生。越来越多的高校研究团队不再满足于将大模型当作问答工具,而是开始尝试让AI真正“动起来”——自主完…

作者头像 李华
网站建设 2026/1/18 8:38:12

GitHub热门项目推荐:基于Qwen3-14B开发的企业级AI助手

基于Qwen3-14B构建企业级AI助手:性能与落地的完美平衡 在当前企业智能化转型的浪潮中,一个现实问题反复浮现:我们是否真的需要动辄千亿参数的大模型来处理日常业务?对于大多数中小企业而言,部署超大规模语言模型不仅成…

作者头像 李华
网站建设 2026/1/18 13:57:12

从零到网络安全专家:一张全景路线图(2025版)

本文适用于:在校大学生(计算机/非计算机专业)想转行网络安全的职场人士网络安全爱好者想系统提升的安全从业者📊 网络安全人才市场现状(数据说话)维度数据解读人才缺口​2024年缺口超327万供需比1:10&#…

作者头像 李华
网站建设 2026/1/21 9:31:29

LeetCode 46/51 排列型回溯题笔记-全排列 / N 皇后

目录 一、题目 1:全排列(LeetCode 46) 题目描述 核心思路 重难点 & 易错点 Java 实现(标准版) 回溯过程演示(以nums[1,2]为例) 二、题目 2:N 皇后(LeetCode 5…

作者头像 李华
网站建设 2026/1/22 8:38:15

一周回顾:勒索飙升、AI上阵、人形机器人被盯上

一周回顾:勒索飙升、AI上阵、人形机器人被盯上 本周全球网络安全态势呈现显著的“多线高压”:勒索软件赎金在过去三年累计突破 21 亿美元,显示产业化、专业化趋势持续加速;AI 被进一步武器化,日本出现高中生借助 Chat…

作者头像 李华
网站建设 2026/1/21 12:41:41

嵌入式FOTA进阶:文件系统直接升级+串口分段传输深度指南!

随着嵌入式设备对FOTA升级效率与稳定性的要求提升,文件系统直写与串口分段传输已成为核心进阶技术。前者通过精简数据写入路径,降低存储开销与升级耗时;后者利用串口的稳定通道,以分段方式保障升级包传输的可靠性。本文将系统拆解…

作者头像 李华