梯度提升中的早停技术

梯度提升是一种集成技术,它通过组合多个弱学习器(通常是决策树)来创建一个稳健且强大的预测模型。这个过程是迭代进行的,每个新阶段(树)都会纠正前一个阶段的错误。早停技术是梯度提升中的一种方法,它允许找到构建模型所需的最优迭代次数,以便模型能够很好地泛化到未见过的数据并避免过拟合。其概念非常简单:预留一部分数据集作为验证集(通过validation_fraction指定)来在训练过程中评估模型的性能。

随着模型通过额外的阶段(树)迭代构建,其在验证集上的性能会作为步骤数量的函数进行监控。当模型在验证集上的性能趋于平稳或恶化(在tol指定的偏差内)超过一定数量的连续阶段(由n_iter_no_change指定)时,早停技术就会变得有效。这表明模型已经达到了一个点,进一步的迭代可能会导致过拟合,是时候停止训练了。当应用早停时,可以通过n_estimators_属性访问最终模型中的估计器(树)数量。总的来说,早停是梯度提升中平衡模型性能和效率的有价值工具。

数据准备

首先,加载并准备加州房价数据集,用于训练和评估。它对数据集进行子集划分,将其拆分为训练和验证集。

import time import matplotlib.pyplot as plt from sklearn.datasets import fetch_california_housing from sklearn.ensemble import GradientBoostingRegressor from sklearn.metrics import mean_squared_error from sklearn.model_selection import train_test_split data = fetch_california_housing() X, y = data.data[:600], data.target[:600] X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

模型训练与比较

训练了两个GradientBoostingRegressor模型:一个使用早停技术,另一个不使用。目的是比较它们的性能。它还计算了两个模型的训练时间和使用的n_estimators_。

params = dict(n_estimators=1000, max_depth=5, learning_rate=0.1, random_state=42) gbm_full = GradientBoostingRegressor(**params) gbm_early_stopping = GradientBoostingRegressor(**params, validation_fraction=0.1, n_iter_no_change=10) start_time = time.time() gbm_full.fit(X_train, y_train) training_time_full = time.time() - start_time n_estimators_full = gbm_full.n_estimators_ start_time = time.time() gbm_early_stopping.fit(X_train, y_train) training_time_early_stopping = time.time() - start_time estimators_early_stopping = gbm_early_stopping.n_estimators_

误差计算

代码计算了在前一节中训练的模型的均方误差,分别针对训练和验证数据集。它计算了每个提升迭代的误差。目的是评估模型的性能和收敛性。

train_errors_without = [] val_errors_without = [] train_errors_with = [] val_errors_with = [] for i, (train_pred, val_pred) in enumerate(zip(gbm_full.staged_predict(X_train), gbm_full.staged_predict(X_val))): train_errors_without.append(mean_squared_error(y_train, train_pred)) val_errors_without.append(mean_squared_error(y_val, val_pred)) for i, (train_pred, val_pred) in enumerate(zip(gbm_early_stopping.staged_predict(X_train), gbm_early_stopping.staged_predict(X_val))): train_errors_with.append(mean_squared_error(y_train, train_pred)) val_errors_with.append(mean_squared_error(y_val, val_pred))

可视化比较

它包括三个子图:绘制两个模型的训练误差随提升迭代的变化;绘制两个模型的验证误差随提升迭代的变化;创建一个条形图来比较使用和不使用早停技术的模型的训练时间和使用的估计器数量。

fig, axes = plt.subplots(ncols=3, figsize=(12, 4)) axes[0].plot(train_errors_without, label="gbm_full") axes[0].plot(train_errors_with, label="gbm_early_stopping") axes[0].set_xlabel("提升迭代") axes[0].set_ylabel("MSE (训练)") axes[0].set_yscale("log") axes[0].legend() axes[0].set_title("训练误差") axes[1].plot(val_errors_without, label="gbm_full") axes[1].plot(val_errors_with, label="gbm_early_stopping") axes[1].set_xlabel("提升迭代") axes[1].set_ylabel("MSE (验证)") axes[1].set_yscale("log") axes[1].legend() axes[1].set_title("验证误差") training_times = [training_time_full, training_time_early_stopping] labels = ["gbm_full", "gbm_early_stopping"] bars = axes[2].bar(labels, training_times) axes[2].set_ylabel("训练时间 (s)") for bar, n_estimators in zip(bars, [n_estimators_full, estimators_early_stopping]): height = bar.get_height() axes[2].text(bar.get_x() + bar.get_width() / 2, height + 0.001, f"估计器: {n_estimators}", ha="center", va="bottom") plt.tight_layout() plt.show()
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485