Source code for phate.cluster

from sklearn import cluster, exceptions, metrics
import warnings
import numpy as np
from .phate import PHATE


[docs]def silhouette_score(phate_op, n_clusters, random_state=None, **kwargs): """Compute the Silhouette score on KMeans on the PHATE potential Parameters ---------- phate_op : phate.PHATE Fitted PHATE operator n_clusters : int Number of clusters. random_state : int or None, optional (default: None) Random seed for k-means Returns ------- score : float """ cluster_labels = kmeans( phate_op, n_clusters=n_clusters, random_state=random_state, **kwargs ) return metrics.silhouette_score(phate_op.diff_potential, cluster_labels)
[docs]def kmeans( phate_op, n_clusters="auto", max_clusters=10, random_state=None, k=None, **kwargs ): """KMeans on the PHATE potential Clustering on the PHATE operator as introduced in Moon et al. This is similar to spectral clustering. Parameters ---------- phate_op : phate.PHATE Fitted PHATE operator n_clusters : int, optional (default: 'auto') Number of clusters. If 'auto', uses the Silhouette score to determine the optimal number of clusters max_clusters : int, optional (default: 10) Maximum number of clusters to test if using the Silhouette score. random_state : int or None, optional (default: None) Random seed for k-means k : deprecated for `n_clusters` kwargs : additional arguments for `sklearn.cluster.KMeans` Returns ------- clusters : np.ndarray Integer array of cluster assignments """ if k is not None: warnings.warn( "k is deprecated. Please use n_clusters in future.", FutureWarning ) n_clusters = k if not isinstance(phate_op, PHATE): raise TypeError( "Expected phate_op to be of type PHATE. Got {}".format(phate_op) ) if phate_op.graph is not None: if n_clusters == "auto": n_clusters = np.arange(2, max_clusters) silhouette_scores = [ silhouette_score(phate_op, k, random_state=random_state, **kwargs) for k in n_clusters ] n_clusters = n_clusters[np.argmax(silhouette_scores)] return cluster.KMeans( n_clusters, random_state=random_state, **kwargs ).fit_predict(phate_op.diff_potential) else: raise exceptions.NotFittedError( "This PHATE instance is not fitted yet. Call " "'fit' with appropriate arguments before " "using this method." )