多项式核近似与线性分类器训练

机器学习领域,核方法是一种强大的技术,它允许在高维空间中有效地进行数据分类。然而,这些方法通常伴随着较高的计算成本。为了解决这一问题,可以使用多项式核近似技术,如PolynomialCountSketch,来降低计算复杂度,同时保持较高的分类准确性。本文将介绍如何利用PolynomialCountSketch在Covtype数据集上训练线性分类器,以近似核化分类器的性能。

数据集准备

Covtype数据集包含581,012个样本,每个样本有54个特征,分为6个类别。该数据集的目标是根据地图变量预测森林覆盖类型。将数据集转换为二元分类问题,以匹配LIBSVM网页上的版本。

from sklearn.datasets import fetch_covtype X, y = fetch_covtype(return_X_y=True) y[y != 2] = 0 y[y == 2] = 1 # 将类别2与其他6个类别分开

数据划分

选取5,000个样本用于训练,10,000个样本用于测试。为了复现Tensor Sketch论文中的结果,可以选择100,000个样本进行训练。

from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=5_000, test_size=10_000, random_state=42)

特征归一化

将特征缩放到[0, 1]的范围,以匹配LIBSVM网页上的数据处理方式,然后将其标准化为单位长度,正如Tensor Sketch论文中所做的那样。

from sklearn.pipeline import make_pipeline from sklearn.preprocessing import MinMaxScaler, Normalizer mm = make_pipeline(MinMaxScaler(), Normalizer()) X_train = mm.fit_transform(X_train) X_test = mm.transform(X_test)

建立基线模型

作为基线,在原始特征上训练线性SVM,并打印准确率。还会测量并存储准确率和训练时间,以便稍后进行绘图。

import time from sklearn.svm import LinearSVC results = {} lsvm = LinearSVC() start = time.time() lsvm.fit(X_train, y_train) lsvm_time = time.time() - start lsvm_score = 100 * lsvm.score(X_test, y_test) results["LSVM"] = {"time": lsvm_time, "score": lsvm_score} print(f"Linear SVM score on raw features: {lsvm_score:.2f}%") # Linear SVM score on raw features: 75.62%

建立核近似模型

接下来,在PolynomialCountSketch生成的特征上训练线性SVM,使用不同的n_components值,展示这些核特征近似如何提高线性分类的准确率。在典型的应用场景中,n_components应该大于输入表示中的特征数量,以实现与线性分类相比的改进。通常,评估分数/运行时间成本的最优值通常在n_components = 10 * n_features左右,尽管这可能取决于正在处理的具体数据集。由于原始样本有54个特征,四度多项式核的显式特征映射将大约有850万个特征(精确地说,54^4)。多亏了PolynomialCountSketch,可以将该特征空间中的大部分区分信息压缩到一个更紧凑的表示中。虽然在这个例子中只运行了一次实验(n_runs = 1),但在实践中,应该多次重复实验,以补偿PolynomialCountSketch的随机性质。

from sklearn.kernel_approximation import PolynomialCountSketch n_runs = 1 N_COMPONENTS = [250, 500, 1000, 2000] for n_components in N_COMPONENTS: ps_lsvm_time = 0 ps_lsvm_score = 0 for _ in range(n_runs): pipeline = make_pipeline(PolynomialCountSketch(n_components=n_components, degree=4), LinearSVC()) start = time.time() pipeline.fit(X_train, y_train) ps_lsvm_time += time.time() - start ps_lsvm_score += 100 * pipeline.score(X_test, y_test) ps_lsvm_time /= n_runs ps_lsvm_score /= n_runs results[f"LSVM + PS({n_components})"] = {"time": ps_lsvm_time, "score": ps_lsvm_score,} print(f"Linear SVM score on {n_components} PolynomialCountSketch features: {ps_lsvm_score:.2f}%") # Linear SVM score on 250 PolynomialCountSketch features: 76.55% # Linear SVM score on 500 PolynomialCountSketch features: 76.92% # Linear SVM score on 1000 PolynomialCountSketch features: 77.79% # Linear SVM score on 2000 PolynomialCountSketch features: 78.59%

建立核化SVM模型

训练一个核化SVM,看看PolynomialCountSketch如何近似核的性能。当然,这可能需要一些时间,因为SVC类在可扩展性方面相对较差。这就是为什么核近似器非常有用的原因:

from sklearn.svm import SVC ksvm = SVC(C=500.0, kernel="poly", degree=4, coef0=0, gamma=1.0) start = time.time() ksvm.fit(X_train, y_train) ksvm_time = time.time() - start ksvm_score = 100 * ksvm.score(X_test, y_test) results["KSVM"] = {"time": ksvm_time, "score": ksvm_score} print(f"Kernel-SVM score on raw features: {ksvm_score:.2f}%") # Kernel-SVM score on raw features: 79.78%

比较结果

最后,绘制不同方法的结果与它们的训练时间。可以看到,核化SVM实现了更高的准确率,但其训练时间要大得多,而且如果训练样本数量增加,其增长速度会更快。

import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=(7, 7)) ax.scatter([results["LSVM"]["time"]], [results["LSVM"]["score"]], label="Linear SVM", c="green", marker="^") ax.scatter([results["LSVM + PS(250)"]["time"]], [results["LSVM + PS(250)"]["score"]], label="Linear SVM + PolynomialCountSketch", c="blue",) for n_components in N_COMPONENTS: ax.scatter([results[f"LSVM + PS({n_components})"]["time"]], [results[f"LSVM + PS({n_components})"]["score"]], c="blue",) ax.annotate(f"n_comp.={n_components}", (results[f"LSVM + PS({n_components})"]["time"], results[f"LSVM + PS({n_components})"]["score"]), xytext=(-30, 10), textcoords="offset pixels",) ax.scatter([results["KSVM"]["time"]], [results["KSVM"]["score"]], label="Kernel SVM", c="red", marker="x",) ax.set_xlabel("Training time (s)") ax.set_ylabel("Accuracy (%)") ax.legend() plt.show()
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485