使用L1正则化的逻辑回归进行MNIST数字分类

机器学习领域,逻辑回归是一种广泛使用的分类算法。当数据集的特征数量较多时,为了防止模型过拟合,通常会采用正则化技术。L1正则化是其中一种,它通过在损失函数中添加一个L1范数项来实现,从而使得模型的权重向量更加稀疏,即许多权重会变为0,这有助于提高模型的可解释性。

在本例中,使用L1正则化的逻辑回归模型来处理MNIST数据集中的手写数字分类问题。MNIST数据集包含了大量的手写数字图片,每个图片被标记为0到9之间的一个数字。目标是训练一个模型,使其能够准确地识别出新的手写数字图片所代表的数字。

采用了SAGA算法来优化带有L1正则化项的逻辑回归模型。SAGA算法特别适合于样本数量远大于特征数量的情况,并且能够有效地处理非光滑的目标函数,这在L1正则化中是常见的。通过这种优化方法,得到的测试准确率超过了80%,同时保持了权重向量的稀疏性,这使得模型更加容易解释。

值得注意的是,虽然L1正则化模型的准确率低于L2正则化模型或非线性的多层感知机模型,但其稀疏性的特点使其在某些应用场景下更具优势。例如,在特征选择或者模型解释性要求较高的场合,L1正则化模型可能更加合适。

在实验中,首先加载了MNIST数据集的一个子集,并对其进行了预处理,包括数据的标准化和划分训练集与测试集。然后,使用SAGA算法训练了L1正则化逻辑回归模型,并计算了模型的稀疏性和测试得分。实验结果显示,L1正则化模型的稀疏性达到了74.57%,测试得分为0.8253。整个实验的运行时间为8.354秒。

此外,还绘制了模型的权重向量,以直观地展示每个类别的分类向量。这些分类向量可以被看作是模型学习到的数字特征,它们在空间上呈现出一定的稀疏性,这与L1正则化的目标相一致。

代码示例

import time import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import fetch_openml from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.utils import check_random_state # 开始计时 t0 = time.time() # 设置训练样本数量 train_samples = 5000 # 从OpenML加载MNIST数据集 X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False) random_state = check_random_state(0) permutation = random_state.permutation(X.shape[0]) X = X[permutation] y = y[permutation] X = X.reshape((X.shape[0], -1)) X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_samples, test_size=10000) # 数据标准化 scaler = StandardScaler() X_train = scaler.fit_transform(X_train) X_test = scaler.transform(X_test) # 初始化逻辑回归模型,使用L1正则化和SAGA算法 clf = LogisticRegression(C=50.0 / train_samples, penalty="l1", solver="saga", tol=0.1) clf.fit(X_train, y_train) # 计算模型的稀疏性和测试得分 sparsity = np.mean(clf.coef_ == 0) * 100 score = clf.score(X_test, y_test) # 打印稀疏性和测试得分 print("Sparsity with L1 penalty: %.2f%%" % sparsity) print("Test score with L1 penalty: %.4f" % score) # 绘制分类向量 coef = clf.coef_.copy() plt.figure(figsize=(10, 5)) scale = np.abs(coef).max() for i in range(10): l1_plot = plt.subplot(2, 5, i + 1) l1_plot.imshow(coef[i].reshape(28, 28), interpolation="nearest", cmap=plt.cm.RdBu, vmin=-scale, vmax=scale) l1_plot.set_xticks(()) l1_plot.set_yticks(()) l1_plot.set_xlabel("Class %i" % i) plt.suptitle("Classification vector for...") run_time = time.time() - t0 print("Example run in %.3f s" % run_time) plt.show()
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485