半监督学习在手写数字识别中的应用

机器学习领域,半监督学习是一种强大的技术,它能够在只有少量标签数据的情况下进行有效的学习。本文将通过一个具体的案例——手写数字识别,来展示半监督学习模型Label Spreading的性能。尽管使用的是完整的数据集,但模型训练时只使用了少量的标签数据。通过混淆矩阵和一系列分类指标,可以评估模型的性能,并且最后还会展示模型最不确定的10个预测结果。

数据生成

使用了手写数字数据集,该数据集包含了1797个样本点。模型将使用所有点进行训练,但只有30个点会被标记。随机选择了340个样本,其中只有40个样本会被赋予已知的标签。因此,存储了300个不应该知道的标签的样本的索引。

import numpy as np from sklearn import datasets digits = datasets.load_digits() rng = np.random.RandomState(2) indices = np.arange(len(digits.data)) rng.shuffle(indices) # 选择340个样本,其中40个有已知标签 X = digits.data[indices[:340]] y = digits.target[indices[:340]] images = digits.images[indices[:340]] n_total_samples = len(y) n_labeled_points = 40 indices = np.arange(n_total_samples) unlabeled_set = indices[n_labeled_points:]

接下来,将所有数据打乱,以便在训练过程中使用。将未标记的样本的标签设置为-1,表示它们是未标记的。

y_train = np.copy(y) y_train[unlabeled_set] = -1

半监督学习

使用LabelSpreading模型来预测未知的标签。模型使用gamma=0.25和max_iter=20的参数进行训练。训练完成后,使用模型来预测未标记样本的标签,并打印出模型的性能报告。

from sklearn.metrics import classification_report from sklearn.semi_supervised import LabelSpreading lp_model = LabelSpreading(gamma=0.25, max_iter=20) lp_model.fit(X, y_train) predicted_labels = lp_model.transduction_[unlabeled_set] true_labels = y[unlabeled_set] print("Label Spreading model: %d labeled & %d unlabeled points (%d total)" % ( n_labeled_points, n_total_samples - n_labeled_points, n_total_samples))

模型的性能报告如下:

print(classification_report(true_labels, predicted_labels))

从报告中可以看到,模型在各个类别上的精确度、召回率和F1分数都相当高,整体的准确率也达到了90%。

混淆矩阵

混淆矩阵是评估分类模型性能的一个重要工具,它展示了模型预测的准确性。使用ConfusionMatrixDisplay来生成混淆矩阵的可视化表示。

from sklearn.metrics import ConfusionMatrixDisplay ConfusionMatrixDisplay.from_predictions(true_labels, predicted_labels, labels=lp_model.classes_)

混淆矩阵的可视化结果将帮助更直观地理解模型在不同类别上的表现。

最不确定的预测

最后,将展示模型最不确定的10个预测结果。通过计算预测标签分布的熵来确定不确定性,并选择熵值最高的10个样本进行展示。

from scipy import stats pred_entropies = stats.distributions.entropy(lp_model.label_distributions_.T) uncertainty_index = np.argsort(pred_entropies)[-10:] import matplotlib.pyplot as plt f = plt.figure(figsize=(7, 5)) for index, image_index in enumerate(uncertainty_index): image = images[image_index] sub = f.add_subplot(2, 5, index + 1) sub.imshow(image, cmap=plt.cm.gray_r) plt.xticks([]) plt.yticks([]) sub.set_title("predict: %i\ntrue: %i" % ( lp_model.transduction_[image_index], y[image_index])) f.suptitle("Learning with small amount of labeled data") plt.show()

通过这些最不确定的预测,可以进一步分析模型在哪些类型的手写数字上存在困难,并探索可能的改进方法。

沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485