多标签分类:分类器链的使用

机器学习领域,多标签分类问题是指模型需要预测一个实例可能属于多个类别的情况。传统的方法是对每个标签独立训练一个二分类器,但在预测时不考虑标签间的相关性。为了解决这个问题,可以使用分类器链(ClassifierChain)这种更高级的策略。分类器链通过将一个分类器的预测结果作为下一个分类器的输入特征,从而在链中利用标签间的相关性。

酵母数据集的加载与预处理

本例中,使用了一个包含2,417个数据点,每个数据点有103个特征和14个可能的标签的酵母数据集。每个数据点至少有一个标签。首先,对14个标签中的每一个都训练了一个逻辑回归分类器作为基线模型。然后,在保留的测试集上进行预测,并计算每个样本的Jaccard相似度分数以评估这些分类器的性能。

from sklearn.datasets import fetch_openml from sklearn.model_selection import train_test_split import numpy as np # 从OpenML加载多标签数据集 X, Y = fetch_openml("yeast", version=4, return_X_y=True) Y = Y == "TRUE" X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=0)

模型拟合与性能评估

首先使用逻辑回归模型,该模型被OneVsRestClassifier包装,以处理多目标数据。然后,计算Jaccard相似度分数来评估模型性能。接着,构建了一个由多个分类器链组成的集成模型。每个链中的模型都是随机排列的,因此链之间的性能存在显著差异。通过平均链的二元预测并应用0.5的阈值来构建投票集成,集成的Jaccard相似度分数通常高于独立模型,并且往往超过集成中每个链的分数(尽管这在随机排列的链中不能保证)。

from sklearn.linear_model import LogisticRegression from sklearn.metrics import jaccard_score from sklearn.multiclass import OneVsRestClassifier from sklearn.multioutput import ClassifierChain # 使用OneVsRestClassifier包装逻辑回归模型 base_lr = LogisticRegression() ovr = OneVsRestClassifier(base_lr) ovr.fit(X_train, Y_train) Y_pred_ovr = ovr.predict(X_test) ovr_jaccard_score = jaccard_score(Y_test, Y_pred_ovr, average="samples") # 构建分类器链的集成 chains = [ClassifierChain(base_lr, order="random", random_state=i) for i in range(10)] for chain in chains: chain.fit(X_train, Y_train) Y_pred_chains = np.array([chain.predict_proba(X_test) for chain in chains]) chain_jaccard_scores = [jaccard_score(Y_test, Y_pred_chain >= 0.5, average="samples") for Y_pred_chain in Y_pred_chains] Y_pred_ensemble = Y_pred_chains.mean(axis=0) ensemble_jaccard_score = jaccard_score(Y_test, Y_pred_ensemble >= 0.5, average="samples")

结果可视化

绘制了独立模型、每个链以及集成的Jaccard相似度分数的柱状图。从图中可以看出,独立模型的性能通常不如分类器链的集成,这是因为逻辑回归没有模拟标签之间的关系。分类器链利用了标签之间的相关性,但由于标签排序的随机性,它可能会产生比独立模型更差的结果。而链的集成之所以表现更好,是因为它不仅捕捉了标签之间的关系,而且不对它们的正确顺序做强烈假设。

import matplotlib.pyplot as plt model_scores = [ovr_jaccard_score] + chain_jaccard_scores + [ensemble_jaccard_score] model_names = ("Independent", "Chain 1", "Chain 2", "Chain 3", "Chain 4", "Chain 5", "Chain 6", "Chain 7", "Chain 8", "Chain 9", "Chain 10", "Ensemble") x_pos = np.arange(len(model_names)) fig, ax = plt.subplots(figsize=(7, 4)) ax.grid(True) ax.set_title("Classifier Chain Ensemble Performance Comparison") ax.set_xticks(x_pos) ax.set_xticklabels(model_names, rotation="vertical") ax.set_ylabel("Jaccard Similarity Score") ax.set_ylim([min(model_scores) * 0.9, max(model_scores) * 1.1]) colors = ["r"] + ["b"] * len(chain_jaccard_scores) + ["g"] ax.bar(x_pos, model_scores, alpha=0.5, color=colors) plt.tight_layout() plt.show()
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485