SHAP进阶解析:机器学习、深度学习模型解释保姆级教程
背景
本篇文章将聚焦SHAP的高级功能与应用技巧,在这个项目中,作者将以一个基于XGBoost的二分类模型为例,展示如何通过SHAP深入剖析模型的内部机制,并结合实际问题,使用SHAP值进行更精细的解释和分析,进而推广到回归预测模型、多分类模型
具体内容包括
- 模型构建与参数优化:将通过网格搜索对XGBoost模型进行超参数调优,并结合SHAP来解释模型预测的特征贡献度
- 特征重要性排序与可视化:通过SHAP条形图、瀑布图等可视化手段,展示特征的重要性,并深入分析特定样本的特征贡献情况
- 分组分析与对比:使用SHAP的分组功能,按性别(分类数据)、年龄(连续性数据)等特征对模型的解释结果进行分类讨论,挖掘不同群体的模型行为差异
- 错误分类样本的深入解释:分析模型在错误分类样本上的表现,利用SHAP决策图可视化错分类的原因,帮助理解模型的局限性
- 群集分析与特征聚类:借助SHAP值和层次聚类技术,探索特征之间的关联性,并展示不同特征对模型决策的影响
- 实例特征解释与SHAP值决策图:通过SHAP值的逐样本分析,展示特定样本如何在特定特征的作用下影响模型的预测结果
通过这篇进阶教程,你将不仅能掌握更复杂场景下的模型解释方法,还能学会使用SHAP工具进行更深层次的模型调试和性能提升,如果你已经熟悉SHAP的基础知识,这篇文章将帮助你更进一步,提升对模型决策过程的洞察力,并为实际项目中的模型优化提供有力支持
代码实现
数据读取与分割
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['axes.unicode_minus'] = False
import warnings
warnings.filterwarnings("ignore")
df = pd.read_csv('Dataset.csv')
# 划分特征和目标变量
X = df.drop(['target'], axis=1)
y = df['target']
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42, stratify=df['target'])
df.head()
导入数据并预处理:读取数据集,提取特征和目标变量、划分训练集和测试集:将数据集拆分为训练集和测试集,便于后续模型的训练和评估、该数据集中使用了14个特征变量,这些特征包括了患者的基本信息和多项医疗检测指标,以下是每个特征的详细说明:
age(年龄):患者的年龄,单位为年
sex(性别):患者性别,1代表男性,0代表女性
cp(胸痛类型):患者的胸痛类型,有4种可能取值:1:典型心绞痛、2:非典型心绞痛、3:非心绞痛、4:无症状
trestbps(静息血压):患者入院时的静息血压,单位为mm Hg
chol(胆固醇):血清胆固醇值,单位为mg/dl
fbs(空腹血糖):空腹血糖值是否大于120 mg/dl,1为真,0为假
restecg(静息心电图结果):静息心电图的结果,有3个可能取值:0:正常、1:存在ST-T波异常(T波反转或ST段升高或降低超过0.05 mV)、2:显示可能或确定的左心室肥大
thalach(最大心率):运动测试中达到的最大心率
exang(运动诱发型心绞痛):是否在运动时诱发心绞痛,1为是,0为否
oldpeak(运动引发的ST段抑制值):与静息时相比的运动引发的ST段抑制
slope(ST段峰值的斜率):运动高峰时ST段的斜率,有3种取值:1:上升、2:平坦、3:下降
ca(主要血管的数量):使用荧光镜检查染色的主要血管数量,取值为0到3
thal(地中海贫血症):血液中的地中海贫血类型,有3个可能取值:0:正常、1:固定缺陷、2:可逆缺陷
num(诊断结果):目标变量,表示是否诊断出心脏病,0为小于50%直径缩小(无心脏病),1为大于50%直径缩小(有心脏病)
模型构建
import xgboost as xgb
from sklearn.model_selection import GridSearchCV
# XGBoost模型参数
params_xgb = {
'learning_rate': 0.02, # 学习率,控制每一步的步长,用于防止过拟合。典型值范围:0.01 - 0.1
'booster': 'gbtree', # 提升方法,这里使用梯度提升树(Gradient Boosting Tree)
'objective': 'binary:logistic', # 损失函数,这里使用逻辑回归,用于二分类任务
'max_leaves': 127, # 每棵树的叶子节点数量,控制模型复杂度。较大值可以提高模型复杂度但可能导致过拟合
'verbosity': 1, # 控制 XGBoost 输出信息的详细程度,0表示无输出,1表示输出进度信息
'seed': 42, # 随机种子,用于重现模型的结果
'nthread': -1, # 并行运算的线程数量,-1表示使用所有可用的CPU核心
'colsample_bytree': 0.6, # 每棵树随机选择的特征比例,用于增加模型的泛化能力
'subsample': 0.7, # 每次迭代时随机选择的样本比例,用于增加模型的泛化能力
'eval_metric': 'logloss' # 评价指标,这里使用对数损失(logloss)
}
# 初始化XGBoost分类模型
model_xgb = xgb.XGBClassifier(**params_xgb)
# 定义参数网格,用于网格搜索
param_grid = {
'n_estimators': [100, 200, 300, 400, 500], # 树的数量
'max_depth': [3, 4, 5, 6, 7], # 树的深度
'learning_rate': [0.01, 0.02, 0.05, 0.1], # 学习率
}
# 使用GridSearchCV进行网格搜索和k折交叉验证
grid_search = GridSearchCV(
estimator=model_xgb,
param_grid=param_grid,
scoring='neg_log_loss', # 评价指标为负对数损失
cv=5, # 5折交叉验证
n_jobs=-1, # 并行计算
verbose=1 # 输出详细进度信息
)
# 训练模型
grid_search.fit(X_train, y_train)
# 输出最优参数
print("Best parameters found: ", grid_search.best_params_)
print("Best Log Loss score: ", -grid_search.best_score_)
# 使用最优参数训练模型
best_model = grid_search.best_estimator_
使用XGBoost分类器构建一个二分类模型,并通过网格搜索和5折交叉验证(GridSearchCV)来优化模型参数,首先,定义了XGBoost模型的初始参数(如学习率、提升方法、树的叶子节点数等)和参数搜索网格(如树的数量、深度、学习率等),然后,通过网格搜索在不同的参数组合下训练模型,目标是最小化负对数损失(logloss)作为评估指标,最后,输出最优的参数组合,并使用该参数重新训练模型
SHAP值计算
import shap
explainer = shap.TreeExplainer(best_model)
# 计算shap值为numpy.array数组
shap_values_numpy = explainer.shap_values(X)
# 计算shap值为Explanation格式
shap_values_Explanation = explainer(X)
使用shap.TreeExplainer创建一个解释器对象explainer,它用于解释基于树模型(如XGBoost、决策树等)的预测,分别计算SHAP值并返回为numpy数组格式和更复杂的Explanation格式,后者适合更高级的分析和图形展示
XGBoost模型特征重要性可视化
# 获取XGBoost模型的特征贡献度(重要性)
feature_importances = best_model.feature_importances_
# 将特征和其重要性一起排序
sorted_indices = np.argsort(feature_importances)[::-1] # 逆序排列,重要性从高到低
sorted_features = X_train.columns[sorted_indices]
sorted_importances = feature_importances[sorted_indices]
# 绘制按重要性排序的特征贡献性柱状图
plt.figure(figsize=(10, 6), dpi=1200)
plt.barh(sorted_features, sorted_importances, color='steelblue')
plt.xlabel('Importance', fontsize=14)
plt.ylabel('Features', fontsize=14)
plt.title('Sorted Feature Importance', fontsize=16)
plt.gca().invert_yaxis()
plt.savefig("Sorted Feature Importance.pdf", format='pdf',bbox_inches='tight')
# 显示图表
plt.show()
通过提取XGBoost模型的内置特征重要性,生成并展示了基于模型自身计算的特征排名,帮助理解每个特征对模型预测结果的贡献;与后续基于SHAP值的特征排名不同,它仅反映模型的分裂结构
基于SHAP numpy数组格式的特征重要性总结图
# 绘制SHAP值总结图(Summary Plot)
plt.figure(figsize=(10, 5), dpi=1200)
shap.summary_plot(shap_values_numpy, X, plot_type="bar", show=False)
plt.title('SHAP_numpy Sorted Feature Importance')
plt.savefig("SHAP_numpy Sorted Feature Importance.pdf", format='pdf',bbox_inches='tight')
plt.tight_layout()
plt.show()
使用shap_values_numpy(即基于SHAP值的numpy数组格式)绘制特征重要性总结图,展示各个特征在整体模型预测中的贡献,细心的读者可发现两种方法的排名不一致,这是因为XGBoost基于树分裂计算特征重要性,而SHAP值根据每个特征对个体预测的影响进行评估,因此两者方法不同导致排名差异
基于SHAP Explanation格式的特征重要性条形图
plt.figure(figsize=(10, 5), dpi=1200)
shap.plots.bar(shap_values_Explanation, show=False)
plt.title('SHAP_Explanation Sorted Feature Importance')
plt.savefig("SHAP_Explanation Sorted Feature Importance.pdf", format='pdf',bbox_inches='tight')
plt.tight_layout()
plt.show()
使用shap_values_Explanation(即SHAP值的Explanation格式)绘制特征重要性条形图,展示每个特征对模型预测的贡献大小,并保存为PDF文件,该格式提供了更丰富的特征信息,便于更深入的分析和可视化,
这张图展示了部分特征的SHAP值重要性,但并未完全展示所有特征,较不重要的特征已被合并为“Sum of 4 other features”,如果希望展示更多特征,可以通过设置max_display参数,调整展示的特征数量,以便观察更完整的特征贡献排名
# 设置 max_display 值
max_display = 13
plt.figure(figsize=(10, 5), dpi=1200)
# 创建 SHAP 值条形图,并使用 max_display 参数限制最多显示的特征数量
shap.plots.bar(shap_values_Explanation, max_display=max_display, show=False)
plt.title(f'SHAP Explanation Sorted Feature Importance (Top {max_display})')
plt.savefig(f"SHAP_Explanation_Sorted_Feature_Importance_Top_{max_display}.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
基于SHAP Explanation的单个样本特征重要性与实际数据可视化
plt.figure(figsize=(10, 5), dpi=1200)
# 创建 SHAP 值条形图,展示数据
shap.plots.bar(shap_values_Explanation[1], show_data=True, show=False, max_display=13)
plt.title('SHAP Explanation for Instance (Feature Importance with Data)')
plt.savefig("SHAP_Explanation_Instance_Feature_Importance_with_Data.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
使用shap_values_Explanation[1](即第一个样本的SHAP值)绘制特征重要性条形图,并通过show_data=True显示每个特征的具体数值,最后将图表保存为PDF文件,该图展示了某个具体样本的各个特征如何正负向影响模型的预测结果,红色表示正贡献,蓝色表示负贡献
基于SHAP Explanation的单个样本瀑布图可视化
plt.figure(figsize=(10, 5), dpi=1200)
# 绘制第1个样本的 SHAP 瀑布图,并设置 show=False 以避免直接显示
shap.plots.waterfall(shap_values_Explanation[1], show=False, max_display=13)
# 保存图像为 PDF 文件
plt.savefig("SHAP_Waterfall_Plot_Sample_1.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
使用shap_values_Explanation[1](即第一个样本的SHAP值)绘制了该样本的瀑布图,通过waterfall函数展示各个特征如何逐步影响最终的模型预测值,设置max_display=13限制最多展示13个特征,该瀑布图其中红色表示正向贡献,蓝色表示负向影响,最终累积影响得出模型的预测值为1.543,其中E[f(X)] = -0.172是模型的基准值,表示在没有特征信息的情况下,模型的平均预测输出,它代表了模型对所有样本的整体预测倾向
基于SHAP Explanation值的按性别分组特征重要性可视化
# 'sex' 列包含性别信息,0 代表女性,1 代表男性
sex = ["Women" if df.loc[i, "sex"] == 0 else "Men" for i in range(df.shape[0])]
plt.figure(figsize=(10, 5), dpi=1200)
# 使用 SHAP 的 cohorts 方法根据 sex 进行分组并绘制条形图,限制显示最多13个特征
shap.plots.bar(shap_values_Explanation.cohorts(sex).abs.mean(0), max_display=13, show=False)
plt.title('SHAP Explanation Sorted Feature Importance by Sex')
plt.savefig("SHAP_Explanation_Sorted_Feature_Importance_by_Sex.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
使用shap_values_Explanation按sex(性别,原始特征为分类数据)特征分组,通过SHAP的cohorts方法计算并绘制各个特征对不同性别群体模型预测的平均贡献度,并限制显示13个最重要的特征,结果展示了按性别(男性和女性)分组后,特征对模型预测的重要性。无论是男性还是女性,cp(胸痛类型)、thal(地中海贫血类型)和ca(主要血管数量)都是最重要的特征,且贡献度相近。然而,某些特征的贡献在性别之间存在差异,如oldpeak在男性中的贡献更显著,这表明模型在预测不同性别群体时,部分特征的影响力存在明显差异,右下角数值分别代表各类别的样本量
基于SHAP Explanation值的按自动分组年龄的特征重要性可视化
# 例如,将年龄划分为三类:青年(<30岁),中年(30-60岁),老年(>60岁)
age_groups = ["Young" if df.loc[i, "age"] < 30 else "Middle-aged" if df.loc[i, "age"] <= 60 else "Senior" for i in range(df.shape[0])]
# 使用 SHAP 的 cohorts 方法根据 age_groups 进行分组
v = shap_values_Explanation.cohorts(age_groups).abs.mean(0)
plt.figure(figsize=(10, 5), dpi=1200)
# 绘制 SHAP 条形图
shap.plots.bar(v, show=False, max_display=13)
plt.title('SHAP Explanation Sorted Feature Importance by Automatically Grouped Age')
plt.savefig("SHAP_Explanation_Sorted_Feature_Importance_by_Grouped_Age.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
使用了shap_values_Explanation,根据年龄(age)这一连续性特征将样本分为青年、中年和老年三类,利用SHAP的cohorts方法计算各组的平均特征贡献度,并绘制条形图。该图展示了在不同年龄段中,特征对模型预测的平均重要性,显示了特征影响随年龄变化的差异,结果显示特征thal和cp在各年龄段的贡献度较为接近,但chol(胆固醇)在老年群体中的影响显著高于中年和青年,说明不同年龄段特征对模型预测的影响存在差异
基于SHAP Explanation值的特征聚类可视化(cutoff=0.5)
# 计算 clustering 结果和 shap_values_Explanation
clustering = shap.utils.hclust(X, y)
plt.figure(figsize=(10, 5), dpi=1200)
shap.plots.bar(shap_values_Explanation,
clustering=clustering,
clustering_cutoff=0.5,
show=False, max_display=13)
plt.title('SHAP Explanation with Clustering (cutoff=0.5)')
plt.savefig("SHAP_Explanation_with_Clustering.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
使用shap_values_Explanation,结合层次聚类方法(hclust),根据特征对模型的贡献进行聚类,设置了聚类截断值(clustering_cutoff=0.5),以展示相关特征的分组情况,通过条形图展示了前13个特征的平均SHAP值贡献,显示了模型中每个特征的相对重要性,聚类可视化的作用是通过层次聚类分析特征之间的相关性,帮助识别哪些特征在对模型预测产生类似影响,通过设置聚类截断值(cutoff=0.5),我们可以将影响力相似的特征分组,从而更好地理解模型的决策结构以及特征之间的相互作用,从当前的可视化结果来看,没有明确显示出类似的特征被分组在一起
该图为生成模拟数据基于SHAP值的特征重要性,结合层次聚类(clustering_cutoff=0.5)对相似特征进行了分组,图中的feature_0和feature_3被划分在同一组(有一个灰色的聚类括号),这表明这两个特征在模型中对预测的影响非常相似,属于相关性较强的特征,相比之前未能显示明显相关特征的图,这个图表通过聚类更好地揭示了特征之间的相似性,有助于更清晰地理解哪些特征在模型决策中发挥了类似的作用,当特征被聚类在同一组时,说明它们对模型的预测有类似的贡献,可能在数据上表现出较高的相关性或冗余性,这意味着这些特征可能描述了相似的现象,或者它们在模型中提供的独立信息量较低。在模型优化中,发现这些相关特征可以帮助我们减少冗余特征,简化模型,或者进一步探索这些特征之间的关系
基于SHAP Explanation值的特征散点图可视化
# 指定特征的名称为 'cp'(胸痛类型)
feature_name = 'cp'
# 找到指定特征的索引
feature_index = shap_values_Explanation.feature_names.index(feature_name)
plt.figure(figsize=(10, 5), dpi=1200)
# 使用 SHAP 的 scatter 方法绘制指定特征的散点图
shap.plots.scatter(shap_values_Explanation[:, feature_index], show=False)
plt.title(f'SHAP Scatter Plot for Feature: {feature_name}')
plt.savefig(f"SHAP_Scatter_Plot_{feature_name}.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
使用shap_values_Explanation,绘制特定特征cp(胸痛类型)的SHAP散点图,通过shap.plots.scatter展示该特征对模型预测的贡献值,并保存为PDF文件,图表横轴显示特征cp的不同取值,纵轴显示该特征的SHAP值(即它对模型输出的影响),可以发现cp的取值越高,其对应的SHAP值越大,表明较高的cp值(如4)对模型的正向预测影响更大,而较低的cp值(如1)则对模型的负向预测影响更显著
基于SHAP Explanation值的力图汇总可视化
# 初始化 JS 库
shap.initjs()
# 使用 force_plot 方法可视化所有样本的解释
shap.force_plot(explainer.expected_value, shap_values_Explanation.values, X)
首先初始化SHAP的JavaScript库 (shap.initjs()),然后使用force_plot方法可视化所有样本的解释结果,通过传入模型的期望输出值(explainer.expected_value)以及所有样本的shap_values_Explanation.values,展示整体模型的预测解释力图(force plot),并为每个样本的特征贡献进行汇总展示,生成的力图是交互式的,用户可以通过点击和悬停在图中的元素上,查看每个特征对模型预测的正向或负向影响,从而更直观地理解模型的决策过程
基于SHAP Explanation值的单个样本决策图可视化
# 获取模型的期望输出值(平均预测值)
expected_value = explainer.expected_value
# 选择第1个样本的 SHAP 值
shap_values = shap_values_numpy[1]
# 决策图的特征名展示
features_display = X
# 绘制 SHAP 决策图
plt.figure(figsize=(10, 5), dpi=1200)
shap.decision_plot(expected_value, shap_values, features_display, show=False)
# 保存图像为 PDF
plt.savefig("shap_decision_plot_samples.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
使用shap_values_numpy[1](即第1个样本的SHAP值),并结合模型的期望输出值(explainer.expected_value)生成了一个SHAP决策图,通过shap.decision_plot展示各个特征对模型预测结果的贡献,并保存为PDF文件,决策图显示了各个特征如何一步步影响模型的最终输出值,红线向右偏表示正向影响,向左偏表示负向影响,帮助理解模型如何基于各特征做出决策
基于SHAP Explanation值的错分类样本决策图可视化
# 获取模型的期望输出值(平均预测值)
expected_value = explainer.expected_value
# 计算预测值:根据 SHAP 值求和加上期望值,再将结果通过阈值判断为分类输出
y_pred = (shap_values_numpy.sum(1) + expected_value) > 0
# 计算错分类的样本
misclassified = y_pred != y
# 决策图展示的特征
features_display = X
# 绘制 SHAP 决策图,带有错分类的样本高亮
plt.figure(figsize=(10, 5), dpi=1200) # 设置画布大小和分辨率
shap.decision_plot(expected_value, shap_values_numpy, features_display,
link='logit', highlight=misclassified, show=False)
# 保存图像为 PDF 文件
plt.savefig("shap_decision_plot_with_misclassified.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
使用shap_values_numpy计算模型预测的SHAP值,并结合模型的期望输出值(expected_value)生成了一个决策图。通过shap.decision_plot展示所有样本的特征贡献,并高亮显示了错分类的样本,最终生成的图表展示了分类模型中哪些样本被错误预测,并保存为PDF文件,决策图展示了各特征对模型预测的影响轨迹,蓝色和红色线条分别代表对模型输出的负向和正向贡献,错分类的样本被高亮显示,帮助分析哪些特征对这些错误预测影响最大,从而可以进一步优化模型或调整特征
基于SHAP Explanation值的仅错分类样本决策图可视化
# 获取错分类样本的索引
misclassified_indices = misclassified[misclassified].index
# 过滤出错分类样本的 SHAP 值
shap_values_misclassified = shap_values_numpy[misclassified_indices]
# 保留原始特征名称,过滤出错分类样本的特征
features_display_misclassified = X.loc[misclassified_indices]
# 绘制 SHAP 决策图,显示错分类的样本
plt.figure(figsize=(10, 5), dpi=1200) # 设置画布大小和分辨率
shap.decision_plot(expected_value,
shap_values_misclassified,
link='logit',
highlight=True, # 高亮显示这些样本
feature_names=list(X.columns), # 转换为列表以避免 TypeError
show=False)
plt.savefig("shap_decision_plot_misclassified_only.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
使用shap_values_numpy,通过提取错分类样本的SHAP值,并结合模型的期望输出值(expected_value),生成了错分类样本的决策图。通过shap.decision_plot高亮显示仅错分类的样本,并展示了每个特征对这些样本模型预测的影响,帮助我们深入分析哪些特征导致了错误预测,并为后续模型调整或特征工程提供参考
- 背景
- 代码实现
- 数据读取与分割
- 模型构建
- SHAP值计算
- XGBoost模型特征重要性可视化
- 基于SHAP numpy数组格式的特征重要性总结图
- 基于SHAP Explanation格式的特征重要性条形图
- 基于SHAP Explanation的单个样本特征重要性与实际数据可视化
- 基于SHAP Explanation的单个样本瀑布图可视化
- 基于SHAP Explanation值的按性别分组特征重要性可视化
- 基于SHAP Explanation值的按自动分组年龄的特征重要性可视化
- 基于SHAP Explanation值的特征聚类可视化(cutoff=0.5)
- 基于SHAP Explanation值的特征散点图可视化
- 基于SHAP Explanation值的力图汇总可视化
- 基于SHAP Explanation值的单个样本决策图可视化
- 基于SHAP Explanation值的错分类样本决策图可视化
- 基于SHAP Explanation值的仅错分类样本决策图可视化