梯度提升树的早停技术

梯度提升树是一种集成学习方法,通过迭代地添加决策树来构建一个强大的预测模型,每一棵新树都尝试修正前一棵树的错误。早停技术是梯度提升中的一种策略,它通过设置一个验证集来监控模型在训练过程中的表现,从而找到最佳的迭代次数,避免过拟合,同时提高模型对未见数据的泛化能力。

早停技术的核心思想非常简单:在训练过程中,保留一部分数据作为验证集,用以评估模型的性能。随着模型的迭代构建,监控其在验证集上的表现。当模型在验证集上的表现趋于稳定或变差时(在指定的容忍度范围内),这通常意味着模型已经达到了一个临界点,进一步的迭代可能会导致过拟合,此时就应该停止训练。

在应用早停技术时,可以通过模型的n_estimators_属性来获取最终模型中估计器(树)的数量。总的来说,早停技术是梯度提升中一个非常有价值的工具,它能够在模型性能和效率之间找到一个平衡点。

数据准备

首先,需要加载并准备加州房价数据集,用于训练和评估。对数据集进行子集划分,然后将其分为训练集和验证集。

from sklearn.datasets import fetch_california_housing from sklearn.model_selection import train_test_split data = fetch_california_housing() X, y = data.data[:600], data.target[:600] X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

模型训练与比较

训练了两个GradientBoostingRegressor模型:一个应用了早停技术,另一个没有。目的是为了比较它们的性能。同时,还计算了两个模型的训练时间和使用的估计器数量。

from sklearn.ensemble import GradientBoostingRegressor from time import time params = dict(n_estimators=1000, max_depth=5, learning_rate=0.1, random_state=42) gbm_full = GradientBoostingRegressor(**params) gbm_early_stopping = GradientBoostingRegressor(**params, validation_fraction=0.1, n_iter_no_change=10) start_time = time() gbm_full.fit(X_train, y_train) training_time_full = time() - start_time n_estimators_full = gbm_full.n_estimators_ start_time = time() gbm_early_stopping.fit(X_train, y_train) training_time_early_stopping = time() - start_time estimators_early_stopping = gbm_early_stopping.n_estimators_

误差计算

代码计算了两个模型在训练和验证数据集上的均方误差(MSE),并为每次提升迭代计算了误差。目的是为了评估模型的性能和收敛情况。

from sklearn.metrics import mean_squared_error train_errors_without = [] val_errors_without = [] train_errors_with = [] val_errors_with = [] for i, (train_pred, val_pred) in enumerate(zip(gbm_full.staged_predict(X_train), gbm_full.staged_predict(X_val))): train_errors_without.append(mean_squared_error(y_train, train_pred)) val_errors_without.append(mean_squared_error(y_val, val_pred)) for i, (train_pred, val_pred) in enumerate(zip(gbm_early_stopping.staged_predict(X_train), gbm_early_stopping.staged_predict(X_val))): train_errors_with.append(mean_squared_error(y_train, train_pred)) val_errors_with.append(mean_squared_error(y_val, val_pred))

可视化比较

创建了三个子图:绘制了两个模型在提升迭代过程中的训练误差和验证误差,并创建了一个条形图来比较有无早停技术的模型的训练时间和使用的估计器数量。

import matplotlib.pyplot as plt fig, axes = plt.subplots(ncols=3, figsize=(12, 4)) axes[0].plot(train_errors_without, label="gbm_full") axes[0].plot(train_errors_with, label="gbm_early_stopping") axes[0].set_xlabel("提升迭代次数") axes[0].set_ylabel("MSE (训练)") axes[0].set_yscale("log") axes[0].legend() axes[0].set_title("训练误差") axes[1].plot(val_errors_without, label="gbm_full") axes[1].plot(val_errors_with, label="gbm_early_stopping") axes[1].set_xlabel("提升迭代次数") axes[1].set_ylabel("MSE (验证)") axes[1].set_yscale("log") axes[1].legend() axes[1].set_title("验证误差") training_times = [training_time_full, training_time_early_stopping] labels = ["gbm_full", "gbm_early_stopping"] bars = axes[2].bar(labels, training_times) axes[2].set_ylabel("训练时间 (秒)") for bar, n_estimators in zip(bars, [n_estimators_full, estimators_early_stopping]): height = bar.get_height() axes[2].text(bar.get_x() + bar.get_width() / 2, height + 0.001, f"估计器数量: {n_estimators}", ha="center", va="bottom") plt.tight_layout() plt.show()
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485