所有文章 > 日积月累 > 一文彻底搞懂机器学习 - 决策树(Decision Tree)

一文彻底搞懂机器学习 - 决策树(Decision Tree)

决策树(Decision Tree)是一种基础且直观的分类与回归技术,它借鉴了人类的决策过程。该技术通过逐步拆分数据集,并根据特征选择建立规则,从而实现对数据的分类和预测。以下是对决策树的深入解析,内容涵盖其训练与可视化方法、类概率的估计,以及CART(分类与回归树)训练算法等关键方面。

一、训练和可视化

决策树(Decision Tree)是什么?决策树是一种用于分类和回归任务的监督学习算法。它通过一个树状结构,将输入的数据(特征)映射到输出(类别或数值)。

在决策树中,每个内部节点表示一个特征上的判断条件,每个分支代表一个判断条件的输出,而每个叶子节点则代表一个类别(对于分类任务)或一个具体的数值(对于回归任务)。

如何训练和可视化决策树?

使用scikit-learn库的DecisionTreeClassifier训练决策树模型,并利用plot_tree函数结合matplotlib库实现决策树的可视化。

  1. 训练决策树:通过DecisionTreeClassifier类创建一个决策树分类器实例,并使用fit方法将训练数据(特征集X_train和目标标签y_train)输入到模型中,从而训练出决策树。
  2. 可视化决策树:利用plot_tree函数,将训练好的决策树模型(clf)进行可视化。这个函数接受多个参数,如是否填充节点颜色(filled=True)、特征名称(feature_names=iris.feature_names)、类别名称(class_names=iris.target_names)以及节点形状(rounded=True)等,以生成一个直观易懂的决策树图形。最后,通过matplotlib库的show方法显示这个图形。
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt

# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建决策树分类器
clf = DecisionTreeClassifier(random_state=42)

# 训练模型
clf.fit(X_train, y_train)

# 可视化决策树
plt.figure(figsize=(20,10)) # 设置图形大小
plot_tree(clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names, rounded=True)
plt.show() # 显示图形

二、估计类概率

决策树如何获取样本属于每个类别的预测概率?决策树通过训练过程形成叶子节点,并记录每个叶子节点中的类别分布。在预测时,根据样本到达的叶子节点来计算它属于每个类别的预测概率。

Python中,使用scikit-learn库可以很方便地实现决策树,并获取样本属于每个类别的预测概率

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建决策树分类器
clf = DecisionTreeClassifier(random_state=42)

# 训练模型
clf.fit(X_train, y_train)

# 获取测试集中样本的类概率
probabilities = clf.predict_proba(X_test)

# 打印前5个样本的类概率
for i in range(5):
print(f"Sample {i+1} class probabilities: {probabilities[i]}")

# 注意:probabilities[i] 是一个数组,包含了样本i属于每个类别的概率。
# 例如,如果probabilities[0] = [0.1, 0.7, 0.2],则表示第一个样本属于类别0的概率为0.1,属于类别1的概率为0.7,属于类别2的概率为0.2。

三、SVM回归

CART训练算法是什么?CART(Classification and Regression Trees)训练算法是一种用于分类和回归任务的决策树学习技术。CART算法采用贪心策略递归地划分数据集,以构建一棵二叉决策树。

它既可以用于分类任务,也可以用于回归任务。在分类任务中,CART算法通常使用基尼指数(Gini Index)作为划分标准;在回归任务中,则使用均方误差(MSE)作为划分标准。

CART训练算法用于分类任务?使用Iris数据集训练一个最大深度为3的CART决策树分类器,并通过plot_tree函数结合matplotlib库可视化了训练好的决策树。

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt

# 加载Iris数据集
iris = load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 初始化并训练CART分类器
clf = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)
clf.fit(X_train, y_train)

# 使用plot_tree函数可视化决策树
plt.figure(figsize=(20,10)) # 设置图形大小
plot_tree(clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names, rounded=True)
plt.show()

CART训练算法用于回归任务?使用Iris数据集(除去目标变量作为特征),通过划分训练集和测试集,初始化并训练了一个最大深度为3的CART回归器,然后可视化了回归树,并最后评估了模型的均方误差(MSE)性能。

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor, plot_tree
import matplotlib.pyplot as plt

# 加载Iris数据集
iris = load_iris()
X = iris.data[:, :-1] # 使用除了最后一个特征之外的所有特征作为预测变量
y = iris.data[:, -1] # 使用最后一个特征(花瓣长度)作为目标变量

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 初始化并训练CART回归器
regressor = DecisionTreeRegressor(criterion='squared_error', max_depth=3, random_state=42)
regressor.fit(X_train, y_train)

# 使用plot_tree函数可视化回归树
plt.figure(figsize=(20, 10)) # 设置图形大小
plot_tree(regressor, filled=True, feature_names=iris.feature_names[:-1], # 排除目标变量作为特征名
rounded=True, impurity=False) # impurity为False因为回归树不使用不纯度
plt.show()

# 评估模型
# 由于这是回归任务,我们可以使用均方误差(MSE)等指标来评估模型的性能。
from sklearn.metrics import mean_squared_error
y_pred = regressor.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse}")

本文章转载微信公众号@架构师带你玩转AI