高斯混合模型初始化方法比较

高斯混合模型(Gaussian Mixture Model, GMM)是一种概率模型,用于表示具有多个子群体的数据集,其中每个子群体的数据分布可以用高斯分布来描述。在实际应用中,选择合适的初始化方法对于模型的性能至关重要。本文将介绍四种常见的初始化方法:kmeans、random、random_from_data和k-means++,并比较它们在样本数据上的收敛速度和初始化时间。

样本数据生成

首先,生成了一些具有四个明显聚类特征的样本数据。这些数据将用于展示不同初始化方法的效果。具体来说,使用make_blobs函数生成了4000个样本点,每个聚类中心的方差为0.60。

初始化方法

在本例中,比较了四种初始化方法

  • kmeans:这是默认的初始化方法,使用k-means算法来初始化聚类中心。
  • random:随机选择数据点作为聚类中心。
  • random_from_data:从数据中随机选择聚类中心,但确保每个聚类中心都是唯一的。
  • k-means++:一种改进的k-means初始化方法,通过优化初始聚类中心的选择来加速收敛。

初始化效果比较

通过比较不同初始化方法的收敛速度和初始化时间,可以得出以下结论:

在本例中,使用random_from_datarandom方法初始化的模型需要更多的迭代次数才能收敛。而k-means++方法在初始化时间和收敛速度方面都表现良好,既快速初始化,又需要较少的迭代次数就能收敛。

代码实现

以下是使用Python和scikit-learn库实现上述比较的代码示例:

from sklearn.datasets._samples_generator import make_blobs from sklearn.mixture import GaussianMixture from sklearn.utils.extmath import row_norms import numpy as np import matplotlib.pyplot as plt import timeit # 生成样本数据 X, y_true = make_blobs(n_samples=4000, centers=4, cluster_std=0.60, random_state=0) X = X[:, ::-1] n_samples, n_components = 4000, 4 x_squared_norms = row_norms(X, squared=True) # 初始化方法 methods = ["kmeans", "random_from_data", "k-means++", "random"] colors = ["navy", "turquoise", "cornflowerblue", "darkorange"] times_init = {} relative_times = {} plt.figure(figsize=(4*len(methods)//2, 6)) plt.subplots_adjust(bottom=0.1, top=0.9, hspace=0.15, wspace=0.05, left=0.05, right=0.95) for n, method in enumerate(methods): r = np.random.RandomState(seed=1234) plt.subplot(2, len(methods)//2, n+1) start = timeit.default_timer() ini = get_initial_means(X, method, r) end = timeit.default_timer() init_time = end - start gmm = GaussianMixture(n_components=4, means_init=ini, tol=1e-9, max_iter=2000, random_state=r) gmm.fit(X) times_init[method] = init_time for i, color in enumerate(colors): data = X[gmm.predict(X) == i] plt.scatter(data[:,0], data[:,1], color=color, marker="x") plt.scatter(ini[:,0], ini[:,1], s=75, marker="D", c="orange", lw=1.5, edgecolors="black") relative_times[method] = times_init[method] / times_init[methods[0]] plt.xticks(()) plt.yticks(()) plt.title(method, loc="left", fontsize=12) plt.title("Iter %i | Init Time %.2f x" % (gmm.n_iter_, relative_times[method]), loc="right", fontsize=10) plt.suptitle("GMM iterations and relative time taken to initialize") plt.show()
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485