在机器学习中,正确地划分数据集是构建有效模型的关键步骤。为了避免模型过拟合,确保测试集中的组数量标准化等问题,数据集可以以多种方式划分为训练集和测试集。本文将通过可视化比较几种常见的scikit-learn对象的行为,来展示这一过程。
首先,需要理解数据的结构。数据集包含100个随机生成的输入数据点,分为3个类别,这些类别在数据点中的分布是不均匀的。同时,数据点被分为10个“组”,这些组在数据点中的分布是均匀的。
为了更好地理解数据,将可视化数据集。生成了类别/组数据,其中包含100个数据点,并且每个类别和组的分布如下:
n_points = 100
X = rng.randn(n_points, 10)
percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))
接下来,定义一个函数来可视化交叉验证的行为。将对数据进行4次划分,在每次划分中,将可视化训练集(蓝色)和测试集(红色)的索引。
def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
"""Create a sample plot for indices of a cross-validation object."""
use_groups = "Group" in type(cv).__name__
groups = group if use_groups else None
for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=groups)):
indices = np.array([np.nan] * len(X))
indices[tt] = 1
indices[tr] = 0
ax.scatter(range(len(indices)), [ii + 0.5] * len(indices), c=indices, marker="_", lw=lw, cmap=cmap_cv, vmin=-0.2, vmax=1.2)
# Plot the data classes and groups at the end
ax.scatter(range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data)
ax.scatter(range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data)
# Formatting
yticklabels = list(range(n_splits)) + ["class", "group"]
ax.set(yticks=np.arange(n_splits + 2) + 0.5, yticklabels=yticklabels, xlabel="Sample index", ylabel="CV iteration", ylim=[n_splits + 2.2, -0.2], xlim=[0, 100])
ax.set_title("{}".format(type(cv).__name__), fontsize=15)
return ax
现在,让看看KFold交叉验证对象的表现。通过下面的代码,可以看到KFold交叉验证迭代器默认情况下不考虑数据点的类别或组。可以通过使用StratifiedKFold来保持每个类别的样本百分比,或者使用GroupKFold来确保同一个组不会出现在两个不同的折叠中。
cvs = [StratifiedKFold, GroupKFold, StratifiedGroupKFold]
for cv in cvs:
fig, ax = plt.subplots(figsize=(6, 3))
plot_cv_indices(cv(n_splits), X, y, groups, ax, n_splits)
ax.legend([Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))], ["Testing set", "Training set"], loc=(1.02, 0.8))
plt.tight_layout()
fig.subplots_adjust(right=0.7)
接下来,将为多个CV迭代器可视化这种行为。将循环遍历几个常见的交叉验证对象,并可视化它们的行为。注意,有些对象使用组/类别信息,而有些则不使用。
cvs = [KFold, GroupKFold, ShuffleSplit, StratifiedKFold, StratifiedGroupKFold, GroupShuffleSplit, StratifiedShuffleSplit, TimeSeriesSplit]
for cv in cvs:
this_cv = cv(n_splits=n_splits)
fig, ax = plt.subplots(figsize=(6, 3))
plot_cv_indices(this_cv, X, y, groups, ax, n_splits)
ax.legend([Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))], ["Testing set", "Training set"], loc=(1.02, 0.8))
plt.tight_layout()
fig.subplots_adjust(right=0.7)
plt.show()