复现SCI文章 SHAP 依赖图可视化以增强机器学习模型的可解释性
背景
在机器学习领域,理解各个特征对模型输出的贡献至关重要,尤其是在像环境科学和生物学这样的重要领域中,SHAP是一种强大的解释工具,能够帮助直观地展示特征对模型预测结果的影响,一项研究《基于可解释机器学习模型的浮游植物生物量预测及关键影响因素识别》中,研究人员使用了 SHAP 依赖图来可视化环境因素如何影响模型预测
本文将通过医学数据,使用 Python 演示如何复现 SHAP 依赖图,并详细解释连续性特征对模型预测结果的影响
什么是 SHAP 依赖图?
SHAP 依赖图用于可视化单个特征对机器学习模型预测结果的影响,具体来说,x 轴是特征值,y 轴是 SHAP 值(度量特征对预测结果的重要性),这些图可以直观地显示出某个特征是对模型预测起正向还是负向作用
代码实现
数据集加载
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['axes.unicode_minus'] = False
import warnings
warnings.filterwarnings("ignore")
df = pd.read_csv('Dataset.csv')
# 划分特征和目标变量
X = df.drop(['target'], axis=1)
y = df['target']
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42, stratify=df['target'])
df.head()
首先,需要加载数据集并将其划分为特征 X 和目标变量 y,然后进行训练集和测试集的划分。目标变量是我们要预测的值,X 是输入的特征,这是一个分类任务,目标是预测患者是否患有心脏病。虽然是分类任务,但无论是分类问题还是回归问题,SHAP 依赖图的使用方式和原理是相同的,都可以用来解释模型中各个特征对预测结果的贡献
训练机器学习模型
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import GridSearchCV
# GBT模型参数
params_gbt = {
'learning_rate': 0.02, # 学习率,控制每一步的步长,用于防止过拟合。典型值范围:0.01 - 0.1
'max_depth': 3, # 树的深度,控制模型复杂度
'random_state': 42, # 随机种子,用于重现模型的结果
'subsample': 0.7, # 每次迭代时随机选择的样本比例,用于增加模型的泛化能力
}
# 初始化Gradient Boosting分类模型
model_gbt = GradientBoostingClassifier(**params_gbt)
# 定义参数网格,用于网格搜索
param_grid = {
'n_estimators': [100, 200, 300], # 树的数量
'max_depth': [3, 4, 5], # 树的深度
'learning_rate': [0.01, 0.1], # 学习率
}
# 使用GridSearchCV进行网格搜索和k折交叉验证
grid_search = GridSearchCV(
estimator=model_gbt,
param_grid=param_grid,
scoring='neg_log_loss', # 评价指标为负对数损失
cv=5, # 5折交叉验证
n_jobs=-1, # 并行计算
verbose=1 # 输出详细进度信息
)
# 训练模型
grid_search.fit(X_train, y_train)
# 使用最优参数训练模型
best_model = grid_search.best_estimator_
这里使用了梯度提升树(GBT),这是一个强大且常用的机器学习算法,通过网格搜索进行参数优化
计算 SHAP 值
import shap
explainer = shap.TreeExplainer(best_model)
# 计算shap值为numpy.array数组
shap_values_numpy = explainer.shap_values(X)
# 计算shap值为Explanation格式
shap_values_Explanation = explainer(X)
模型训练完毕后,可以使用 shap 包来计算 SHAP 值,SHAP 值用于衡量特定特征对模型输出的影响,这里分别通过 explainer.shap_values(X) 计算 SHAP 值为数组格式以便自定义绘制,和通过 explainer(X) 计算为 Explanation 格式,直接使用 SHAP 自带的绘图函数进行可视化
SHAP自带绘图函数实现依赖图
默认参数下绘制
# 绘制 'age' 特征的SHAP依赖图
shap.dependence_plot('age', shap_values_Explanation.values, X, show=False)
plt.savefig("SHAP Dependence Plot_1.pdf", format='pdf',bbox_inches='tight',dpi=1200)
图展示了 age(年龄) 特征对模型预测结果的 SHAP 值的依赖关系,说明不同年龄段如何影响模型的预测
- X 轴(age):表示年龄的取值范围,从 30 到 75 岁
- Y 轴(SHAP value for age):表示年龄对模型预测的影响。SHAP 值为正时,表示该年龄段增加了模型预测的概率;SHAP 值为负时,表示该年龄段降低了预测的概率
从图中可以看到:
- 年龄在 50 到 60 岁之间 对模型预测结果有显著的正面影响,SHAP 值较高,说明模型在这个年龄段倾向于预测目标事件的发生
- 70 岁左右,SHAP 值开始变为负数,意味着在这个年龄段,模型预测发生的概率降低
- 颜色代表了 thal(地中海贫血类型) 这一交互特征的影响,红色表示更高的值,蓝色表示较低的值,可以看到,thal 的不同取值对 SHAP 值的分布有一定影响,尤其是在 SHAP 值较大的区域,红色点较为集中
展示了年龄对模型预测的非线性影响,同时揭示了另一个特征(thal)如何与年龄共同作用,影响预测结果,然而,与文献中的图表样式相比,仍存在一些细微的差别绘制无颜色条的年龄 SHAP 依赖图
# 绘制 'age' 特征的 SHAP 依赖图,不显示颜色条
shap.dependence_plot('age', shap_values_Explanation.values, X, interaction_index=None, show=False)
# 添加 SHAP=0 的横线
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1)
plt.savefig("SHAP Dependence Plot_2.pdf", format='pdf',bbox_inches='tight',dpi=1200)
plt.show()
在这里,通过设置 interaction_index=None 可以关闭颜色条,不显示交互特征的影响。不过,该函数目前没有内置参数可以直接在 SHAP 值为 0 的位置添加一条横线。为了实现这一功能,可以利用 matplotlib 的 plt.axhline() 方法,在绘制依赖图后手动添加横线
接下来,还可以通过 explainer.shap_values(X) 格式绘制这个shap依赖图,以便实现自定义绘图
自定义绘图
将 SHAP 值转换为 DataFrame 格式以便于自定义绘图
shap_values_df = pd.DataFrame(shap_values_numpy, columns=X.columns)
shap_values_df.head()
单个shap依赖图绘制
# 绘制散点图,x轴是'age'特征,y轴是SHAP值
plt.figure(figsize=(6, 4),dpi=1200)
plt.scatter(df['age'], shap_values_df['age'], s=10)
# 添加shap=0的横线
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1)
plt.xlabel('Age', fontsize=12)
plt.ylabel('SHAP value for\nAge', fontsize=12)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.savefig("SHAP Dependence Plot_3.pdf", format='pdf',bbox_inches='tight')
plt.show()
代码生成一个 SHAP 值依赖图,其中展示了特征 age 对模型输出的贡献,同时对图表进行了一些格式上的优化,比如隐藏不必要的边框线条、在 SHAP=0 处添加一条基准线,并最终将图像保存为高分辨率的 PDF 文件。相比于直接使用 shap.dependence_plot() 的默认作图方式,这种方法提供了更高的灵活性,特别是在定制化绘图方面,可以根据不同场景、需求对图表进行高度定制,从而提高可视化的效果和表达的准确性
多个shap依赖图绘制
# 定义绘制 SHAP 依赖图的函数
def plot_shap_dependence(feature_list, df, shap_values_df, file_name="SHAP_Dependence_Plots.pdf"):
fig, axs = plt.subplots(2, 3, figsize=(12, 8), dpi=1200)
plt.subplots_adjust(hspace=0.4, wspace=0.4)
# 循环绘制每个特征的 SHAP 依赖图
for i, feature in enumerate(feature_list):
row = i // 3 # 行号
col = i % 3 # 列号
ax = axs[row, col]
# 绘制散点图,x轴是特征值,y轴是SHAP值
ax.scatter(df[feature], shap_values_df[feature], s=10)
# 添加shap=0的横线
ax.axhline(y=0, color='black', linestyle='-.', linewidth=1)
# 设置x和y轴标签
ax.set_xlabel(feature, fontsize=12)
ax.set_ylabel(f'SHAP value for\n{feature}', fontsize=12)
# 隐藏顶部和右侧的脊柱
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# 隐藏最后一个空图表的坐标轴 (若画布未关闭)
axs[1, 2].axis('off')
plt.savefig(file_name, format='pdf', bbox_inches='tight')
plt.show()
# 使用函数绘制age、trestbps、chol、thalach、oldpeak的shap依赖图
feature_list = ['age', 'trestbps', 'chol', 'thalach', 'oldpeak']
plot_shap_dependence(feature_list, df, shap_values_df)
这段代码定义一个函数 plot_shap_dependence,用于绘制给定特征列表的 SHAP 依赖图,生成 2 行 3 列的图表布局,并在 SHAP=0 处添加基准线,最后保存为高分辨率 PDF,该图的样式基本上与文献中的 SHAP 依赖图形式一致,包括散点图、SHAP 值为 0 的基准线、去掉顶部和右侧脊柱的简洁图形设计等
本文章转载微信公众号@Python机器学习AI