文献复现——优化SHAP依赖图拟合曲线与交点标注的新应用
背景
在这篇文章中,将带读者深入探讨SHAP值解释图的优化与可视化手段,并结合之前的研究及应用——复现SCI文章 SHAP 依赖图可视化以增强机器学习模型的可解释性,展示如何通过在图中引入拟合曲线以及标注SHAP值为0时的交点,进一步提升对机器学习模型解释性的理解,本文的灵感主要来源于对《建成环境对街道活力的非线性影响和交互效应》这篇研究的解读与延展
拟合曲线的引入与优化
为了更好地揭示特征与目标变量之间的复杂非线性关系,在SHAP散点图中引入了LOWESS拟合曲线,这条曲线是通过局部加权回归法生成的,能够平滑数据点之间的变化,帮助直观地捕捉数据的趋势走向
SHAP值为0时的交点标注
在对SHAP解释图的进一步优化中,特别添加了SHAP值为0时拟合曲线的交点标注。这一点非常重要,因为SHAP值为0时意味着该特征在该点附近对模型的预测结果没有显著影响。标注这一交点,可以帮助识别出特征值在哪些区间对目标变量的影响是从无到有或从正到负的转变
代码实现
数据集加载
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['axes.unicode_minus'] = False
import warnings
warnings.filterwarnings("ignore")
df = pd.read_csv('Dataset.csv')
# 划分特征和目标变量
X = df.drop(['target'], axis=1)
y = df['target']
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42, stratify=df['target'])
df.head()
首先,需要加载数据集并将其划分为特征 X 和目标变量 y,然后进行训练集和测试集的划分。目标变量是我们要预测的值,X 是输入的特征,这是一个分类任务,目标是预测患者是否患有心脏病。虽然是分类任务,但无论是分类问题还是回归问题,SHAP 依赖图的使用方式和原理是相同的,都可以用来解释模型中各个特征对预测结果的贡献
训练机器学习模型
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import GridSearchCV
# GBT模型参数
params_gbt = {
'learning_rate': 0.02, # 学习率,控制每一步的步长,用于防止过拟合。典型值范围:0.01 - 0.1
'max_depth': 3, # 树的深度,控制模型复杂度
'random_state': 42, # 随机种子,用于重现模型的结果
'subsample': 0.7, # 每次迭代时随机选择的样本比例,用于增加模型的泛化能力
}
# 初始化Gradient Boosting分类模型
model_gbt = GradientBoostingClassifier(**params_gbt)
# 定义参数网格,用于网格搜索
param_grid = {
'n_estimators': [100, 200, 300], # 树的数量
'max_depth': [3, 4, 5], # 树的深度
'learning_rate': [0.01, 0.1], # 学习率
}
# 使用GridSearchCV进行网格搜索和k折交叉验证
grid_search = GridSearchCV(
estimator=model_gbt,
param_grid=param_grid,
scoring='neg_log_loss', # 评价指标为负对数损失
cv=5, # 5折交叉验证
n_jobs=-1, # 并行计算
verbose=1 # 输出详细进度信息
)
# 训练模型
grid_search.fit(X_train, y_train)
# 使用最优参数训练模型
best_model = grid_search.best_estimator_
代码通过网格搜索 (GridSearchCV) 对Gradient Boosting分类模型的超参数进行优化,并通过5折交叉验证选择出最优的模型参数
shap值计算整理
import shap
explainer = shap.TreeExplainer(best_model)
# 计算shap值为numpy.array数组
shap_values_numpy = explainer.shap_values(X)
shap_values_df = pd.DataFrame(shap_values_numpy, columns=X.columns)
shap_values_df.head()
计算模型的SHAP值,并将其转换为DataFrame格式,方便后续进行自定义绘图分析
基础绘图
# 绘制散点图,x轴是'age'特征,y轴是SHAP值
plt.figure(figsize=(6, 4),dpi=1200)
plt.scatter(df['age'], shap_values_df['age'], s=10)
# 添加shap=0的横线
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1)
plt.xlabel('Age', fontsize=12)
plt.ylabel('SHAP value for\nAge', fontsize=12)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.savefig("SHAP Dependence Plot_1.pdf", format='pdf',bbox_inches='tight')
plt.show()
绘制了一个基础的SHAP依赖图,其中x轴代表特征“age”(年龄),y轴代表该特征的SHAP值,即年龄对模型预测的影响大小,散点图展示了不同年龄的SHAP值,黑色虚线表示SHAP值为0的基准线,表示在该点年龄对预测没有显著正负影响,此图帮助直观地理解特征“age”对模型预测结果的影响方向和程度,当然更具体的解释参考文章——复现SCI文章 SHAP 依赖图可视化以增强机器学习模型的可解释性
通过拟合曲线与交点标注绘图
import seaborn as sns
from scipy.optimize import fsolve
# 绘制散点图
plt.figure(figsize=(8, 6), dpi=300)
plt.scatter(df['age'], shap_values_df['age'], s=20, label='SHAP values', alpha=0.7)
# 添加LOWESS拟合曲线
sns.regplot(x=df['age'], y=shap_values_df['age'], scatter=False, lowess=True, color='lightcoral', label='LOWESS Curve')
# 使用 LOWESS 数据生成拟合曲线
lowess_data = sns.regplot(x=df['age'], y=shap_values_df['age'], scatter=False, lowess=True, color='lightcoral')
line = lowess_data.get_lines()[0] # 拟合线条对象
x_fit = line.get_xdata() # LOWESS 拟合线的 x 轴数据
y_fit = line.get_ydata() # LOWESS 拟合线的 y 轴数据
# 找出所有与 y=0 相交的 x 值
def find_zero_crossings(x_fit, y_fit):
crossings = []
for i in range(1, len(y_fit)):
if (y_fit[i-1] < 0 and y_fit[i] > 0) or (y_fit[i-1] > 0 and y_fit[i] < 0):
# 使用插值法找到 x_fit 和 y_fit 中 y 值接近 0 的 x 值
crossing = fsolve(lambda x: np.interp(x, x_fit, y_fit), x_fit[i])[0]
crossings.append(crossing)
return crossings
x_intercepts = find_zero_crossings(x_fit, y_fit)
# 在图中标注所有的 x_intercepts
for x_intercept in x_intercepts:
plt.axvline(x=x_intercept, color='blue', linestyle='--', label=f'Intersection at Age = {x_intercept:.2f}')
plt.text(x_intercept, 0.2, f'Age = {x_intercept:.2f}', color='blue', fontsize=10, verticalalignment='bottom')
# 添加shap=0的横线
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1, label='SHAP = 0')
# 添加图例
plt.legend()
# 设置标签和标题
plt.xlabel('Age', fontsize=12)
plt.ylabel('SHAP value for\nAge', fontsize=12)
# 去除上和右的边框线
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# 保存并显示图像
plt.savefig("SHAP Dependence Plot_with_Multiple_Intersections.pdf", format='pdf', bbox_inches='tight')
plt.show()
在这幅图中,通过LOWESS拟合曲线和SHAP解释图来深入分析年龄(Age)对模型预测结果的影响。下面着重解释拟合线与交点的含义:LOWESS拟合曲线LOWESS曲线(红色曲线)是局部加权回归曲线,它用来平滑数据中的非线性趋势。在这幅图中,它表示了年龄对目标变量的平均影响趋势,从曲线中可以看出,随着年龄的变化,SHAP值也随之波动,通过这条拟合曲线,可以识别出不同年龄区间对模型预测的不同贡献
- 在年龄较低时(约40岁以下),SHAP值为负,表示年龄对预测结果的负向影响较为明显
- 之后,随着年龄的增加,SHAP值开始逐渐上升,到了大约53岁附近时,SHAP值变为正,说明该特征对模型开始有正向的影响
- 过了约64岁之后,SHAP值再次呈现下降趋势,表明年龄对预测的正向影响逐渐减弱甚至变为负向
交点
图中用蓝色虚线标注了两个交点,分别表示SHAP值曲线与y=0的交点,这两个交点表示特定年龄时,SHAP值为零,即在这些点上,年龄对模型的影响由负向或正向逐渐转变
- 第一个交点(Age = 53.92):这是年龄从负向影响转变为正向影响的点,当年龄大于53.92时,SHAP值开始为正,意味着年龄对模型的正向贡献逐渐增大
- 第二个交点(Age = 64.35):这是年龄从正向影响转为负向影响的点,当年龄超过64.35时,SHAP值再次变为负,说明此时年龄对模型的预测影响逐渐减弱
通过拟合曲线和交点,可以更直观地理解特征“年龄”对模型预测结果的非线性影响,尤其是这些交点,它们揭示了年龄在特定区间中对目标变量的关键变化点,有助于理解模型如何处理年龄这个特征,以及如何做出更精准的解释
多特征SHAP依赖图:通过拟合曲线与交点分析模型特征贡献
# 定义绘制 SHAP 依赖图的函数
def plot_shap_dependence(feature_list, df, shap_values_df, file_name="SHAP_Dependence_Plots.pdf"):
fig, axs = plt.subplots(2, 3, figsize=(12, 8), dpi=1200)
plt.subplots_adjust(hspace=0.4, wspace=0.4)
# 循环绘制每个特征的 SHAP 依赖图
for i, feature in enumerate(feature_list):
row = i // 3 # 行号
col = i % 3 # 列号
ax = axs[row, col]
# 绘制散点图,x轴是特征值,y轴是SHAP值
ax.scatter(df[feature], shap_values_df[feature], s=10, alpha=0.7)
# 添加 LOWESS 拟合曲线
sns.regplot(x=df[feature], y=shap_values_df[feature], scatter=False, lowess=True, color='lightcoral', ax=ax)
# 使用 LOWESS 数据生成拟合曲线
lowess_data = sns.regplot(x=df[feature], y=shap_values_df[feature], scatter=False, lowess=True, color='lightcoral', ax=ax)
line = lowess_data.get_lines()[0] # 拟合线条对象
x_fit = line.get_xdata() # LOWESS 拟合线的 x 轴数据
y_fit = line.get_ydata() # LOWESS 拟合线的 y 轴数据
# 找出所有与 y=0 相交的 x 值
def find_zero_crossings(x_fit, y_fit):
crossings = []
for i in range(1, len(y_fit)):
if (y_fit[i-1] < 0 and y_fit[i] > 0) or (y_fit[i-1] > 0 and y_fit[i] < 0):
crossing = fsolve(lambda x: np.interp(x, x_fit, y_fit), x_fit[i])[0]
crossings.append(crossing)
return crossings
x_intercepts = find_zero_crossings(x_fit, y_fit)
# 在图中标注所有的 x_intercepts
for x_intercept in x_intercepts:
ax.axvline(x=x_intercept, color='blue', linestyle='--') # 标注虚线
ax.text(x_intercept, 0.1, f'{x_intercept:.2f}', color='black', fontsize=10, verticalalignment='bottom') # 将文本标注颜色改为淡红色
# 添加shap=0的横线
ax.axhline(y=0, color='black', linestyle='-.', linewidth=1)
# 设置x和y轴标签
ax.set_xlabel(feature, fontsize=12)
ax.set_ylabel(f'SHAP value for\n{feature}', fontsize=12)
# 隐藏顶部和右侧的脊柱
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# 隐藏最后一个空图表的坐标轴 (若画布未关闭)
axs[1, 2].axis('off')
# 保存为 PDF 文件
plt.savefig(file_name, format='pdf', bbox_inches='tight')
plt.show()
# 使用函数绘制 age、trestbps、chol、thalach、oldpeak 的 shap 依赖图
feature_list = ['age', 'trestbps', 'chol', 'thalach', 'oldpeak']
plot_shap_dependence(feature_list, df, shap_values_df)
通过绘制多个特征的SHAP依赖图,结合LOWESS拟合曲线与交点标注,分析各特征对模型预测的影响,当然,也可以采用其他拟合曲线,而不仅限于LOWESS,这里主要是基于参考文献中所使用的LOWESS拟合曲线进行分析,这里的解释同前文Age解释原理相同
本文章转载微信公众号@Python机器学习AI