cluster_backend.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, Union
  3. import hdbscan
  4. import numpy as np
  5. import scipy
  6. import sklearn
  7. import umap
  8. from sklearn.cluster._kmeans import k_means
  9. from torch import nn
  10. class SpectralCluster:
  11. r"""A spectral clustering mehtod using unnormalized Laplacian of affinity matrix.
  12. This implementation is adapted from https://github.com/speechbrain/speechbrain.
  13. """
  14. def __init__(self, min_num_spks=1, max_num_spks=15, pval=0.022):
  15. self.min_num_spks = min_num_spks
  16. self.max_num_spks = max_num_spks
  17. self.pval = pval
  18. def __call__(self, X, oracle_num=None):
  19. # Similarity matrix computation
  20. sim_mat = self.get_sim_mat(X)
  21. # Refining similarity matrix with pval
  22. prunned_sim_mat = self.p_pruning(sim_mat)
  23. # Symmetrization
  24. sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
  25. # Laplacian calculation
  26. laplacian = self.get_laplacian(sym_prund_sim_mat)
  27. # Get Spectral Embeddings
  28. emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
  29. # Perform clustering
  30. labels = self.cluster_embs(emb, num_of_spk)
  31. return labels
  32. def get_sim_mat(self, X):
  33. # Cosine similarities
  34. M = sklearn.metrics.pairwise.cosine_similarity(X, X)
  35. return M
  36. def p_pruning(self, A):
  37. if A.shape[0] * self.pval < 6:
  38. pval = 6. / A.shape[0]
  39. else:
  40. pval = self.pval
  41. n_elems = int((1 - pval) * A.shape[0])
  42. # For each row in a affinity matrix
  43. for i in range(A.shape[0]):
  44. low_indexes = np.argsort(A[i, :])
  45. low_indexes = low_indexes[0:n_elems]
  46. # Replace smaller similarity values by 0s
  47. A[i, low_indexes] = 0
  48. return A
  49. def get_laplacian(self, M):
  50. M[np.diag_indices(M.shape[0])] = 0
  51. D = np.sum(np.abs(M), axis=1)
  52. D = np.diag(D)
  53. L = D - M
  54. return L
  55. def get_spec_embs(self, L, k_oracle=None):
  56. lambdas, eig_vecs = scipy.linalg.eigh(L)
  57. if k_oracle is not None:
  58. num_of_spk = k_oracle
  59. else:
  60. lambda_gap_list = self.getEigenGaps(
  61. lambdas[self.min_num_spks - 1:self.max_num_spks + 1])
  62. num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
  63. emb = eig_vecs[:, :num_of_spk]
  64. return emb, num_of_spk
  65. def cluster_embs(self, emb, k):
  66. _, labels, _ = k_means(emb, k)
  67. return labels
  68. def getEigenGaps(self, eig_vals):
  69. eig_vals_gap_list = []
  70. for i in range(len(eig_vals) - 1):
  71. gap = float(eig_vals[i + 1]) - float(eig_vals[i])
  72. eig_vals_gap_list.append(gap)
  73. return eig_vals_gap_list
  74. class UmapHdbscan:
  75. r"""
  76. Reference:
  77. - Siqi Zheng, Hongbin Suo. Reformulating Speaker Diarization as Community Detection With
  78. Emphasis On Topological Structure. ICASSP2022
  79. """
  80. def __init__(self,
  81. n_neighbors=20,
  82. n_components=60,
  83. min_samples=10,
  84. min_cluster_size=10,
  85. metric='cosine'):
  86. self.n_neighbors = n_neighbors
  87. self.n_components = n_components
  88. self.min_samples = min_samples
  89. self.min_cluster_size = min_cluster_size
  90. self.metric = metric
  91. def __call__(self, X):
  92. umap_X = umap.UMAP(
  93. n_neighbors=self.n_neighbors,
  94. min_dist=0.0,
  95. n_components=min(self.n_components, X.shape[0] - 2),
  96. metric=self.metric,
  97. ).fit_transform(X)
  98. labels = hdbscan.HDBSCAN(
  99. min_samples=self.min_samples,
  100. min_cluster_size=self.min_cluster_size,
  101. allow_single_cluster=True).fit_predict(umap_X)
  102. return labels
  103. class ClusterBackend(nn.Module):
  104. r"""Perfom clustering for input embeddings and output the labels.
  105. Args:
  106. model_dir: A model dir.
  107. model_config: The model config.
  108. """
  109. def __init__(self):
  110. super().__init__()
  111. self.model_config = {'merge_thr':0.78}
  112. # self.other_config = kwargs
  113. self.spectral_cluster = SpectralCluster()
  114. self.umap_hdbscan_cluster = UmapHdbscan()
  115. def forward(self, X, **params):
  116. # clustering and return the labels
  117. k = params['oracle_num'] if 'oracle_num' in params else None
  118. assert len(
  119. X.shape
  120. ) == 2, 'modelscope error: the shape of input should be [N, C]'
  121. if X.shape[0] < 20:
  122. return np.zeros(X.shape[0], dtype='int')
  123. if X.shape[0] < 2048 or k is not None:
  124. labels = self.spectral_cluster(X, k)
  125. else:
  126. labels = self.umap_hdbscan_cluster(X)
  127. if k is None and 'merge_thr' in self.model_config:
  128. labels = self.merge_by_cos(labels, X,
  129. self.model_config['merge_thr'])
  130. return labels
  131. def merge_by_cos(self, labels, embs, cos_thr):
  132. # merge the similar speakers by cosine similarity
  133. assert cos_thr > 0 and cos_thr <= 1
  134. while True:
  135. spk_num = labels.max() + 1
  136. if spk_num == 1:
  137. break
  138. spk_center = []
  139. for i in range(spk_num):
  140. spk_emb = embs[labels == i].mean(0)
  141. spk_center.append(spk_emb)
  142. assert len(spk_center) > 0
  143. spk_center = np.stack(spk_center, axis=0)
  144. norm_spk_center = spk_center / np.linalg.norm(
  145. spk_center, axis=1, keepdims=True)
  146. affinity = np.matmul(norm_spk_center, norm_spk_center.T)
  147. affinity = np.triu(affinity, 1)
  148. spks = np.unravel_index(np.argmax(affinity), affinity.shape)
  149. if affinity[spks] < cos_thr:
  150. break
  151. for i in range(len(labels)):
  152. if labels[i] == spks[1]:
  153. labels[i] = spks[0]
  154. elif labels[i] > spks[1]:
  155. labels[i] -= 1
  156. return labels