数据聚类算法比较

在数据科学领域,聚类算法是一种无监督学习方法,用于将数据集中的样本划分为若干个簇,使得同一簇内的样本相似度高,而不同簇之间的样本相似度低。本文将介绍几种常见的聚类算法,并展示它们在二维数据集上的表现。除了最后一个数据集外,其他数据集的算法参数都经过了调整,以产生良好的聚类结果。需要注意的是,某些算法对参数值的敏感度比其他算法更高。

最后一个数据集是一个“无聚类”情况的例子:数据是均匀的,没有好的聚类方式。在这个例子中,无结构数据集使用了与上一行数据集相同的参数,这代表了参数值与数据结构之间的不匹配。虽然这些例子可以提供一些关于算法的直观理解,但这种直观理解可能不适用于非常高维的数据。

为了更好地理解聚类算法的效果,生成了一系列大小适中的数据集,以便观察算法的可扩展性,同时避免运行时间过长。使用了多种数据集,包括带有噪声的圆形数据集、带有噪声的新月形数据集、不同方差的高斯分布数据集以及各向异性分布的数据集。

在设置聚类参数时,为每种数据集-算法对设置了特定的参数,以确保聚类结果的质量。例如,对于带有噪声的圆形数据集,调整了阻尼因子、偏好设置、分位数等参数。对于其他数据集,也进行了类似的参数调整

在聚类过程中,使用了多种聚类算法,包括MiniBatchKMeans、AffinityPropagation、MeanShift、SpectralClustering、Ward、AgglomerativeClustering、DBSCAN、HDBSCAN、OPTICS、BIRCH和GaussianMixture。这些算法在聚类过程中的表现各有特点,例如,MeanShift算法通过估计数据集的带宽来进行聚类,而DBSCAN算法则根据设定的eps值来确定簇的大小。

在聚类完成后,使用matplotlib库来可视化聚类结果。每个聚类算法都会在子图中展示其聚类效果,包括聚类的时间消耗。通过比较不同算法的聚类效果,可以更好地理解每种算法的优缺点,以及它们在不同数据集上的表现。

代码示例

import time import warnings from itertools import cycle, islice import matplotlib.pyplot as plt import numpy as np from sklearn import cluster, datasets, mixture from sklearn.neighbors import kneighbors_graph from sklearn.preprocessing import StandardScaler # 生成数据集 n_samples = 500 seed = 30 noisy_circles = datasets.make_circles(n_samples=n_samples, factor=0.5, noise=0.05, random_state=seed) noisy_moons = datasets.make_moons(n_samples=n_samples, noise=0.05, random_state=seed) blobs = datasets.make_blobs(n_samples=n_samples, random_state=seed) rng = np.random.RandomState(seed) no_structure = rng.rand(n_samples, 2), None # 设置聚类参数 plt.figure(figsize=(9*2+3, 13)) plt.subplots_adjust(left=0.02, right=0.98, bottom=0.001, top=0.95, wspace=0.05, hspace=0.01) plot_num = 1 default_base = { "quantile": 0.3, "eps": 0.3, "damping": 0.9, "preference": -200, "n_neighbors": 3, "n_clusters": 3, "min_samples": 7, "xi": 0.05, "min_cluster_size": 0.1, "allow_single_cluster": True, "hdbscan_min_cluster_size": 15, "hdbscan_min_samples": 3, "random_state": 42, } datasets = [ (noisy_circles, {"damping": 0.77, "preference": -240, "quantile": 0.2, "n_clusters": 2, "min_samples": 7, "xi": 0.08,}), (noisy_moons, {"damping": 0.75, "preference": -220, "n_clusters": 2, "min_samples": 7, "xi": 0.1,}), (blobs, {"min_samples": 7, "xi": 0.1, "min_cluster_size": 0.2}), (no_structure, {}), ] for i_dataset, (dataset, algo_params) in enumerate(datasets): params = default_base.copy() params.update(algo_params) X, y = dataset X = StandardScaler().fit_transform(X) bandwidth = cluster.estimate_bandwidth(X, quantile=params["quantile"]) connectivity = kneighbors_graph(X, n_neighbors=params["n_neighbors"], include_self=False) connectivity = 0.5 * (connectivity + connectivity.T) # 创建聚类对象 ms = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True) two_means = cluster.MiniBatchKMeans(n_clusters=params["n_clusters"], random_state=params["random_state"],) ward = cluster.AgglomerativeClustering(n_clusters=params["n_clusters"], linkage="ward", connectivity=connectivity) spectral = cluster.SpectralClustering(n_clusters=params["n_clusters"], eigen_solver="arpack", affinity="nearest_neighbors", random_state=params["random_state"],) dbscan = cluster.DBSCAN(eps=params["eps"]) hdbscan = cluster.HDBSCAN(min_samples=params["hdbscan_min_samples"], min_cluster_size=params["hdbscan_min_cluster_size"], allow_single_cluster=params["allow_single_cluster"],) optics = cluster.OPTICS(min_samples=params["min_samples"], xi=params["xi"], min_cluster_size=params["min_cluster_size"],) affinity_propagation = cluster.AffinityPropagation(damping=params["damping"], preference=params["preference"], random_state=params["random_state"],) average_linkage = cluster.AgglomerativeClustering(linkage="average", metric="cityblock", n_clusters=params["n_clusters"], connectivity=connectivity,) birch = cluster.Birch(n_clusters=params["n_clusters"]) gmm = mixture.GaussianMixture(n_components=params["n_clusters"], covariance_type="full", random_state=params["random_state"],) clustering_algorithms = ( ("MiniBatch KMeans", two_means), ("Affinity Propagation", affinity_propagation), ("MeanShift", ms), ("Spectral Clustering", spectral), ("Ward", ward), ("Agglomerative Clustering", average_linkage), ("DBSCAN", dbscan), ("HDBSCAN", hdbscan), ("OPTICS", optics), ("BIRCH", birch), ("Gaussian Mixture", gmm), ) for name, algorithm in clustering_algorithms: t0 = time.time() with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="the number of connected components of the connectivity matrix is [0-9]{1,2} > 1. Completing it to avoid stopping the tree early.", category=UserWarning,) warnings.filterwarnings("ignore", message="Graph is not fully connected, spectral embedding may not work as expected.", category=UserWarning,) algorithm.fit(X) t1 = time.time() if hasattr(algorithm, "labels_"): y_pred = algorithm.labels_.astype(int) else: y_pred = algorithm.predict(X) plt.subplot(len(datasets), len(clustering_algorithms), plot_num) if i_dataset == 0: plt.title(name, size=18) colors = np.array(list(islice(cycle(["#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628", "#984ea3", "#999999", "#e41a1c", "#dede00",]), int(max(y_pred) + 1),))) colors = np.append(colors, ["#000000"]) plt.scatter(X[:, 0], X[:, 1], s=10, color=colors[y_pred]) plt.xlim(-2.5, 2.5) plt.ylim(-2.5, 2.5) plt.xticks(()) plt.yticks(()) plt.text(0.99, 0.01, ("%.2fs" % (t1-t0)).lstrip("0"), transform=plt.gca().transAxes, size=15, horizontalalignment="right") plot_num += 1 plt.show()
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485