speaker_utils.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """ Some implementations are adapted from https://github.com/yuyq96/D-TDNN
  3. """
  4. import io
  5. from typing import Union
  6. import librosa as sf
  7. import numpy as np
  8. import torch
  9. import torch.nn.functional as F
  10. import torchaudio.compliance.kaldi as Kaldi
  11. from torch import nn
  12. from funasr.utils.modelscope_file import File
  13. def check_audio_list(audio: list):
  14. audio_dur = 0
  15. for i in range(len(audio)):
  16. seg = audio[i]
  17. assert seg[1] >= seg[0], 'modelscope error: Wrong time stamps.'
  18. assert isinstance(seg[2], np.ndarray), 'modelscope error: Wrong data type.'
  19. assert int(seg[1] * 16000) - int(
  20. seg[0] * 16000
  21. ) == seg[2].shape[
  22. 0], 'modelscope error: audio data in list is inconsistent with time length.'
  23. if i > 0:
  24. assert seg[0] >= audio[
  25. i - 1][1], 'modelscope error: Wrong time stamps.'
  26. audio_dur += seg[1] - seg[0]
  27. return audio_dur
  28. # assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
  29. def sv_preprocess(inputs: Union[np.ndarray, list]):
  30. output = []
  31. for i in range(len(inputs)):
  32. if isinstance(inputs[i], str):
  33. file_bytes = File.read(inputs[i])
  34. data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
  35. if len(data.shape) == 2:
  36. data = data[:, 0]
  37. data = torch.from_numpy(data).unsqueeze(0)
  38. data = data.squeeze(0)
  39. elif isinstance(inputs[i], np.ndarray):
  40. assert len(
  41. inputs[i].shape
  42. ) == 1, 'modelscope error: Input array should be [N, T]'
  43. data = inputs[i]
  44. if data.dtype in ['int16', 'int32', 'int64']:
  45. data = (data / (1 << 15)).astype('float32')
  46. else:
  47. data = data.astype('float32')
  48. data = torch.from_numpy(data)
  49. else:
  50. raise ValueError(
  51. 'modelscope error: The input type is restricted to audio address and nump array.'
  52. )
  53. output.append(data)
  54. return output
  55. def sv_chunk(vad_segments: list, fs = 16000) -> list:
  56. config = {
  57. 'seg_dur': 1.5,
  58. 'seg_shift': 0.75,
  59. }
  60. def seg_chunk(seg_data):
  61. seg_st = seg_data[0]
  62. data = seg_data[2]
  63. chunk_len = int(config['seg_dur'] * fs)
  64. chunk_shift = int(config['seg_shift'] * fs)
  65. last_chunk_ed = 0
  66. seg_res = []
  67. for chunk_st in range(0, data.shape[0], chunk_shift):
  68. chunk_ed = min(chunk_st + chunk_len, data.shape[0])
  69. if chunk_ed <= last_chunk_ed:
  70. break
  71. last_chunk_ed = chunk_ed
  72. chunk_st = max(0, chunk_ed - chunk_len)
  73. chunk_data = data[chunk_st:chunk_ed]
  74. if chunk_data.shape[0] < chunk_len:
  75. chunk_data = np.pad(chunk_data,
  76. (0, chunk_len - chunk_data.shape[0]),
  77. 'constant')
  78. seg_res.append([
  79. chunk_st / fs + seg_st, chunk_ed / fs + seg_st,
  80. chunk_data
  81. ])
  82. return seg_res
  83. segs = []
  84. for i, s in enumerate(vad_segments):
  85. segs.extend(seg_chunk(s))
  86. return segs
  87. def extract_feature(audio):
  88. features = []
  89. for au in audio:
  90. feature = Kaldi.fbank(
  91. au.unsqueeze(0), num_mel_bins=80)
  92. feature = feature - feature.mean(dim=0, keepdim=True)
  93. features.append(feature.unsqueeze(0))
  94. features = torch.cat(features)
  95. return features
  96. def postprocess(segments: list, vad_segments: list,
  97. labels: np.ndarray, embeddings: np.ndarray) -> list:
  98. assert len(segments) == len(labels)
  99. labels = correct_labels(labels)
  100. distribute_res = []
  101. for i in range(len(segments)):
  102. distribute_res.append([segments[i][0], segments[i][1], labels[i]])
  103. # merge the same speakers chronologically
  104. distribute_res = merge_seque(distribute_res)
  105. # accquire speaker center
  106. spk_embs = []
  107. for i in range(labels.max() + 1):
  108. spk_emb = embeddings[labels == i].mean(0)
  109. spk_embs.append(spk_emb)
  110. spk_embs = np.stack(spk_embs)
  111. def is_overlapped(t1, t2):
  112. if t1 > t2 + 1e-4:
  113. return True
  114. return False
  115. # distribute the overlap region
  116. for i in range(1, len(distribute_res)):
  117. if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]):
  118. p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2
  119. distribute_res[i][0] = p
  120. distribute_res[i - 1][1] = p
  121. # smooth the result
  122. distribute_res = smooth(distribute_res)
  123. return distribute_res
  124. def correct_labels(labels):
  125. labels_id = 0
  126. id2id = {}
  127. new_labels = []
  128. for i in labels:
  129. if i not in id2id:
  130. id2id[i] = labels_id
  131. labels_id += 1
  132. new_labels.append(id2id[i])
  133. return np.array(new_labels)
  134. def merge_seque(distribute_res):
  135. res = [distribute_res[0]]
  136. for i in range(1, len(distribute_res)):
  137. if distribute_res[i][2] != res[-1][2] or distribute_res[i][
  138. 0] > res[-1][1]:
  139. res.append(distribute_res[i])
  140. else:
  141. res[-1][1] = distribute_res[i][1]
  142. return res
  143. def smooth(res, mindur=1):
  144. # short segments are assigned to nearest speakers.
  145. for i in range(len(res)):
  146. res[i][0] = round(res[i][0], 2)
  147. res[i][1] = round(res[i][1], 2)
  148. if res[i][1] - res[i][0] < mindur:
  149. if i == 0:
  150. res[i][2] = res[i + 1][2]
  151. elif i == len(res) - 1:
  152. res[i][2] = res[i - 1][2]
  153. elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]:
  154. res[i][2] = res[i - 1][2]
  155. else:
  156. res[i][2] = res[i + 1][2]
  157. # merge the speakers
  158. res = merge_seque(res)
  159. return res
  160. def distribute_spk(sentence_list, sd_time_list):
  161. sd_sentence_list = []
  162. for d in sentence_list:
  163. sentence_start = d['ts_list'][0][0]
  164. sentence_end = d['ts_list'][-1][1]
  165. sentence_spk = 0
  166. max_overlap = 0
  167. for sd_time in sd_time_list:
  168. spk_st, spk_ed, spk = sd_time
  169. spk_st = spk_st*1000
  170. spk_ed = spk_ed*1000
  171. overlap = max(
  172. min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0)
  173. if overlap > max_overlap:
  174. max_overlap = overlap
  175. sentence_spk = spk
  176. d['spk'] = sentence_spk
  177. sd_sentence_list.append(d)
  178. return sd_sentence_list