在当今数据量巨大的时代,使用机器学习模型进行决策已成为医疗、金融、市场营销等行业的关键。许多机器学习模型如同黑箱,训练后很难完全理解其工作原理。这使得理解并解释模型行为变得困难,但为了确保对其准确性的信任,这样做至关重要。那么,如何建立对黑箱预测的信任呢?解决方案是可解释人工智能(XAI)。XAI旨在为人类难以感知的复杂AI模型开发解释,即系统能够理解AI算法正在做什么以及为何做出该决策。这些信息可以提高模型性能,帮助机器学习工程师进行故障排除,使AI系统更具说服力且易于理解。
OmniXAI是一个简化可解释AI的库,适用于需要在多个机器学习阶段获得解释的用户,包括数据分析、特征提取、模型构建和模型评估。它通过使用卡方分析和互信息计算等技术来查看输入特征与目标变量之间的相关性,帮助特征选择,识别关键特征。使用它提供的数据分析师,可以轻松进行相关性分析并发现类别不平衡。OmniXAI可以用于表格、图像、自然语言处理和时间序列数据。
OmniXAI提供了多种解释,使用户能够详细了解模型的行为。这些解释可以借助该库轻松可视化。它使用Plotly创建交互式图表,只需几行代码,就可以创建一个仪表板,简化了同时比较多个解释的过程。在本文的后续部分,将构建这样一个仪表板来描述模型的结果。
主要有两种类型的解释:本地解释和全局解释。本地解释解释了某个特定决策背后的原因。这种解释是使用LIME和SHAP等技术产生的。全局解释检查模型的整体行为。要生成全局解释,可以使用部分依赖图。
该库使用多种模型无关的技术,包括LIME、SHAP和L2X。这些方法可以在不了解模型复杂性的情况下有效地描述模型所做的决策。此外,它还使用模型特定的方法如Grad-CAM为给定模型生成解释。
将使用中风预测数据集来创建一个分类器模型。该数据集根据包括性别、年龄、不同疾病和吸烟状况在内的输入特征,用于确定患者是否可能中风。不能依赖医疗领域的“黑箱”做出判断;必须有选择的理由。为了实现这一点,将使用OmniXAI来分析数据集并理解模型的行为。
# 导入库
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from imblearn.over_sampling import SMOTE
from omnixai.data.tabular import Tabular
from omnixai.explainers.data import DataAnalyzer
from omnixai.visualization.dashboard import Dashboard
from omnixai.preprocessing.tabular import TabularTransform
from omnixai.explainers.tabular import TabularExplainer
from omnixai.explainers.prediction import PredictionAnalyzer
import seaborn as sns
import matplotlib.pyplot as plt
# 加载数据集
df = pd.read_csv("healthcare-dataset-stroke-data.csv")
df = df.drop('id', axis=1)
df = df.dropna()
print(df.head(10))
将使用OmniXAI来分析表格数据集,使用Pandas数据框。需要指定数据框、分类特征名称和目标列名称来构建给定Pandas数据框的Tabular实例。
feature_names = df.columns
categorical_columns = ['gender','ever_married','work_type','Residence_type','smoking_status']
tabular_data = Tabular(
df,
feature_columns=feature_names,
categorical_columns=categorical_columns,
target_column='stroke'
)
现在将创建一个DataAnalyzer解释器来分析数据。
explainer = DataAnalyzer(
explainers=["correlation", "imbalance#0", "imbalance#1", "mutual", "chi2"],
mode="classification",
data=tabular_data
)
explanations = explainer.explain_global(
params={"imbalance#0": {"features": ["gender"]},
"imbalance#1": {"features": ["ever_married"]}
}
)
dashboard = Dashboard(global_explanations=explanations)
dashboard.show()
Dash正在运行在 http://127.0.0.1:8050/。这个仪表板显示了特征相关性分析、性别和已婚特征的特征不平衡图以及特征重要性图。
现在将对数据应用TabularTransform,将表格实例转换为NumPy数组,并将分类特征转换为独热编码。接下来,将使用SMOTE通过从类别1中过采样数据来解决类别不平衡问题。然后,使用StandardScaler和LogisticRegression模型,将数据拟合到管道中。
transformer = TabularTransform().fit(tabular_data)
x = transformer.transform(tabular_data)
train, test, train_labels, test_labels = train_test_split(x[:, :-1], x[:, -1], train_size=0.80)
# 在训练集中平衡类别
oversample = SMOTE()
X_train_balanced, y_train_balanced = oversample.fit_resample(train, train_labels)
model = Pipeline(steps = [('scale',StandardScaler()),('lr',LogisticRegression())])
model.fit(X_train_balanced, y_train_balanced)
print('Test accuracy: {}'.format(accuracy_score(test_labels, model.predict(test))))
print(classification_report(test_labels,model.predict(test)))
print(confusion_matrix(test_labels,model.predict(test)))
train_data = transformer.invert(X_train_balanced)
test_data = transformer.invert(test)
在训练模型之后,可以进一步创建对其行为的解释。
现在将定义一个TabularExplainer,并在代码中给出参数。参数“explainers”提到了要使用的解释器的名称。预处理将原始数据转换为模型输入。本地解释由LIME、SHAP和MACE生成,而PDP生成全局解释。为了计算这个分类器模型的性能指标,将定义一个PredictionAnalyzer,并为其提供测试数据。
preprocess = lambda z: transformer.transform(z)
explainers = TabularExplainer(
explainers=["lime", "shap", "mace", "pdp"],
mode="classification",
data=train_data,
model=model,
preprocess=preprocess,
params={
"lime": {"kernel_width": 4},
"shap": {"nsamples": 200},
}
)
test_instances = test_data[10:15]
local_explanations = explainers.explain(X=test_instances)
global_explanations = explainers.explain_global(
params={"pdp": {"features": ['age', 'hypertension', 'heart_disease', 'ever_married', 'bmi','work_type']}}
)
analyzer = PredictionAnalyzer(
mode="classification",
test_data=test_data,
test_targets=test_labels,
model=model,
preprocess=preprocess
)
prediction_explanations = analyzer.explain()
在创建解释之后,将定义仪表板的参数,然后创建Plotly dash应用程序。可以通过将本地地址复制到浏览器的地址栏来运行此仪表板。
dashboard = Dashboard(
instances=test_instances,
local_explanations=local_explanations,
global_explanations=global_explanations,
prediction_explanations=prediction_explanations,
class_names=class_names
)
dashboard.show()