preprocessor.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from pathlib import Path
  4. from typing import Collection
  5. from typing import Dict
  6. from typing import Iterable
  7. from typing import List
  8. from typing import Union
  9. import numpy as np
  10. import scipy.signal
  11. import soundfile
  12. from typeguard import check_argument_types
  13. from typeguard import check_return_type
  14. from funasr.text.build_tokenizer import build_tokenizer
  15. from funasr.text.cleaner import TextCleaner
  16. from funasr.text.token_id_converter import TokenIDConverter
  17. class AbsPreprocessor(ABC):
  18. def __init__(self, train: bool):
  19. self.train = train
  20. @abstractmethod
  21. def __call__(
  22. self, uid: str, data: Dict[str, Union[str, np.ndarray]]
  23. ) -> Dict[str, np.ndarray]:
  24. raise NotImplementedError
  25. def framing(
  26. x,
  27. frame_length: int = 512,
  28. frame_shift: int = 256,
  29. centered: bool = True,
  30. padded: bool = True,
  31. ):
  32. if x.size == 0:
  33. raise ValueError("Input array size is zero")
  34. if frame_length < 1:
  35. raise ValueError("frame_length must be a positive integer")
  36. if frame_length > x.shape[-1]:
  37. raise ValueError("frame_length is greater than input length")
  38. if 0 >= frame_shift:
  39. raise ValueError("frame_shift must be greater than 0")
  40. if centered:
  41. pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [
  42. (frame_length // 2, frame_length // 2)
  43. ]
  44. x = np.pad(x, pad_shape, mode="constant", constant_values=0)
  45. if padded:
  46. # Pad to integer number of windowed segments
  47. # I.e make x.shape[-1] = frame_length + (nseg-1)*nstep,
  48. # with integer nseg
  49. nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length
  50. pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)]
  51. x = np.pad(x, pad_shape, mode="constant", constant_values=0)
  52. # Created strided array of data segments
  53. if frame_length == 1 and frame_length == frame_shift:
  54. result = x[..., None]
  55. else:
  56. shape = x.shape[:-1] + (
  57. (x.shape[-1] - frame_length) // frame_shift + 1,
  58. frame_length,
  59. )
  60. strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1])
  61. result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
  62. return result
  63. def detect_non_silence(
  64. x: np.ndarray,
  65. threshold: float = 0.01,
  66. frame_length: int = 1024,
  67. frame_shift: int = 512,
  68. window: str = "boxcar",
  69. ) -> np.ndarray:
  70. """Power based voice activity detection.
  71. Args:
  72. x: (Channel, Time)
  73. >>> x = np.random.randn(1000)
  74. >>> detect = detect_non_silence(x)
  75. >>> assert x.shape == detect.shape
  76. >>> assert detect.dtype == np.bool
  77. """
  78. if x.shape[-1] < frame_length:
  79. return np.full(x.shape, fill_value=True, dtype=np.bool)
  80. if x.dtype.kind == "i":
  81. x = x.astype(np.float64)
  82. # framed_w: (C, T, F)
  83. framed_w = framing(
  84. x,
  85. frame_length=frame_length,
  86. frame_shift=frame_shift,
  87. centered=False,
  88. padded=True,
  89. )
  90. framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype)
  91. # power: (C, T)
  92. power = (framed_w**2).mean(axis=-1)
  93. # mean_power: (C, 1)
  94. mean_power = np.mean(power, axis=-1, keepdims=True)
  95. if np.all(mean_power == 0):
  96. return np.full(x.shape, fill_value=True, dtype=np.bool)
  97. # detect_frames: (C, T)
  98. detect_frames = power / mean_power > threshold
  99. # detects: (C, T, F)
  100. detects = np.broadcast_to(
  101. detect_frames[..., None], detect_frames.shape + (frame_shift,)
  102. )
  103. # detects: (C, TF)
  104. detects = detects.reshape(*detect_frames.shape[:-1], -1)
  105. # detects: (C, TF)
  106. return np.pad(
  107. detects,
  108. [(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])],
  109. mode="edge",
  110. )
  111. class CommonPreprocessor(AbsPreprocessor):
  112. def __init__(
  113. self,
  114. train: bool,
  115. token_type: str = None,
  116. token_list: Union[Path, str, Iterable[str]] = None,
  117. bpemodel: Union[Path, str, Iterable[str]] = None,
  118. text_cleaner: Collection[str] = None,
  119. g2p_type: str = None,
  120. unk_symbol: str = "<unk>",
  121. space_symbol: str = "<space>",
  122. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  123. delimiter: str = None,
  124. rir_scp: str = None,
  125. rir_apply_prob: float = 1.0,
  126. noise_scp: str = None,
  127. noise_apply_prob: float = 1.0,
  128. noise_db_range: str = "3_10",
  129. speech_volume_normalize: float = None,
  130. speech_name: str = "speech",
  131. text_name: str = "text",
  132. split_with_space: bool = False,
  133. ):
  134. super().__init__(train)
  135. self.train = train
  136. self.speech_name = speech_name
  137. self.text_name = text_name
  138. self.speech_volume_normalize = speech_volume_normalize
  139. self.rir_apply_prob = rir_apply_prob
  140. self.noise_apply_prob = noise_apply_prob
  141. self.split_with_space = split_with_space
  142. if token_type is not None:
  143. if token_list is None:
  144. raise ValueError("token_list is required if token_type is not None")
  145. self.text_cleaner = TextCleaner(text_cleaner)
  146. self.tokenizer = build_tokenizer(
  147. token_type=token_type,
  148. bpemodel=bpemodel,
  149. delimiter=delimiter,
  150. space_symbol=space_symbol,
  151. non_linguistic_symbols=non_linguistic_symbols,
  152. g2p_type=g2p_type,
  153. )
  154. self.token_id_converter = TokenIDConverter(
  155. token_list=token_list,
  156. unk_symbol=unk_symbol,
  157. )
  158. else:
  159. self.text_cleaner = None
  160. self.tokenizer = None
  161. self.token_id_converter = None
  162. if train and rir_scp is not None:
  163. self.rirs = []
  164. with open(rir_scp, "r", encoding="utf-8") as f:
  165. for line in f:
  166. sps = line.strip().split(None, 1)
  167. if len(sps) == 1:
  168. self.rirs.append(sps[0])
  169. else:
  170. self.rirs.append(sps[1])
  171. else:
  172. self.rirs = None
  173. if train and noise_scp is not None:
  174. self.noises = []
  175. with open(noise_scp, "r", encoding="utf-8") as f:
  176. for line in f:
  177. sps = line.strip().split(None, 1)
  178. if len(sps) == 1:
  179. self.noises.append(sps[0])
  180. else:
  181. self.noises.append(sps[1])
  182. sps = noise_db_range.split("_")
  183. if len(sps) == 1:
  184. self.noise_db_low, self.noise_db_high = float(sps[0])
  185. elif len(sps) == 2:
  186. self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1])
  187. else:
  188. raise ValueError(
  189. "Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]"
  190. )
  191. else:
  192. self.noises = None
  193. def _speech_process(
  194. self, data: Dict[str, Union[str, np.ndarray]]
  195. ) -> Dict[str, Union[str, np.ndarray]]:
  196. assert check_argument_types()
  197. if self.speech_name in data:
  198. if self.train and (self.rirs is not None or self.noises is not None):
  199. speech = data[self.speech_name]
  200. nsamples = len(speech)
  201. # speech: (Nmic, Time)
  202. if speech.ndim == 1:
  203. speech = speech[None, :]
  204. else:
  205. speech = speech.T
  206. # Calc power on non shlence region
  207. power = (speech[detect_non_silence(speech)] ** 2).mean()
  208. # 1. Convolve RIR
  209. if self.rirs is not None and self.rir_apply_prob >= np.random.random():
  210. rir_path = np.random.choice(self.rirs)
  211. if rir_path is not None:
  212. rir, _ = soundfile.read(
  213. rir_path, dtype=np.float64, always_2d=True
  214. )
  215. # rir: (Nmic, Time)
  216. rir = rir.T
  217. # speech: (Nmic, Time)
  218. # Note that this operation doesn't change the signal length
  219. speech = scipy.signal.convolve(speech, rir, mode="full")[
  220. :, : speech.shape[1]
  221. ]
  222. # Reverse mean power to the original power
  223. power2 = (speech[detect_non_silence(speech)] ** 2).mean()
  224. speech = np.sqrt(power / max(power2, 1e-10)) * speech
  225. # 2. Add Noise
  226. if (
  227. self.noises is not None
  228. and self.noise_apply_prob >= np.random.random()
  229. ):
  230. noise_path = np.random.choice(self.noises)
  231. if noise_path is not None:
  232. noise_db = np.random.uniform(
  233. self.noise_db_low, self.noise_db_high
  234. )
  235. with soundfile.SoundFile(noise_path) as f:
  236. if f.frames == nsamples:
  237. noise = f.read(dtype=np.float64, always_2d=True)
  238. elif f.frames < nsamples:
  239. offset = np.random.randint(0, nsamples - f.frames)
  240. # noise: (Time, Nmic)
  241. noise = f.read(dtype=np.float64, always_2d=True)
  242. # Repeat noise
  243. noise = np.pad(
  244. noise,
  245. [(offset, nsamples - f.frames - offset), (0, 0)],
  246. mode="wrap",
  247. )
  248. else:
  249. offset = np.random.randint(0, f.frames - nsamples)
  250. f.seek(offset)
  251. # noise: (Time, Nmic)
  252. noise = f.read(
  253. nsamples, dtype=np.float64, always_2d=True
  254. )
  255. if len(noise) != nsamples:
  256. raise RuntimeError(f"Something wrong: {noise_path}")
  257. # noise: (Nmic, Time)
  258. noise = noise.T
  259. noise_power = (noise**2).mean()
  260. scale = (
  261. 10 ** (-noise_db / 20)
  262. * np.sqrt(power)
  263. / np.sqrt(max(noise_power, 1e-10))
  264. )
  265. speech = speech + scale * noise
  266. speech = speech.T
  267. ma = np.max(np.abs(speech))
  268. if ma > 1.0:
  269. speech /= ma
  270. data[self.speech_name] = speech
  271. if self.speech_volume_normalize is not None:
  272. speech = data[self.speech_name]
  273. ma = np.max(np.abs(speech))
  274. data[self.speech_name] = speech * self.speech_volume_normalize / ma
  275. assert check_return_type(data)
  276. return data
  277. def _text_process(
  278. self, data: Dict[str, Union[str, np.ndarray]]
  279. ) -> Dict[str, np.ndarray]:
  280. if self.text_name in data and self.tokenizer is not None:
  281. text = data[self.text_name]
  282. text = self.text_cleaner(text)
  283. if self.split_with_space:
  284. tokens = text.strip().split(" ")
  285. else:
  286. tokens = self.tokenizer.text2tokens(text)
  287. text_ints = self.token_id_converter.tokens2ids(tokens)
  288. data[self.text_name] = np.array(text_ints, dtype=np.int64)
  289. assert check_return_type(data)
  290. return data
  291. def __call__(
  292. self, uid: str, data: Dict[str, Union[str, np.ndarray]]
  293. ) -> Dict[str, np.ndarray]:
  294. assert check_argument_types()
  295. data = self._speech_process(data)
  296. data = self._text_process(data)
  297. return data
  298. class CommonPreprocessor_multi(AbsPreprocessor):
  299. def __init__(
  300. self,
  301. train: bool,
  302. token_type: str = None,
  303. token_list: Union[Path, str, Iterable[str]] = None,
  304. bpemodel: Union[Path, str, Iterable[str]] = None,
  305. text_cleaner: Collection[str] = None,
  306. g2p_type: str = None,
  307. unk_symbol: str = "<unk>",
  308. space_symbol: str = "<space>",
  309. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  310. delimiter: str = None,
  311. speech_name: str = "speech",
  312. text_name: List[str] = ["text"],
  313. ):
  314. super().__init__(train)
  315. self.train = train
  316. self.speech_name = speech_name
  317. self.text_name = text_name
  318. if token_type is not None:
  319. if token_list is None:
  320. raise ValueError("token_list is required if token_type is not None")
  321. self.text_cleaner = TextCleaner(text_cleaner)
  322. self.tokenizer = build_tokenizer(
  323. token_type=token_type,
  324. bpemodel=bpemodel,
  325. delimiter=delimiter,
  326. space_symbol=space_symbol,
  327. non_linguistic_symbols=non_linguistic_symbols,
  328. g2p_type=g2p_type,
  329. )
  330. self.token_id_converter = TokenIDConverter(
  331. token_list=token_list,
  332. unk_symbol=unk_symbol,
  333. )
  334. else:
  335. self.text_cleaner = None
  336. self.tokenizer = None
  337. self.token_id_converter = None
  338. def _text_process(
  339. self, data: Dict[str, Union[str, np.ndarray]]
  340. ) -> Dict[str, np.ndarray]:
  341. for text_n in self.text_name:
  342. if text_n in data and self.tokenizer is not None:
  343. text = data[text_n]
  344. text = self.text_cleaner(text)
  345. tokens = self.tokenizer.text2tokens(text)
  346. text_ints = self.token_id_converter.tokens2ids(tokens)
  347. data[text_n] = np.array(text_ints, dtype=np.int64)
  348. assert check_return_type(data)
  349. return data
  350. def __call__(
  351. self, uid: str, data: Dict[str, Union[str, np.ndarray]]
  352. ) -> Dict[str, np.ndarray]:
  353. assert check_argument_types()
  354. if self.speech_name in data:
  355. # Nothing now: candidates:
  356. # - STFT
  357. # - Fbank
  358. # - CMVN
  359. # - Data augmentation
  360. pass
  361. data = self._text_process(data)
  362. return data
  363. class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
  364. def __init__(
  365. self,
  366. train: bool,
  367. token_type: List[str] = [None],
  368. token_list: List[Union[Path, str, Iterable[str]]] = [None],
  369. bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
  370. text_cleaner: Collection[str] = None,
  371. g2p_type: str = None,
  372. unk_symbol: str = "<unk>",
  373. space_symbol: str = "<space>",
  374. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  375. delimiter: str = None,
  376. rir_scp: str = None,
  377. rir_apply_prob: float = 1.0,
  378. noise_scp: str = None,
  379. noise_apply_prob: float = 1.0,
  380. noise_db_range: str = "3_10",
  381. speech_volume_normalize: float = None,
  382. speech_name: str = "speech",
  383. text_name: List[str] = ["text"],
  384. ):
  385. # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
  386. super().__init__(
  387. train=train,
  388. token_type=token_type[0],
  389. token_list=token_list[0],
  390. bpemodel=bpemodel[0],
  391. text_cleaner=text_cleaner,
  392. g2p_type=g2p_type,
  393. unk_symbol=unk_symbol,
  394. space_symbol=space_symbol,
  395. non_linguistic_symbols=non_linguistic_symbols,
  396. delimiter=delimiter,
  397. speech_name=speech_name,
  398. text_name=text_name[0],
  399. rir_scp=rir_scp,
  400. rir_apply_prob=rir_apply_prob,
  401. noise_scp=noise_scp,
  402. noise_apply_prob=noise_apply_prob,
  403. noise_db_range=noise_db_range,
  404. speech_volume_normalize=speech_volume_normalize,
  405. )
  406. assert (
  407. len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
  408. ), "token_type, token_list, bpemodel, or processing text_name mismatched"
  409. self.num_tokenizer = len(token_type)
  410. self.tokenizer = []
  411. self.token_id_converter = []
  412. for i in range(self.num_tokenizer):
  413. if token_type[i] is not None:
  414. if token_list[i] is None:
  415. raise ValueError("token_list is required if token_type is not None")
  416. self.tokenizer.append(
  417. build_tokenizer(
  418. token_type=token_type[i],
  419. bpemodel=bpemodel[i],
  420. delimiter=delimiter,
  421. space_symbol=space_symbol,
  422. non_linguistic_symbols=non_linguistic_symbols,
  423. g2p_type=g2p_type,
  424. )
  425. )
  426. self.token_id_converter.append(
  427. TokenIDConverter(
  428. token_list=token_list[i],
  429. unk_symbol=unk_symbol,
  430. )
  431. )
  432. else:
  433. self.tokenizer.append(None)
  434. self.token_id_converter.append(None)
  435. self.text_cleaner = TextCleaner(text_cleaner)
  436. self.text_name = text_name # override the text_name from CommonPreprocessor
  437. def _text_process(
  438. self, data: Dict[str, Union[str, np.ndarray]]
  439. ) -> Dict[str, np.ndarray]:
  440. for i in range(self.num_tokenizer):
  441. text_name = self.text_name[i]
  442. if text_name in data and self.tokenizer[i] is not None:
  443. text = data[text_name]
  444. text = self.text_cleaner(text)
  445. tokens = self.tokenizer[i].text2tokens(text)
  446. text_ints = self.token_id_converter[i].tokens2ids(tokens)
  447. data[text_name] = np.array(text_ints, dtype=np.int64)
  448. assert check_return_type(data)
  449. return data