机器学习库scikit-learn 1.0版本更新概览

经过一段时间的稳定发展,scikit-learn库发布了1.0版本。这个版本标志着库的成熟,并为用户提供了信号。此版本不包含任何破坏性变更,除了常规的两版本废弃周期。未来,将继续努力保持这种模式。此版本包括一些关键新特性以及许多改进和bug修复。以下是此版本一些主要特性的详细介绍。如需查看所有变更的详尽列表,请参考发布说明。

要安装最新版本的scikit-learn,可以使用pip:

pip install --upgrade scikit-learn

或者使用conda:

conda install -c conda-forgescikit-learn

API参数的变化

scikit-learnAPI公开了许多函数和方法,这些函数和方法有很多输入参数。例如,在本版本发布之前,可以这样实例化一个HistGradientBoostingRegressor:

HistGradientBoostingRegressor("squared_error", 0.1, 100, 31, None, 20, 0.0, 255, None, None, False, "auto", "loss", 0.1, 10, 1e-7, 0, None)

理解上述代码需要读者查阅API文档,并检查每个参数的位置和含义。为了提高基于scikit-learn编写的代码的可读性,现在用户需要提供大多数参数的名称,作为关键字参数,而不是位置参数。例如,上述代码将变为:

HistGradientBoostingRegressor(loss="squared_error", learning_rate=0.1, max_iter=100, max_leaf_nodes=31, max_depth=None, min_samples_leaf=20, l2_regularization=0.0, max_bins=255, categorical_features=None, monotonic_cst=None, warm_start=False, early_stopping="auto", scoring="loss", validation_fraction=0.1, n_iter_no_change=10, tol=1e-7, verbose=0, random_state=None)

这样代码的可读性大大提高了。位置参数自0.23版本以来已被废弃,现在将引发TypeError。在某些情况下,仍然允许使用有限数量的位置参数,例如在PCA中,PCA(10)仍然是允许的,但PCA(10, False)是不允许的。

Spline变换器

向数据集的特征集添加非线性项的一种方法是为连续/数值特征生成样条基函数。新引入的SplineTransformer实现了B样条基。样条是分段多项式,由它们的多项式度数和节点位置参数化。以下代码展示了样条的实际应用,更多信息请参考用户指南。

import numpy as np from sklearn.preprocessing import SplineTransformer X = np.arange(5).reshape(5, 1) spline = SplineTransformer(degree=2, n_knots=3) spline.fit_transform(X) array([[0.5 , 0.5 , 0. , 0. ], [0.125, 0.75 , 0.125, 0. ], [0. , 0.5 , 0.5 , 0. ], [0. , 0.125, 0.75 , 0.125], [0. , 0. , 0.5 , 0.5 ]])

分位数回归器

分位数回归估计的是给定X条件下y的中位数或其他分位数,而普通最小二乘法(OLS)估计的是条件均值。作为一种线性模型,新的QuantileRegressor为第q个分位数提供线性预测\(\hat{y}(w, X) = Xw\),其中\(q \in (0, 1)\)。然后通过以下最小化问题找到权重或系数w:

min_w {(1/n_samples) * sum_i PB_q(y_i - X_i w) + alpha ||w||_1}.

这包括了pinball损失(也称为线性损失),更多信息请参考mean_pinball_loss,

PB_q(t) = q * max(t, 0) + (1 - q) * max(-t, 0) = { q * t, & t > 0, 0, & t = 0, (1-q) * t, & t < 0 }

以及由参数alpha控制的L1惩罚,类似于linear_model.Lasso。

特征名称支持

当估计器在fit过程中接收到一个pandas的dataframe时,估计器将设置一个feature_names_in_属性,包含特征名称。注意,只有当dataframe的列名全部为字符串时,才启用特征名称支持。feature_names_in_用于检查传递给非fit方法(如predict)的dataframe的列名是否与fit中的特征一致:

from sklearn.preprocessing import StandardScaler import pandas as pd X = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=["a", "b", "c"]) scalar = StandardScaler().fit(X) scalar.feature_names_in_ array(['a', 'b', 'c'], dtype=object)

get_feature_names_out的支持已为已有get_feature_names的变换器以及具有输入和输出一一对应关系的变换器(如StandardScaler)提供。get_feature_names_out的支持将在未来的版本中添加到所有其他变换器中。此外,compose.ColumnTransformer.get_feature_names_out可用于组合其变换器的特征名称:

from sklearn.compose import ColumnTransformer from sklearn.preprocessing import OneHotEncoder import pandas as pd X = pd.DataFrame({ "pet": ["dog", "cat", "fish"], "age": [3, 7, 1] }) preprocessor = ColumnTransformer([ ("numerical", StandardScaler(), ["age"]), ("categorical", OneHotEncoder(), ["pet"]), ], verbose_feature_names_out=False) .fit(X) preprocessor.get_feature_names_out() array(['age', 'pet_cat', 'pet_dog', 'pet_fish'], dtype=object)

当这个preprocessor与pipeline一起使用时,可以通过切片和调用get_feature_names_out来获取分类器使用的特征名称:

from sklearn.linear_model import LogisticRegression from sklearn.pipeline import make_pipeline y = [1, 0, 1] pipe = make_pipeline(preprocessor, LogisticRegression()) pipe.fit(X, y) pipe[:-1].get_feature_names_out() array(['age', 'pet_cat', 'pet_dog', 'pet_fish'], dtype=object)

更灵活的绘图API

metrics.ConfusionMatrixDisplay、metrics.PrecisionRecallDisplay、metrics.DetCurveDisplay和inspection.PartialDependenceDisplay现在公开了两个类方法:from_estimator和from_predictions,允许用户根据预测或估计器创建图表。这意味着相应的plot_*函数已被废弃。请查看示例一和示例二了解如何使用新的绘图功能。

在线单类SVM

新类SGDOneClassSVM实现了使用随机梯度下降的在线线性版本的单类SVM。结合核近似技术,SGDOneClassSVM可以用来近似解决核化单类SVM的解,该解在OneClassSVM中实现,具有线性的拟合时间复杂度。请注意,核化单类SVM的复杂度在最好的情况下是样本数量的二次方。因此,SGDOneClassSVM非常适合用于训练样本数量大(> 10,000)的数据集,其中SGD变体可以快几个数量级。请查看此示例了解其使用方法,以及用户指南了解更多细节。

基于直方图的梯度提升模型现已稳定

HistGradientBoostingRegressor和HistGradientBoostingClassifier不再是实验性的,可以直接导入和使用:

from sklearn.ensemble import HistGradientBoostingClassifier

文档改进

此版本包括许多文档改进。在超过2100个合并的拉取请求中,大约有800个是对文档的改进。

脚本总运行时间:

(0分钟0.015秒)

相关示例

scikit-learn 1.1版本更新亮点

scikit-learn 1.2版本更新亮点

scikit-learn 0.24版本更新亮点

沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485