线性与二次判别分析比较

线性判别分析(LDA)和二次判别分析(QDA)是两种常用的分类方法。本文通过生成不同特征的合成数据集,比较了这两种方法在不同情况下的表现。首先定义了一个函数来生成合成数据,这些数据被分为两个类别,每个类别的分布由特定的协方差矩阵控制。生成了三个数据集,分别对应不同的协方差结构:第一个数据集的两个类别共享相同的球面对称协方差矩阵;第二个数据集的协方差矩阵固定但不是球面对称的;第三个数据集为每个类别分配了不同的非球面对称协方差矩阵。

接下来,使用matplotlib库来绘制LDA和QDA的决策边界以及每个类别的协方差椭球。协方差椭球展示了每个类别的两倍标准差。在LDA中,所有类别的标准差是相同的,而在QDA中,每个类别都有自己的标准差。定义了两个函数:一个用于绘制协方差椭球,另一个用于展示分类器的结果,包括决策边界、正确分类和错误分类的样本点、每个类别的均值以及估计的协方差。

在比较LDA和QDA时,发现在前两个数据集上,两种方法的表现是相似的。这是因为在这些情况下,数据生成过程为两个类别提供了相同的协方差矩阵,因此QDA估计的两个协方差矩阵几乎相等,与LDA估计的协方差矩阵相当。在第一个数据集中,用于生成数据集的协方差矩阵是球面对称的,这导致判别边界与两个均值之间的垂直平分线对齐。然而,在第二个数据集中,判别边界不再通过两个均值的中间。

最后,在第三个数据集中,观察到LDA和QDA之间的真正区别。QDA拟合了两个协方差矩阵,并提供了非线性的判别边界,而LDA由于假设两个类别共享一个单一的协方差矩阵,因此拟合不足。通过这些比较,可以更深入地理解LDA和QDA在不同数据特征下的适用性和性能差异。

代码实现

以下是生成数据和绘制结果的Python代码实现。代码首先导入了必要的库,然后定义了生成数据的函数,接着创建了三个不同的数据集。之后,定义了绘制协方差椭球和分类结果的函数。最后,比较了LDA和QDA在三个数据集上的表现,并通过matplotlib绘制了结果。

import numpy as np
import matplotlib as mpl
from matplotlib import colors
from sklearn.inspection import DecisionBoundaryDisplay

def make_data(n_samples, n_features, cov_class_1, cov_class_2, seed=0):
    rng = np.random.RandomState(seed)
    X = np.concatenate([
        rng.randn(n_samples, n_features) @ cov_class_1,
        rng.randn(n_samples, n_features) @ cov_class_2 + np.array([1, 1]),
    ])
    y = np.concatenate([np.zeros(n_samples), np.ones(n_samples)])
    return X, y

def plot_ellipse(mean, cov, color, ax):
    v, w = np.linalg.eigh(cov)
    u = w[0] / np.linalg.norm(w[0])
    angle = np.arctan(u[1] / u[0])
    angle = 180 * angle / np.pi  # convert to degrees
    ell = mpl.patches.Ellipse(mean, 2 * v[0]**0.5, 2 * v[1]**0.5, angle=180 + angle, facecolor=color, edgecolor="black", linewidth=2)
    ell.set_clip_box(ax.bbox)
    ell.set_alpha(0.4)
    ax.add_artist(ell)

def plot_result(estimator, X, y, ax):
    cmap = colors.ListedColormap(["tab:red", "tab:blue"])
    DecisionBoundaryDisplay.from_estimator(estimator, X, response_method="predict_proba", plot_method="pcolormesh", ax=ax, cmap="RdBu", alpha=0.3)
    DecisionBoundaryDisplay.from_estimator(estimator, X, response_method="predict_proba", plot_method="contour", ax=ax, alpha=1.0, levels=[0.5])
    y_pred = estimator.predict(X)
    X_right, y_right = X[y == y_pred], y[y == y_pred]
    X_wrong, y_wrong = X[y != y_pred], y[y != y_pred]
    ax.scatter(X_right[:, 0], X_right[:, 1], c=y_right, s=20, cmap=cmap, alpha=0.5)
    ax.scatter(X_wrong[:, 0], X_wrong[:, 1], c=y_wrong, s=30, cmap=cmap, alpha=0.9, marker="x")
    ax.scatter(estimator.means_[:, 0], estimator.means_[:, 1], c="yellow", s=200, marker="*", edgecolor="black")
    if isinstance(estimator, LinearDiscriminantAnalysis):
        covariance = [estimator.covariance_] * 2
    else:
        covariance = estimator.covariance_
    plot_ellipse(estimator.means_[0], covariance[0], "tab:red", ax)
    plot_ellipse(estimator.means_[1], covariance[1], "tab:blue", ax)
    ax.set_box_aspect(1)
    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.spines["left"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.set(xticks=[], yticks=[])

import matplotlib.pyplot as plt
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis

fig, axs = plt.subplots(nrows=3, ncols=2, sharex="row", sharey="row", figsize=(8, 12))
lda = LinearDiscriminantAnalysis(solver="svd", store_covariance=True)
qda = QuadraticDiscriminantAnalysis(store_covariance=True)

for ax_row, X, y in zip(axs, (X_isotropic_covariance, X_shared_covariance, X_different_covariance), (y_isotropic_covariance, y_shared_covariance, y_different_covariance)):
    lda.fit(X, y)
    plot_result(lda, X, y, ax_row[0])
    qda.fit(X, y)
    plot_result(qda, X, y, ax_row[1])

axs[0, 0].set_title("Linear Discriminant Analysis")
axs[0, 0].set_ylabel("Data with fixed and spherical covariance")
axs[1, 0].set_ylabel("Data with fixed covariance")
axs[0, 1].set_title("Quadratic Discriminant Analysis")
axs[2, 0].set_ylabel("Data with varying covariances")

fig.suptitle("Linear Discriminant Analysis vs Quadratic Discriminant Analysis", y=0.94, fontsize=15)
plt.show()
        
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485