1. 项目背景及解决问题的方案
1.1 项目背景
鸢尾花(Iris)数据集是机器学习领域的经典基准数据集,由统计学家Fisher于1936年提出,是多分类任务的入门级数据集。该数据集包含150个样本,对应3类鸢尾花(山鸢尾/Iris-setosa、变色鸢尾/Iris-versicolor、维吉尼亚鸢尾/Iris-virginica),每类各50个样本;每个样本包含4个数值型特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度。
从技术角度,决策树是一种基于树状结构做决策的分类/回归算法,具有可解释性强、无需特征归一化、直观易懂等优点,但核心痛点是:当决策树的深度过深时,模型会过度拟合训练数据的细节(如噪声),导致在测试集上的泛化能力下降(过拟合)。
本项目的核心目标:
- 基于鸢尾花数据集,使用决策树分类器实现鸢尾花种类的精准分类;
- 探究决策树深度对模型泛化能力(测试集错误率)的影响,验证“深度过深导致过拟合”的现象;
- 掌握决策树模型的训练、评估、可视化及超参数(深度)调优的核心流程。
1.2 解决问题的方案(分步骤)
| 步骤 | 核心动作 | 具体实现 |
|---|---|---|
| 数据准备 | 加载+预处理 | 1. 加载sklearn内置的Iris数据集; 2. 转换为Pandas DataFrame,命名特征列并添加目标列; 3. 选择“花瓣长度、花瓣宽度”两个核心特征(区分度更高)。 |
| 数据集划分 | 训练/测试集拆分 | 按75%(训练):25%(测试)划分数据,设置random_state=42保证结果可复现。 |
| 基础模型训练 | 决策树训练+评估 | 1. 初始化决策树分类器(max_depth=8,gini准则);2. 训练模型并预测测试集; 3. 计算测试集准确率、输出特征重要性; 4. 导出决策树可视化文件(dot格式)。 |
| 单样本验证 | 自定义样本预测 | 对花瓣长度=5、宽度=1.5的样本,预测分类概率和最终结果。 |
| 超参数探究 | 深度对性能的影响 | 1. 遍历深度1~14,训练不同深度的决策树; 2. 计算每个深度的测试集错误率; 3. 可视化深度与错误率的关系,验证过拟合。 |
| 可视化展示 | 结果可视化 | 设置中文字体,绘制深度-错误率折线图,直观展示规律。 |
2. 代码详细注释版
# 导入必要的库importpandasaspd# 数据处理库,用于结构化数据操作importnumpyasnp# 数值计算库,用于数组/矩阵操作fromsklearn.datasetsimportload_iris# 加载sklearn内置的鸢尾花数据集fromsklearn.treeimportDecisionTreeClassifier# 决策树分类器fromsklearn.treeimportexport_graphviz# 导出决策树为dot格式(可视化用)fromsklearn.treeimportDecisionTreeRegressor# 决策树回归器(本项目未使用,保留注释)fromsklearn.model_selectionimporttrain_test_split# 划分训练集/测试集fromsklearn.metricsimportaccuracy_score# 计算分类准确率importmatplotlib.pyplotasplt# 绘图库importmatplotlibasmpl# 绘图配置库# ===================== 步骤1:加载并预处理鸢尾花数据集 =====================# 加载鸢尾花数据集iris=load_iris()# 将特征数据转换为DataFrame(方便查看和处理)data=pd.DataFrame(iris.data)# 为特征列命名(对应数据集的4个特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度)data.columns=iris.feature_names# 添加目标列(鸢尾花种类,0=setosa,1=versicolor,2=virginica)data['Species']=load_iris().target# 打印数据集前几行(默认5行),查看数据结构print(data)# 特征选择:仅选取花瓣长度(第3列)和花瓣宽度(第4列)作为输入特征(iloc[:,2:4]表示行全选,列选2-3索引)x=data.iloc[:,2:4]# 目标变量:选取最后一列(Species)作为分类目标y=data.iloc[:,-1]# ===================== 步骤2:划分训练集和测试集 =====================# train_size=0.75:训练集占75%,测试集25%;random_state=42:固定随机种子,保证结果可复现x_train,x_test,y_train,y_test=train_test_split(x,y,train_size=0.75,random_state=42)# ===================== 步骤3:训练基础决策树模型并评估 =====================# 初始化决策树分类器:max_depth=8(树最大深度),criterion='gini'(基尼系数作为分裂准则)tree_clf=DecisionTreeClassifier(max_depth=8,criterion='gini')# 用训练集数据训练模型tree_clf.fit(x_train,y_train)# 用训练好的模型预测测试集y_test_hat=tree_clf.predict(x_test)# 计算并打印测试集准确率print("acc score:",accuracy_score(y_test,y_test_hat))# 打印特征重要性:数值越大表示该特征对分类的贡献越大print("特征重要性(花瓣长度/花瓣宽度):",tree_clf.feature_importances_)# 导出决策树为dot格式文件(可通过dot命令转换为PNG图片查看树结构)export_graphviz(tree_clf,# 训练好的决策树模型out_file="./iris_tree.dot",# 输出文件路径feature_names=iris.feature_names[2:4],# 特征名(仅花瓣长度/宽度)class_names=iris.target_names,# 类别名(setosa/versicolor/virginica)rounded=True,# 节点边框圆角filled=True# 节点填充颜色)# 备注:转换命令(需安装graphviz):./dot -Tpng ~/PycharmProjects/mlstudy/bjsxt/iris_tree.dot -o ~/PycharmProjects/mlstudy/bjsxt/iris_tree.png# ===================== 步骤4:单样本预测 =====================# 预测花瓣长度=5,宽度=1.5的样本属于各类的概率print("单样本分类概率:",tree_clf.predict_proba([[5,1.5]]))# 预测该样本的最终分类结果(输出类别索引)print("单样本分类结果:",tree_clf.predict([[5,1.5]]))# ===================== 步骤5:探究决策树深度对错误率的影响 =====================# 生成深度范围:1到14(包含14)depth=np.arange(1,15)# 存储每个深度对应的错误率err_list=[]# 遍历每个深度,训练模型并计算错误率fordindepth:print(f"当前训练的决策树深度:{d}")# 初始化对应深度的决策树分类器(基尼系数准则)clf=DecisionTreeClassifier(criterion='gini',max_depth=d)# 训练模型clf.fit(x_train,y_train)# 预测测试集y_test_hat=clf.predict(x_test)# 计算预测正确的样本(True/False数组)result=(y_test_hat==y_test)# 仅当深度=1时打印预测正确与否的结果(用于调试)ifd==1:print(f"深度=1时的预测正确结果:{result}")# 计算错误率:1 - 正确样本的均值err=1-np.mean(result)# 打印错误率(百分比)print(f"深度={d}时的错误率(百分比):{100*err:.2f}%")# 将错误率加入列表err_list.append(err)# ===================== 步骤6:可视化深度与错误率的关系 =====================# 设置matplotlib的中文字体(SimHei=黑体,避免中文乱码)mpl.rcParams['font.sans-serif']=['SimHei']# 设置图片背景色为白色plt.figure(facecolor='w')# 绘制折线图:红色圆点+实线,线宽=2plt.plot(depth,err_list,'ro-',lw=2)# 设置x轴标签plt.xlabel('决策树深度',fontsize=15)# 设置y轴标签plt.ylabel('错误率',fontsize=15)# 设置标题plt.title('决策树深度和过拟合',fontsize=18)# 显示网格线plt.grid(True)# 展示图片plt.show()# 决策树回归器(本项目未使用,保留代码注释)# tree_reg = DecisionTreeRegressor(max_depth=2)# tree_reg.fit(X, y)3. 代码简洁版(核心逻辑,精简注释/打印)
importpandasaspdimportnumpyasnpfromsklearn.datasetsimportload_irisfromsklearn.treeimportDecisionTreeClassifier,export_graphvizfromsklearn.model_selectionimporttrain_test_splitfromsklearn.metricsimportaccuracy_scoreimportmatplotlib.pyplotaspltimportmatplotlibasmpl# 数据加载与预处理iris=load_iris()data=pd.DataFrame(iris.data,columns=iris.feature_names)data['Species']=iris.target x=data.iloc[:,2:4]# 花瓣长度/宽度y=data.iloc[:,-1]# 划分数据集x_train,x_test,y_train,y_test=train_test_split(x,y,train_size=0.75,random_state=42)# 基础模型训练tree_clf=DecisionTreeClassifier(max_depth=8,criterion='gini')tree_clf.fit(x_train,y_train)print("准确率:",accuracy_score(y_test,tree_clf.predict(x_test)))# 导出决策树可视化文件export_graphviz(tree_clf,out_file="./iris_tree.dot",feature_names=iris.feature_names[2:4],class_names=iris.target_names,rounded=True,filled=True)# 单样本预测print("单样本概率:",tree_clf.predict_proba([[5,1.5]]))print("单样本结果:",tree_clf.predict([[5,1.5]]))# 探究深度对错误率的影响depth=np.arange(1,15)err_list=[]fordindepth:clf=DecisionTreeClassifier(criterion='gini',max_depth=d)clf.fit(x_train,y_train)err=1-np.mean(clf.predict(x_test)==y_test)err_list.append(err)# 可视化mpl.rcParams['font.sans-serif']=['SimHei']plt.figure(facecolor='w')plt.plot(depth,err_list,'ro-',lw=2)plt.xlabel('决策树深度',fontsize=15)plt.ylabel('错误率',fontsize=15)plt.title('决策树深度和过拟合',fontsize=18)plt.grid(True)plt.show()