网格搜索与连续减半搜索比较

机器学习领域,参数优化是一个关键步骤,它直接影响模型的性能。本文将探讨两种流行的参数搜索方法:网格搜索(GridSearchCV)和连续减半搜索(HalvingGridSearchCV)。将通过一个支持向量机(SVC)的例子来比较这两种方法,并分析它们在训练时间和准确性方面的表现。

参数空间的定义

首先,定义了SVC估计器的参数空间。选择了两个参数:'gamma'和'C',它们的候选值分别如下:

gammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7] Cs = [1, 10, 100, 1e3, 1e4, 1e5] param_grid = { "gamma": gammas, "C": Cs } clf = SVC(random_state=rng)

这里,'gamma'参数控制着核函数的系数,而'C'参数则控制着错误分类的惩罚程度。为这两个参数分别设置了一系列的候选值,以便在后续的搜索过程中进行评估。

连续减半搜索与网格搜索的实现

接下来,实现了连续减半搜索和网格搜索。连续减半搜索通过逐步减少候选参数组合的数量来提高搜索效率,而网格搜索则尝试所有可能的参数组合。记录了两种方法的训练时间,并进行了比较。

tic = time() gsh = HalvingGridSearchCV(estimator=clf, param_grid=param_grid, factor=2, random_state=rng) gsh.fit(X, y) gsh_time = time() - tic tic = time() gs = GridSearchCV(estimator=clf, param_grid=param_grid) gs.fit(X, y) gs_time = time() - tic

在上述代码中,'factor'参数控制着连续减半搜索中每次迭代减少候选数量的比例。可以看到,连续减半搜索在找到与网格搜索同样准确的参数组合的同时,所需的时间更少。

结果可视化

为了更直观地展示两种搜索方法的效果,绘制了热力图。热力图显示了不同参数组合的平均测试分数,连续减半搜索还显示了参数组合最后一次使用的迭代次数。

def make_heatmap(ax, gs, is_sh=False, make_cbar=False): # Helper to make a heatmap. results = pd.DataFrame(gs.cv_results_) results[["param_C", "param_gamma"]] = results[["param_C", "param_gamma"]].astype(np.float64) if is_sh: # SH dataframe: get mean_test_score values for the highest iter scores_matrix = results.sort_values("iter").pivot_table(index="param_gamma", columns="param_C", values="mean_test_score", aggfunc="last") else: scores_matrix = results.pivot(index="param_gamma", columns="param_C", values="mean_test_score") im = ax.imshow(scores_matrix) ax.set_xticks(np.arange(len(Cs))) ax.set_xticklabels(["{:.0E}".format(x) for x in Cs]) ax.set_xlabel("C", fontsize=15) ax.set_yticks(np.arange(len(gammas))) ax.set_yticklabels(["{:.0E}".format(x) for x in gammas]) ax.set_ylabel("gamma", fontsize=15) plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") if is_sh: iterations = results.pivot_table(index="param_gamma", columns="param_C", values="iter", aggfunc="max").values for i in range(len(gammas)): for j in range(len(Cs)): ax.text(j, i, iterations[i, j], ha="center", va="center", color="w", fontsize=20) if make_cbar: fig.subplots_adjust(right=0.8) cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) fig.colorbar(im, cax=cbar_ax) cbar_ax.set_ylabel("mean_test_score", rotation=-90, va="bottom", fontsize=15) fig, axes = plt.subplots(ncols=2, sharey=True) ax1, ax2 = axes make_heatmap(ax1, gsh, is_sh=True) make_heatmap(ax2, gs, make_cbar=True) ax1.set_title("Successive Halving\ntime = {:.3f}s".format(gsh_time), fontsize=15) ax2.set_title("GridSearch\ntime = {:.3f}s".format(gs_time), fontsize=15) plt.show()

通过热力图,可以清楚地看到,连续减半搜索在第一次迭代中只评估了标记为0的参数组合,而标记为5的参数组合被认为是最好的。这表明连续减半搜索能够在更短的时间内找到与网格搜索同样准确的参数组合。

沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485