交叉验证行为可视化

在机器学习中,正确选择交叉验证方法对于恰当地拟合模型至关重要。为了避免模型过拟合、标准化测试集中的组数等,数据可以以多种方式被分割成训练集和测试集。本文通过一个示例来可视化scikit-learn中几种常见交叉验证对象的行为,以供比较。

首先,需要理解数据的结构。数据包含100个随机生成的输入数据点,分为3个类别,这些类别在数据点中的分布是不均匀的,同时还有10个“组”,这些组在数据点中分布是均匀的。接下来,将可视化这些数据。

import matplotlib.pyplot as plt import numpy as np from matplotlib.patches import Patch from sklearn.model_selection import ( GroupKFold, GroupShuffleSplit, KFold, ShuffleSplit, StratifiedGroupKFold, StratifiedKFold, StratifiedShuffleSplit, TimeSeriesSplit, ) rng = np.random.RandomState(1338) cmap_data = plt.cm.Paired cmap_cv = plt.cm.coolwarm n_splits = 4

为了可视化数据,首先生成类别/组数据。然后定义一个函数来可视化数据集中的组。这个函数会创建一个散点图,其中每个点代表一个数据点,颜色代表其所属的组或类别。

n_points = 100 X = rng.randn(n_points, 10) percentiles_classes = [0.1, 0.3, 0.6] y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)]) group_prior = rng.dirichlet([2] * 10) groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior)) def visualize_groups(classes, groups, name): fig, ax = plt.subplots() ax.scatter(range(len(groups)), [0.5] * len(groups), c=groups, marker="_", lw=50, cmap=cmap_data) ax.scatter(range(len(groups)), [3.5] * len(groups), c=classes, marker="_", lw=50, cmap=cmap_data) ax.set(ylim=[-1, 5], yticks=[0.5, 3.5], yticklabels=["Data\ngroup", "Data\nclass"], xlabel="Sample index") visualize_groups(y, groups, "no groups")

接下来,定义一个函数来可视化每个交叉验证对象的行为。将对数据进行4次分割,在每次分割中,将可视化训练集(蓝色)和测试集(红色)的索引。

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): """Create a sample plot for indices of a cross-validation object.""" use_groups = "Group" in type(cv).__name__ groups = group if use_groups else None for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=groups)): indices = np.array([np.nan] * len(X)) indices[tt] = 1 indices[tr] = 0 ax.scatter(range(len(indices)), [ii + 0.5] * len(indices), c=indices, marker="_", lw=lw, cmap=cmap_cv, vmin=-0.2, vmax=1.2) ax.scatter(range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data) ax.scatter(range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data) ax.set(yticks=np.arange(n_splits + 2) + 0.5, yticklabels=list(range(n_splits)) + ["class", "group"], xlabel="Sample index", ylabel="CV iteration", ylim=[n_splits + 2.2, -0.2], xlim=[0, 100]) ax.set_title("{}".format(type(cv).__name__), fontsize=15) return ax

通过这个函数,可以观察到不同交叉验证对象的行为。例如,KFold交叉验证迭代器默认不考虑数据点的类别或组。可以通过使用StratifiedKFold来保持每个类别的样本百分比,或者使用GroupKFold确保同一组不会出现在两个不同的折叠中。

cvs = [StratifiedKFold, GroupKFold, StratifiedGroupKFold] for cv in cvs: fig, ax = plt.subplots(figsize=(6, 3)) plot_cv_indices(cv(n_splits), X, y, groups, ax, n_splits) ax.legend([Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))], ["Testing set", "Training set"], loc=(1.02, 0.8)) plt.tight_layout() fig.subplots_adjust(right=0.7)

最后,将可视化多个交叉验证迭代器的行为。通过循环遍历几个常见的交叉验证对象,可以观察到每个对象的行为,注意有些对象使用组/类别信息,而有些则不使用。

cvs = [ KFold, GroupKFold, ShuffleSplit, StratifiedKFold, StratifiedGroupKFold, GroupShuffleSplit, StratifiedShuffleSplit, TimeSeriesSplit, ] for cv in cvs: this_cv = cv(n_splits=n_splits) fig, ax = plt.subplots(figsize=(6, 3)) plot_cv_indices(this_cv, X, y, groups, ax, n_splits) ax.legend([Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))], ["Testing set", "Training set"], loc=(1.02, 0.8)) plt.tight_layout() fig.subplots_adjust(right=0.7) plt.show()
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485