在数据科学领域,聚类算法是一种无监督学习方法,用于将数据集中的样本划分为若干个簇,使得同一簇内的样本相似度高,而不同簇之间的样本相似度低。本文将介绍几种常见的聚类算法,并展示它们在二维数据集上的表现。除了最后一个数据集外,其他数据集的算法参数都经过了调整,以产生良好的聚类结果。需要注意的是,某些算法对参数值的敏感度比其他算法更高。
最后一个数据集是一个“无聚类”情况的例子:数据是均匀的,没有好的聚类方式。在这个例子中,无结构数据集使用了与上一行数据集相同的参数,这代表了参数值与数据结构之间的不匹配。虽然这些例子可以提供一些关于算法的直观理解,但这种直观理解可能不适用于非常高维的数据。
为了更好地理解聚类算法的效果,生成了一系列大小适中的数据集,以便观察算法的可扩展性,同时避免运行时间过长。使用了多种数据集,包括带有噪声的圆形数据集、带有噪声的新月形数据集、不同方差的高斯分布数据集以及各向异性分布的数据集。
在设置聚类参数时,为每种数据集-算法对设置了特定的参数,以确保聚类结果的质量。例如,对于带有噪声的圆形数据集,调整了阻尼因子、偏好设置、分位数等参数。对于其他数据集,也进行了类似的参数调整。
在聚类过程中,使用了多种聚类算法,包括MiniBatchKMeans、AffinityPropagation、MeanShift、SpectralClustering、Ward、AgglomerativeClustering、DBSCAN、HDBSCAN、OPTICS、BIRCH和GaussianMixture。这些算法在聚类过程中的表现各有特点,例如,MeanShift算法通过估计数据集的带宽来进行聚类,而DBSCAN算法则根据设定的eps值来确定簇的大小。
在聚类完成后,使用matplotlib库来可视化聚类结果。每个聚类算法都会在子图中展示其聚类效果,包括聚类的时间消耗。通过比较不同算法的聚类效果,可以更好地理解每种算法的优缺点,以及它们在不同数据集上的表现。
import time
import warnings
from itertools import cycle, islice
import matplotlib.pyplot as plt
import numpy as np
from sklearn import cluster, datasets, mixture
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import StandardScaler
# 生成数据集
n_samples = 500
seed = 30
noisy_circles = datasets.make_circles(n_samples=n_samples, factor=0.5, noise=0.05, random_state=seed)
noisy_moons = datasets.make_moons(n_samples=n_samples, noise=0.05, random_state=seed)
blobs = datasets.make_blobs(n_samples=n_samples, random_state=seed)
rng = np.random.RandomState(seed)
no_structure = rng.rand(n_samples, 2), None
# 设置聚类参数
plt.figure(figsize=(9*2+3, 13))
plt.subplots_adjust(left=0.02, right=0.98, bottom=0.001, top=0.95, wspace=0.05, hspace=0.01)
plot_num = 1
default_base = {
"quantile": 0.3,
"eps": 0.3,
"damping": 0.9,
"preference": -200,
"n_neighbors": 3,
"n_clusters": 3,
"min_samples": 7,
"xi": 0.05,
"min_cluster_size": 0.1,
"allow_single_cluster": True,
"hdbscan_min_cluster_size": 15,
"hdbscan_min_samples": 3,
"random_state": 42,
}
datasets = [
(noisy_circles, {"damping": 0.77, "preference": -240, "quantile": 0.2, "n_clusters": 2, "min_samples": 7, "xi": 0.08,}),
(noisy_moons, {"damping": 0.75, "preference": -220, "n_clusters": 2, "min_samples": 7, "xi": 0.1,}),
(blobs, {"min_samples": 7, "xi": 0.1, "min_cluster_size": 0.2}),
(no_structure, {}),
]
for i_dataset, (dataset, algo_params) in enumerate(datasets):
params = default_base.copy()
params.update(algo_params)
X, y = dataset
X = StandardScaler().fit_transform(X)
bandwidth = cluster.estimate_bandwidth(X, quantile=params["quantile"])
connectivity = kneighbors_graph(X, n_neighbors=params["n_neighbors"], include_self=False)
connectivity = 0.5 * (connectivity + connectivity.T)
# 创建聚类对象
ms = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)
two_means = cluster.MiniBatchKMeans(n_clusters=params["n_clusters"], random_state=params["random_state"],)
ward = cluster.AgglomerativeClustering(n_clusters=params["n_clusters"], linkage="ward", connectivity=connectivity)
spectral = cluster.SpectralClustering(n_clusters=params["n_clusters"], eigen_solver="arpack", affinity="nearest_neighbors", random_state=params["random_state"],)
dbscan = cluster.DBSCAN(eps=params["eps"])
hdbscan = cluster.HDBSCAN(min_samples=params["hdbscan_min_samples"], min_cluster_size=params["hdbscan_min_cluster_size"], allow_single_cluster=params["allow_single_cluster"],)
optics = cluster.OPTICS(min_samples=params["min_samples"], xi=params["xi"], min_cluster_size=params["min_cluster_size"],)
affinity_propagation = cluster.AffinityPropagation(damping=params["damping"], preference=params["preference"], random_state=params["random_state"],)
average_linkage = cluster.AgglomerativeClustering(linkage="average", metric="cityblock", n_clusters=params["n_clusters"], connectivity=connectivity,)
birch = cluster.Birch(n_clusters=params["n_clusters"])
gmm = mixture.GaussianMixture(n_components=params["n_clusters"], covariance_type="full", random_state=params["random_state"],)
clustering_algorithms = (
("MiniBatch KMeans", two_means),
("Affinity Propagation", affinity_propagation),
("MeanShift", ms),
("Spectral Clustering", spectral),
("Ward", ward),
("Agglomerative Clustering", average_linkage),
("DBSCAN", dbscan),
("HDBSCAN", hdbscan),
("OPTICS", optics),
("BIRCH", birch),
("Gaussian Mixture", gmm),
)
for name, algorithm in clustering_algorithms:
t0 = time.time()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="the number of connected components of the connectivity matrix is [0-9]{1,2} > 1. Completing it to avoid stopping the tree early.", category=UserWarning,)
warnings.filterwarnings("ignore", message="Graph is not fully connected, spectral embedding may not work as expected.", category=UserWarning,)
algorithm.fit(X)
t1 = time.time()
if hasattr(algorithm, "labels_"):
y_pred = algorithm.labels_.astype(int)
else:
y_pred = algorithm.predict(X)
plt.subplot(len(datasets), len(clustering_algorithms), plot_num)
if i_dataset == 0:
plt.title(name, size=18)
colors = np.array(list(islice(cycle(["#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628", "#984ea3", "#999999", "#e41a1c", "#dede00",]), int(max(y_pred) + 1),)))
colors = np.append(colors, ["#000000"])
plt.scatter(X[:, 0], X[:, 1], s=10, color=colors[y_pred])
plt.xlim(-2.5, 2.5)
plt.ylim(-2.5, 2.5)
plt.xticks(())
plt.yticks(())
plt.text(0.99, 0.01, ("%.2fs" % (t1-t0)).lstrip("0"), transform=plt.gca().transAxes, size=15, horizontalalignment="right")
plot_num += 1
plt.show()