utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. # Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
  6. import io
  7. import os
  8. import torch
  9. import requests
  10. import tempfile
  11. import contextlib
  12. import numpy as np
  13. import librosa as sf
  14. from typing import Union
  15. from pathlib import Path
  16. from typing import Generator, Union
  17. from abc import ABCMeta, abstractmethod
  18. import torchaudio.compliance.kaldi as Kaldi
  19. from funasr.models.transformer.utils.nets_utils import pad_list
  20. def check_audio_list(audio: list):
  21. audio_dur = 0
  22. for i in range(len(audio)):
  23. seg = audio[i]
  24. assert seg[1] >= seg[0], 'modelscope error: Wrong time stamps.'
  25. assert isinstance(seg[2], np.ndarray), 'modelscope error: Wrong data type.'
  26. assert int(seg[1] * 16000) - int(
  27. seg[0] * 16000
  28. ) == seg[2].shape[
  29. 0], 'modelscope error: audio data in list is inconsistent with time length.'
  30. if i > 0:
  31. assert seg[0] >= audio[
  32. i - 1][1], 'modelscope error: Wrong time stamps.'
  33. audio_dur += seg[1] - seg[0]
  34. return audio_dur
  35. # assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
  36. def sv_preprocess(inputs: Union[np.ndarray, list]):
  37. output = []
  38. for i in range(len(inputs)):
  39. if isinstance(inputs[i], str):
  40. file_bytes = File.read(inputs[i])
  41. data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
  42. if len(data.shape) == 2:
  43. data = data[:, 0]
  44. data = torch.from_numpy(data).unsqueeze(0)
  45. data = data.squeeze(0)
  46. elif isinstance(inputs[i], np.ndarray):
  47. assert len(
  48. inputs[i].shape
  49. ) == 1, 'modelscope error: Input array should be [N, T]'
  50. data = inputs[i]
  51. if data.dtype in ['int16', 'int32', 'int64']:
  52. data = (data / (1 << 15)).astype('float32')
  53. else:
  54. data = data.astype('float32')
  55. data = torch.from_numpy(data)
  56. else:
  57. raise ValueError(
  58. 'modelscope error: The input type is restricted to audio address and nump array.'
  59. )
  60. output.append(data)
  61. return output
  62. def sv_chunk(vad_segments: list, fs = 16000) -> list:
  63. config = {
  64. 'seg_dur': 1.5,
  65. 'seg_shift': 0.75,
  66. }
  67. def seg_chunk(seg_data):
  68. seg_st = seg_data[0]
  69. data = seg_data[2]
  70. chunk_len = int(config['seg_dur'] * fs)
  71. chunk_shift = int(config['seg_shift'] * fs)
  72. last_chunk_ed = 0
  73. seg_res = []
  74. for chunk_st in range(0, data.shape[0], chunk_shift):
  75. chunk_ed = min(chunk_st + chunk_len, data.shape[0])
  76. if chunk_ed <= last_chunk_ed:
  77. break
  78. last_chunk_ed = chunk_ed
  79. chunk_st = max(0, chunk_ed - chunk_len)
  80. chunk_data = data[chunk_st:chunk_ed]
  81. if chunk_data.shape[0] < chunk_len:
  82. chunk_data = np.pad(chunk_data,
  83. (0, chunk_len - chunk_data.shape[0]),
  84. 'constant')
  85. seg_res.append([
  86. chunk_st / fs + seg_st, chunk_ed / fs + seg_st,
  87. chunk_data
  88. ])
  89. return seg_res
  90. segs = []
  91. for i, s in enumerate(vad_segments):
  92. segs.extend(seg_chunk(s))
  93. return segs
  94. def extract_feature(audio):
  95. features = []
  96. feature_times = []
  97. feature_lengths = []
  98. for au in audio:
  99. feature = Kaldi.fbank(
  100. au.unsqueeze(0), num_mel_bins=80)
  101. feature = feature - feature.mean(dim=0, keepdim=True)
  102. features.append(feature)
  103. feature_times.append(au.shape[0])
  104. feature_lengths.append(feature.shape[0])
  105. # padding for batch inference
  106. features_padded = pad_list(features, pad_value=0)
  107. # features = torch.cat(features)
  108. return features_padded, feature_lengths, feature_times
  109. def postprocess(segments: list, vad_segments: list,
  110. labels: np.ndarray, embeddings: np.ndarray) -> list:
  111. assert len(segments) == len(labels)
  112. labels = correct_labels(labels)
  113. distribute_res = []
  114. for i in range(len(segments)):
  115. distribute_res.append([segments[i][0], segments[i][1], labels[i]])
  116. # merge the same speakers chronologically
  117. distribute_res = merge_seque(distribute_res)
  118. # accquire speaker center
  119. spk_embs = []
  120. for i in range(labels.max() + 1):
  121. spk_emb = embeddings[labels == i].mean(0)
  122. spk_embs.append(spk_emb)
  123. spk_embs = np.stack(spk_embs)
  124. def is_overlapped(t1, t2):
  125. if t1 > t2 + 1e-4:
  126. return True
  127. return False
  128. # distribute the overlap region
  129. for i in range(1, len(distribute_res)):
  130. if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]):
  131. p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2
  132. distribute_res[i][0] = p
  133. distribute_res[i - 1][1] = p
  134. # smooth the result
  135. distribute_res = smooth(distribute_res)
  136. return distribute_res
  137. def correct_labels(labels):
  138. labels_id = 0
  139. id2id = {}
  140. new_labels = []
  141. for i in labels:
  142. if i not in id2id:
  143. id2id[i] = labels_id
  144. labels_id += 1
  145. new_labels.append(id2id[i])
  146. return np.array(new_labels)
  147. def merge_seque(distribute_res):
  148. res = [distribute_res[0]]
  149. for i in range(1, len(distribute_res)):
  150. if distribute_res[i][2] != res[-1][2] or distribute_res[i][
  151. 0] > res[-1][1]:
  152. res.append(distribute_res[i])
  153. else:
  154. res[-1][1] = distribute_res[i][1]
  155. return res
  156. def smooth(res, mindur=1):
  157. # short segments are assigned to nearest speakers.
  158. for i in range(len(res)):
  159. res[i][0] = round(res[i][0], 2)
  160. res[i][1] = round(res[i][1], 2)
  161. if res[i][1] - res[i][0] < mindur:
  162. if i == 0:
  163. res[i][2] = res[i + 1][2]
  164. elif i == len(res) - 1:
  165. res[i][2] = res[i - 1][2]
  166. elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]:
  167. res[i][2] = res[i - 1][2]
  168. else:
  169. res[i][2] = res[i + 1][2]
  170. # merge the speakers
  171. res = merge_seque(res)
  172. return res
  173. def distribute_spk(sentence_list, sd_time_list):
  174. sd_sentence_list = []
  175. for d in sentence_list:
  176. sentence_start = d['start']
  177. sentence_end = d['end']
  178. sentence_spk = 0
  179. max_overlap = 0
  180. for sd_time in sd_time_list:
  181. spk_st, spk_ed, spk = sd_time
  182. spk_st = spk_st*1000
  183. spk_ed = spk_ed*1000
  184. overlap = max(
  185. min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0)
  186. if overlap > max_overlap:
  187. max_overlap = overlap
  188. sentence_spk = spk
  189. d['spk'] = int(sentence_spk)
  190. sd_sentence_list.append(d)
  191. return sd_sentence_list
  192. class Storage(metaclass=ABCMeta):
  193. """Abstract class of storage.
  194. All backends need to implement two apis: ``read()`` and ``read_text()``.
  195. ``read()`` reads the file as a byte stream and ``read_text()`` reads
  196. the file as texts.
  197. """
  198. @abstractmethod
  199. def read(self, filepath: str):
  200. pass
  201. @abstractmethod
  202. def read_text(self, filepath: str):
  203. pass
  204. @abstractmethod
  205. def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
  206. pass
  207. @abstractmethod
  208. def write_text(self,
  209. obj: str,
  210. filepath: Union[str, Path],
  211. encoding: str = 'utf-8') -> None:
  212. pass
  213. class LocalStorage(Storage):
  214. """Local hard disk storage"""
  215. def read(self, filepath: Union[str, Path]) -> bytes:
  216. """Read data from a given ``filepath`` with 'rb' mode.
  217. Args:
  218. filepath (str or Path): Path to read data.
  219. Returns:
  220. bytes: Expected bytes object.
  221. """
  222. with open(filepath, 'rb') as f:
  223. content = f.read()
  224. return content
  225. def read_text(self,
  226. filepath: Union[str, Path],
  227. encoding: str = 'utf-8') -> str:
  228. """Read data from a given ``filepath`` with 'r' mode.
  229. Args:
  230. filepath (str or Path): Path to read data.
  231. encoding (str): The encoding format used to open the ``filepath``.
  232. Default: 'utf-8'.
  233. Returns:
  234. str: Expected text reading from ``filepath``.
  235. """
  236. with open(filepath, 'r', encoding=encoding) as f:
  237. value_buf = f.read()
  238. return value_buf
  239. def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
  240. """Write data to a given ``filepath`` with 'wb' mode.
  241. Note:
  242. ``write`` will create a directory if the directory of ``filepath``
  243. does not exist.
  244. Args:
  245. obj (bytes): Data to be written.
  246. filepath (str or Path): Path to write data.
  247. """
  248. dirname = os.path.dirname(filepath)
  249. if dirname and not os.path.exists(dirname):
  250. os.makedirs(dirname, exist_ok=True)
  251. with open(filepath, 'wb') as f:
  252. f.write(obj)
  253. def write_text(self,
  254. obj: str,
  255. filepath: Union[str, Path],
  256. encoding: str = 'utf-8') -> None:
  257. """Write data to a given ``filepath`` with 'w' mode.
  258. Note:
  259. ``write_text`` will create a directory if the directory of
  260. ``filepath`` does not exist.
  261. Args:
  262. obj (str): Data to be written.
  263. filepath (str or Path): Path to write data.
  264. encoding (str): The encoding format used to open the ``filepath``.
  265. Default: 'utf-8'.
  266. """
  267. dirname = os.path.dirname(filepath)
  268. if dirname and not os.path.exists(dirname):
  269. os.makedirs(dirname, exist_ok=True)
  270. with open(filepath, 'w', encoding=encoding) as f:
  271. f.write(obj)
  272. @contextlib.contextmanager
  273. def as_local_path(
  274. self,
  275. filepath: Union[str,
  276. Path]) -> Generator[Union[str, Path], None, None]:
  277. """Only for unified API and do nothing."""
  278. yield filepath
  279. class HTTPStorage(Storage):
  280. """HTTP and HTTPS storage."""
  281. def read(self, url):
  282. # TODO @wenmeng.zwm add progress bar if file is too large
  283. r = requests.get(url)
  284. r.raise_for_status()
  285. return r.content
  286. def read_text(self, url):
  287. r = requests.get(url)
  288. r.raise_for_status()
  289. return r.text
  290. @contextlib.contextmanager
  291. def as_local_path(
  292. self, filepath: str) -> Generator[Union[str, Path], None, None]:
  293. """Download a file from ``filepath``.
  294. ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
  295. can be called with ``with`` statement, and when exists from the
  296. ``with`` statement, the temporary path will be released.
  297. Args:
  298. filepath (str): Download a file from ``filepath``.
  299. Examples:
  300. >>> storage = HTTPStorage()
  301. >>> # After existing from the ``with`` clause,
  302. >>> # the path will be removed
  303. >>> with storage.get_local_path('http://path/to/file') as path:
  304. ... # do something here
  305. """
  306. try:
  307. f = tempfile.NamedTemporaryFile(delete=False)
  308. f.write(self.read(filepath))
  309. f.close()
  310. yield f.name
  311. finally:
  312. os.remove(f.name)
  313. def write(self, obj: bytes, url: Union[str, Path]) -> None:
  314. raise NotImplementedError('write is not supported by HTTP Storage')
  315. def write_text(self,
  316. obj: str,
  317. url: Union[str, Path],
  318. encoding: str = 'utf-8') -> None:
  319. raise NotImplementedError(
  320. 'write_text is not supported by HTTP Storage')
  321. class OSSStorage(Storage):
  322. """OSS storage."""
  323. def __init__(self, oss_config_file=None):
  324. # read from config file or env var
  325. raise NotImplementedError(
  326. 'OSSStorage.__init__ to be implemented in the future')
  327. def read(self, filepath):
  328. raise NotImplementedError(
  329. 'OSSStorage.read to be implemented in the future')
  330. def read_text(self, filepath, encoding='utf-8'):
  331. raise NotImplementedError(
  332. 'OSSStorage.read_text to be implemented in the future')
  333. @contextlib.contextmanager
  334. def as_local_path(
  335. self, filepath: str) -> Generator[Union[str, Path], None, None]:
  336. """Download a file from ``filepath``.
  337. ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
  338. can be called with ``with`` statement, and when exists from the
  339. ``with`` statement, the temporary path will be released.
  340. Args:
  341. filepath (str): Download a file from ``filepath``.
  342. Examples:
  343. >>> storage = OSSStorage()
  344. >>> # After existing from the ``with`` clause,
  345. >>> # the path will be removed
  346. >>> with storage.get_local_path('http://path/to/file') as path:
  347. ... # do something here
  348. """
  349. try:
  350. f = tempfile.NamedTemporaryFile(delete=False)
  351. f.write(self.read(filepath))
  352. f.close()
  353. yield f.name
  354. finally:
  355. os.remove(f.name)
  356. def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
  357. raise NotImplementedError(
  358. 'OSSStorage.write to be implemented in the future')
  359. def write_text(self,
  360. obj: str,
  361. filepath: Union[str, Path],
  362. encoding: str = 'utf-8') -> None:
  363. raise NotImplementedError(
  364. 'OSSStorage.write_text to be implemented in the future')
  365. G_STORAGES = {}
  366. class File(object):
  367. _prefix_to_storage: dict = {
  368. 'oss': OSSStorage,
  369. 'http': HTTPStorage,
  370. 'https': HTTPStorage,
  371. 'local': LocalStorage,
  372. }
  373. @staticmethod
  374. def _get_storage(uri):
  375. assert isinstance(uri,
  376. str), f'uri should be str type, but got {type(uri)}'
  377. if '://' not in uri:
  378. # local path
  379. storage_type = 'local'
  380. else:
  381. prefix, _ = uri.split('://')
  382. storage_type = prefix
  383. assert storage_type in File._prefix_to_storage, \
  384. f'Unsupported uri {uri}, valid prefixs: '\
  385. f'{list(File._prefix_to_storage.keys())}'
  386. if storage_type not in G_STORAGES:
  387. G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
  388. return G_STORAGES[storage_type]
  389. @staticmethod
  390. def read(uri: str) -> bytes:
  391. """Read data from a given ``filepath`` with 'rb' mode.
  392. Args:
  393. filepath (str or Path): Path to read data.
  394. Returns:
  395. bytes: Expected bytes object.
  396. """
  397. storage = File._get_storage(uri)
  398. return storage.read(uri)
  399. @staticmethod
  400. def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str:
  401. """Read data from a given ``filepath`` with 'r' mode.
  402. Args:
  403. filepath (str or Path): Path to read data.
  404. encoding (str): The encoding format used to open the ``filepath``.
  405. Default: 'utf-8'.
  406. Returns:
  407. str: Expected text reading from ``filepath``.
  408. """
  409. storage = File._get_storage(uri)
  410. return storage.read_text(uri)
  411. @staticmethod
  412. def write(obj: bytes, uri: Union[str, Path]) -> None:
  413. """Write data to a given ``filepath`` with 'wb' mode.
  414. Note:
  415. ``write`` will create a directory if the directory of ``filepath``
  416. does not exist.
  417. Args:
  418. obj (bytes): Data to be written.
  419. filepath (str or Path): Path to write data.
  420. """
  421. storage = File._get_storage(uri)
  422. return storage.write(obj, uri)
  423. @staticmethod
  424. def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None:
  425. """Write data to a given ``filepath`` with 'w' mode.
  426. Note:
  427. ``write_text`` will create a directory if the directory of
  428. ``filepath`` does not exist.
  429. Args:
  430. obj (str): Data to be written.
  431. filepath (str or Path): Path to write data.
  432. encoding (str): The encoding format used to open the ``filepath``.
  433. Default: 'utf-8'.
  434. """
  435. storage = File._get_storage(uri)
  436. return storage.write_text(obj, uri)
  437. @contextlib.contextmanager
  438. def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
  439. """Only for unified API and do nothing."""
  440. storage = File._get_storage(uri)
  441. with storage.as_local_path(uri) as local_path:
  442. yield local_path