cluster_backend.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. # Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
  6. import scipy
  7. import torch
  8. import sklearn
  9. import hdbscan
  10. import numpy as np
  11. from sklearn.cluster._kmeans import k_means
  12. class SpectralCluster:
  13. r"""A spectral clustering mehtod using unnormalized Laplacian of affinity matrix.
  14. This implementation is adapted from https://github.com/speechbrain/speechbrain.
  15. """
  16. def __init__(self, min_num_spks=1, max_num_spks=15, pval=0.022):
  17. self.min_num_spks = min_num_spks
  18. self.max_num_spks = max_num_spks
  19. self.pval = pval
  20. def __call__(self, X, oracle_num=None):
  21. # Similarity matrix computation
  22. sim_mat = self.get_sim_mat(X)
  23. # Refining similarity matrix with pval
  24. prunned_sim_mat = self.p_pruning(sim_mat)
  25. # Symmetrization
  26. sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
  27. # Laplacian calculation
  28. laplacian = self.get_laplacian(sym_prund_sim_mat)
  29. # Get Spectral Embeddings
  30. emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
  31. # Perform clustering
  32. labels = self.cluster_embs(emb, num_of_spk)
  33. return labels
  34. def get_sim_mat(self, X):
  35. # Cosine similarities
  36. M = sklearn.metrics.pairwise.cosine_similarity(X, X)
  37. return M
  38. def p_pruning(self, A):
  39. if A.shape[0] * self.pval < 6:
  40. pval = 6. / A.shape[0]
  41. else:
  42. pval = self.pval
  43. n_elems = int((1 - pval) * A.shape[0])
  44. # For each row in a affinity matrix
  45. for i in range(A.shape[0]):
  46. low_indexes = np.argsort(A[i, :])
  47. low_indexes = low_indexes[0:n_elems]
  48. # Replace smaller similarity values by 0s
  49. A[i, low_indexes] = 0
  50. return A
  51. def get_laplacian(self, M):
  52. M[np.diag_indices(M.shape[0])] = 0
  53. D = np.sum(np.abs(M), axis=1)
  54. D = np.diag(D)
  55. L = D - M
  56. return L
  57. def get_spec_embs(self, L, k_oracle=None):
  58. lambdas, eig_vecs = scipy.linalg.eigh(L)
  59. if k_oracle is not None:
  60. num_of_spk = k_oracle
  61. else:
  62. lambda_gap_list = self.getEigenGaps(
  63. lambdas[self.min_num_spks - 1:self.max_num_spks + 1])
  64. num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
  65. emb = eig_vecs[:, :num_of_spk]
  66. return emb, num_of_spk
  67. def cluster_embs(self, emb, k):
  68. _, labels, _ = k_means(emb, k)
  69. return labels
  70. def getEigenGaps(self, eig_vals):
  71. eig_vals_gap_list = []
  72. for i in range(len(eig_vals) - 1):
  73. gap = float(eig_vals[i + 1]) - float(eig_vals[i])
  74. eig_vals_gap_list.append(gap)
  75. return eig_vals_gap_list
  76. class UmapHdbscan:
  77. r"""
  78. Reference:
  79. - Siqi Zheng, Hongbin Suo. Reformulating Speaker Diarization as Community Detection With
  80. Emphasis On Topological Structure. ICASSP2022
  81. """
  82. def __init__(self,
  83. n_neighbors=20,
  84. n_components=60,
  85. min_samples=10,
  86. min_cluster_size=10,
  87. metric='cosine'):
  88. self.n_neighbors = n_neighbors
  89. self.n_components = n_components
  90. self.min_samples = min_samples
  91. self.min_cluster_size = min_cluster_size
  92. self.metric = metric
  93. def __call__(self, X):
  94. import umap.umap_ as umap
  95. umap_X = umap.UMAP(
  96. n_neighbors=self.n_neighbors,
  97. min_dist=0.0,
  98. n_components=min(self.n_components, X.shape[0] - 2),
  99. metric=self.metric,
  100. ).fit_transform(X)
  101. labels = hdbscan.HDBSCAN(
  102. min_samples=self.min_samples,
  103. min_cluster_size=self.min_cluster_size,
  104. allow_single_cluster=True).fit_predict(umap_X)
  105. return labels
  106. class ClusterBackend(torch.nn.Module):
  107. r"""Perfom clustering for input embeddings and output the labels.
  108. Args:
  109. model_dir: A model dir.
  110. model_config: The model config.
  111. """
  112. def __init__(self):
  113. super().__init__()
  114. self.model_config = {'merge_thr':0.78}
  115. # self.other_config = kwargs
  116. self.spectral_cluster = SpectralCluster()
  117. self.umap_hdbscan_cluster = UmapHdbscan()
  118. def forward(self, X, **params):
  119. # clustering and return the labels
  120. k = params['oracle_num'] if 'oracle_num' in params else None
  121. assert len(
  122. X.shape
  123. ) == 2, 'modelscope error: the shape of input should be [N, C]'
  124. if X.shape[0] < 20:
  125. return np.zeros(X.shape[0], dtype='int')
  126. if X.shape[0] < 2048 or k is not None:
  127. # unexpected corner case
  128. labels = self.spectral_cluster(X, k)
  129. else:
  130. labels = self.umap_hdbscan_cluster(X)
  131. if k is None and 'merge_thr' in self.model_config:
  132. labels = self.merge_by_cos(labels, X,
  133. self.model_config['merge_thr'])
  134. return labels
  135. def merge_by_cos(self, labels, embs, cos_thr):
  136. # merge the similar speakers by cosine similarity
  137. assert cos_thr > 0 and cos_thr <= 1
  138. while True:
  139. spk_num = labels.max() + 1
  140. if spk_num == 1:
  141. break
  142. spk_center = []
  143. for i in range(spk_num):
  144. spk_emb = embs[labels == i].mean(0)
  145. spk_center.append(spk_emb)
  146. assert len(spk_center) > 0
  147. spk_center = np.stack(spk_center, axis=0)
  148. norm_spk_center = spk_center / np.linalg.norm(
  149. spk_center, axis=1, keepdims=True)
  150. affinity = np.matmul(norm_spk_center, norm_spk_center.T)
  151. affinity = np.triu(affinity, 1)
  152. spks = np.unravel_index(np.argmax(affinity), affinity.shape)
  153. if affinity[spks] < cos_thr:
  154. break
  155. for i in range(len(labels)):
  156. if labels[i] == spks[1]:
  157. labels[i] = spks[0]
  158. elif labels[i] > spks[1]:
  159. labels[i] -= 1
  160. return labels