所有文章 > AI驱动 > 利用XGBoost模型进行多分类任务下的SHAP解释附代码讲解及GUI展示

利用XGBoost模型进行多分类任务下的SHAP解释附代码讲解及GUI展示

目标

在这篇文章中,我们将介绍如何利用XGBoost模型进行多分类任务,并使用SHAP对模型进行解释,并生成SHAP解释图、依赖图、力图和热图,从而直观地理解模型的决策过程和特征的重要性

二分类模型和多分类模型在SHAP上的差异

二分类模型

在二分类任务中,模型的目标是将数据划分为两个类别(例如,0和1),SHAP值用于解释每个特征对模型输出的贡献,在二分类模型中,每个样本的SHAP值只有一个,表示该特征对预测结果(通常是正类概率)的贡献

多分类模型

在多分类任务中,模型需要将数据划分为三个或更多类别,每个样本的预测结果不仅包含一个类别,还包括每个类别的概率,SHAP值在多分类任务中的应用需要分别计算每个类别的SHAP值,因此,对于每个样本,SHAP值将是一个矩阵,其中每个元素表示一个特征对某个类别的贡献

代码实现

数据读取处理

from sklearn.datasets import load_iris
import pandas as pd

# Load the iris dataset
iris = load_iris()

# Create a DataFrame from the iris dataset
iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
iris_df['target'] = iris.target

from sklearn.model_selection import train_test_split
X = iris_df.drop(['target'],axis=1)
y = iris_df['target']

X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# 然后将训练集进一步划分为训练集和验证集
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.125,stratify=y_temp, random_state=42) # 0.125 x 0.8 = 0.1

加载鸢尾花数据集,将鸢尾花数据集分割为训练集、验证集和测试集,具体过程是:从整个数据集中抽取20%作为测试集;剩余的80%数据中抽取12.5%作为验证集,最终验证集占整个数据集的10%,训练集占整个数据集的70%,这一步骤为后续使用XGBoost进行多分类模型的训练和评估奠定基础

模型建立

import xgboost as xgb

# 更新后的多分类模型参数
params_xgb = {
'learning_rate': 0.02, # 学习率
'booster': 'gbtree', # 提升方法
'objective': 'multi:softprob', # 损失函数,多分类使用softmax
'num_class': 3, # 类别数,鸢尾花数据集有三类
'max_leaves': 127, # 每棵树的叶子节点数量
'verbosity': 1, # 输出信息的详细程度
'seed': 42, # 随机种子
'nthread': -1, # 并行运算的线程数量
'colsample_bytree': 0.6, # 每棵树随机选择的特征比例
'subsample': 0.7, # 每次迭代时随机选择的样本比例
'early_stopping_rounds': 100, # 早停轮数
'eval_metric': 'mlogloss' # 评估指标,多分类使用mlogloss
}

# 创建并训练多分类模型
model_xgb = xgb.XGBClassifier(**params_xgb)
model_xgb.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=False)

配置并训练一个XGBoost多分类模型来预测鸢尾花数据集的类别,使用特定的参数设置和早停机制,比如这里是三分类,如果需要更改为其他多分类(如四分类),只需修改参数num_class,并相应调整其他参数以达到最优模型效果

模型评价指标输出

评价报告

from sklearn.metrics import classification_report
# 预测测试集
y_pred = model_xgb.predict(X_test)

# 输出模型报告, 查看评价指标
print(classification_report(y_test, y_pred))

混淆矩阵热力图

from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# 输出混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)

# 绘制热力图
plt.figure(figsize=(10, 7))
sns.heatmap(conf_matrix, annot=True, annot_kws={'size':15}, fmt='d', cmap='YlGnBu')
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.title('Confusion matrix heat map', fontsize=15)
plt.show()

Shap实现

创建Shap解释器

import shap
# 创建SHAP解释器
explainer = shap.Explainer(model_xgb)
# 计算SHAP值
shap_values = explainer(X_test)
print("shap值维度;",shap_values.shape)
shap_values

可以看见针对测试集的shap值的维度为(30,4,3),也就是计算的每个类别的SHAP值,对于每个样本,SHAP值将是一个矩阵,其中每个元素表示一个特征对某个类别的贡献

绘制Shap解释图

# 特征标签
labels = X_train.columns

# 设置 matplotlib 的全局字体配置
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Times New Roman'
plt.rcParams['font.size'] = 13

# 提取每个类别的 SHAP 值
shap_values_class_1 = shap_values.values[:, :, 0]
shap_values_class_2 = shap_values.values[:, :, 1]
shap_values_class_3 = shap_values.values[:, :, 2]
shap_values_class_1

# 绘制 SHAP 总结图,使用viridis配色方案
plt.figure()
plt.title('class_1')
shap.summary_plot(shap_values_class_1, X_val, feature_names=labels, plot_type="dot", cmap="viridis")
plt.show()

这里针对鸢尾花的测试集第一个类别0进行shap解释图绘制

绘制Shap依赖图

shap.dependence_plot('sepal length (cm)', shap_values_class_1, X_val, interaction_index='sepal width (cm)')
plt.show()

针对鸢尾花的测试集第一个类别0的特征sepal length (cm)、sepal width (cm)进行shap依赖图绘制

绘制Shap力图

# 选择一个样本索引进行解释
sample_index = 1
expected_value = explainer.expected_value[0] # 需要指定个类别的基准值,这里是第一个类别
# 获取单个样本的 SHAP 值
sample_shap_values = shap_values_class_1[sample_index]

# 绘制 SHAP 解释力图 (Force Plot)
shap.force_plot(expected_value, sample_shap_values, X_val.iloc[sample_index], matplotlib=True)
# 显示绘图
plt.show()

shap力图解释同样会选择数据集以及类别,但是还会多一个应选择的基准值,比如这里选择的第一个类比那基准值也要选择第一个类比的基准值

生成Shap交互作用图

shap_interaction_values = explainer.shap_interaction_values(X_val)
# 提取每个类别的值
shap_interaction_values_class_1 = shap_interaction_values[:, :, :, 0] # 类别1
shap_interaction_values_class_2 = shap_interaction_values[:, :, :, 1] # 类别2
shap_interaction_values_class_3 = shap_interaction_values[:, :, :, 2] # 类别3
# 绘制 SHAP 交互值的总结图
plt.figure()
shap.summary_plot(shap_interaction_values_class_1, X_val, feature_names=labels)
plt.show()

计算的交互值比Shap值维度多一,同理得提取每一个类比的交互值,具体怎么提取参考这个三分类代码

生成Shap热图

expected_value = explainer.expected_value[0]  # 需要指定个类别的基准值,这里是第一个类别 
# 创建 shap.Explanation 对象
shap_explanation = shap.Explanation(values=shap_values_class_1[0:10, :],
base_values= expected_value,
data=X_val.iloc[0:10, :],
feature_names=X_val.columns)

# 绘制热图
plt.figure()
shap.plots.heatmap(shap_explanation)
plt.show()

代码绘制的是第一个类比测试集前10个样本的Shap热图,和力图一样得确定一个基准值,对哪一个类比做就采用哪一个类比的基准值

项目GUI实现

创建一个使用Tkinter库构建的GUI应用程序,旨在通过按钮、标签、组合框和文本框等组件实现数据上传、选择目标特征、设置分类任务的类别数、选择数据集、选择颜色方案、选择特征、输入样本索引、输入样本范围等功能,从而对XGBoost分类模型进行训练并生成相关的解释图,并确保将这些图保存为高DPI的PDF文件,以保证可视化效果不受损失

本文章转载微信公众号@Python机器学习AI