preprocessor.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817
  1. import re
  2. from abc import ABC
  3. from abc import abstractmethod
  4. from pathlib import Path
  5. from typing import Collection
  6. from typing import Dict
  7. from typing import Iterable
  8. from typing import List
  9. from typing import Union
  10. import numpy as np
  11. import scipy.signal
  12. import soundfile
  13. from typeguard import check_argument_types
  14. from typeguard import check_return_type
  15. from funasr.text.build_tokenizer import build_tokenizer
  16. from funasr.text.cleaner import TextCleaner
  17. from funasr.text.token_id_converter import TokenIDConverter
  18. class AbsPreprocessor(ABC):
  19. def __init__(self, train: bool):
  20. self.train = train
  21. @abstractmethod
  22. def __call__(
  23. self, uid: str, data: Dict[str, Union[str, np.ndarray]]
  24. ) -> Dict[str, np.ndarray]:
  25. raise NotImplementedError
  26. def forward_segment(text, dic):
  27. word_list = []
  28. i = 0
  29. while i < len(text):
  30. longest_word = text[i]
  31. for j in range(i + 1, len(text) + 1):
  32. word = text[i:j]
  33. if word in dic:
  34. if len(word) > len(longest_word):
  35. longest_word = word
  36. word_list.append(longest_word)
  37. i += len(longest_word)
  38. return word_list
  39. def seg_tokenize(txt, seg_dict):
  40. out_txt = ""
  41. for word in txt:
  42. word = word.lower()
  43. if word in seg_dict:
  44. out_txt += seg_dict[word] + " "
  45. else:
  46. out_txt += "<unk>" + " "
  47. return out_txt.strip().split()
  48. def seg_tokenize_wo_pattern(txt, seg_dict):
  49. out_txt = ""
  50. for word in txt:
  51. if word in seg_dict:
  52. out_txt += seg_dict[word] + " "
  53. else:
  54. out_txt += "<unk>" + " "
  55. return out_txt.strip().split()
  56. def framing(
  57. x,
  58. frame_length: int = 512,
  59. frame_shift: int = 256,
  60. centered: bool = True,
  61. padded: bool = True,
  62. ):
  63. if x.size == 0:
  64. raise ValueError("Input array size is zero")
  65. if frame_length < 1:
  66. raise ValueError("frame_length must be a positive integer")
  67. if frame_length > x.shape[-1]:
  68. raise ValueError("frame_length is greater than input length")
  69. if 0 >= frame_shift:
  70. raise ValueError("frame_shift must be greater than 0")
  71. if centered:
  72. pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [
  73. (frame_length // 2, frame_length // 2)
  74. ]
  75. x = np.pad(x, pad_shape, mode="constant", constant_values=0)
  76. if padded:
  77. # Pad to integer number of windowed segments
  78. # I.e make x.shape[-1] = frame_length + (nseg-1)*nstep,
  79. # with integer nseg
  80. nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length
  81. pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)]
  82. x = np.pad(x, pad_shape, mode="constant", constant_values=0)
  83. # Created strided array of data segments
  84. if frame_length == 1 and frame_length == frame_shift:
  85. result = x[..., None]
  86. else:
  87. shape = x.shape[:-1] + (
  88. (x.shape[-1] - frame_length) // frame_shift + 1,
  89. frame_length,
  90. )
  91. strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1])
  92. result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
  93. return result
  94. def detect_non_silence(
  95. x: np.ndarray,
  96. threshold: float = 0.01,
  97. frame_length: int = 1024,
  98. frame_shift: int = 512,
  99. window: str = "boxcar",
  100. ) -> np.ndarray:
  101. """Power based voice activity detection.
  102. Args:
  103. x: (Channel, Time)
  104. >>> x = np.random.randn(1000)
  105. >>> detect = detect_non_silence(x)
  106. >>> assert x.shape == detect.shape
  107. >>> assert detect.dtype == np.bool
  108. """
  109. if x.shape[-1] < frame_length:
  110. return np.full(x.shape, fill_value=True, dtype=np.bool)
  111. if x.dtype.kind == "i":
  112. x = x.astype(np.float64)
  113. # framed_w: (C, T, F)
  114. framed_w = framing(
  115. x,
  116. frame_length=frame_length,
  117. frame_shift=frame_shift,
  118. centered=False,
  119. padded=True,
  120. )
  121. framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype)
  122. # power: (C, T)
  123. power = (framed_w ** 2).mean(axis=-1)
  124. # mean_power: (C, 1)
  125. mean_power = np.mean(power, axis=-1, keepdims=True)
  126. if np.all(mean_power == 0):
  127. return np.full(x.shape, fill_value=True, dtype=np.bool)
  128. # detect_frames: (C, T)
  129. detect_frames = power / mean_power > threshold
  130. # detects: (C, T, F)
  131. detects = np.broadcast_to(
  132. detect_frames[..., None], detect_frames.shape + (frame_shift,)
  133. )
  134. # detects: (C, TF)
  135. detects = detects.reshape(*detect_frames.shape[:-1], -1)
  136. # detects: (C, TF)
  137. return np.pad(
  138. detects,
  139. [(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])],
  140. mode="edge",
  141. )
  142. class CommonPreprocessor(AbsPreprocessor):
  143. def __init__(
  144. self,
  145. train: bool,
  146. token_type: str = None,
  147. token_list: Union[Path, str, Iterable[str]] = None,
  148. bpemodel: Union[Path, str, Iterable[str]] = None,
  149. text_cleaner: Collection[str] = None,
  150. g2p_type: str = None,
  151. unk_symbol: str = "<unk>",
  152. space_symbol: str = "<space>",
  153. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  154. delimiter: str = None,
  155. rir_scp: str = None,
  156. rir_apply_prob: float = 1.0,
  157. noise_scp: str = None,
  158. noise_apply_prob: float = 1.0,
  159. noise_db_range: str = "3_10",
  160. speech_volume_normalize: float = None,
  161. speech_name: str = "speech",
  162. text_name: str = "text",
  163. split_with_space: bool = False,
  164. seg_dict_file: str = None,
  165. ):
  166. super().__init__(train)
  167. self.train = train
  168. self.speech_name = speech_name
  169. self.text_name = text_name
  170. self.speech_volume_normalize = speech_volume_normalize
  171. self.rir_apply_prob = rir_apply_prob
  172. self.noise_apply_prob = noise_apply_prob
  173. self.split_with_space = split_with_space
  174. self.seg_dict = None
  175. if seg_dict_file is not None:
  176. self.seg_dict = {}
  177. with open(seg_dict_file) as f:
  178. lines = f.readlines()
  179. for line in lines:
  180. s = line.strip().split()
  181. key = s[0]
  182. value = s[1:]
  183. self.seg_dict[key] = " ".join(value)
  184. if token_type is not None:
  185. if token_list is None:
  186. raise ValueError("token_list is required if token_type is not None")
  187. self.text_cleaner = TextCleaner(text_cleaner)
  188. self.tokenizer = build_tokenizer(
  189. token_type=token_type,
  190. bpemodel=bpemodel,
  191. delimiter=delimiter,
  192. space_symbol=space_symbol,
  193. non_linguistic_symbols=non_linguistic_symbols,
  194. g2p_type=g2p_type,
  195. )
  196. self.token_id_converter = TokenIDConverter(
  197. token_list=token_list,
  198. unk_symbol=unk_symbol,
  199. )
  200. else:
  201. self.text_cleaner = None
  202. self.tokenizer = None
  203. self.token_id_converter = None
  204. if train and rir_scp is not None:
  205. self.rirs = []
  206. with open(rir_scp, "r", encoding="utf-8") as f:
  207. for line in f:
  208. sps = line.strip().split(None, 1)
  209. if len(sps) == 1:
  210. self.rirs.append(sps[0])
  211. else:
  212. self.rirs.append(sps[1])
  213. else:
  214. self.rirs = None
  215. if train and noise_scp is not None:
  216. self.noises = []
  217. with open(noise_scp, "r", encoding="utf-8") as f:
  218. for line in f:
  219. sps = line.strip().split(None, 1)
  220. if len(sps) == 1:
  221. self.noises.append(sps[0])
  222. else:
  223. self.noises.append(sps[1])
  224. sps = noise_db_range.split("_")
  225. if len(sps) == 1:
  226. self.noise_db_low, self.noise_db_high = float(sps[0])
  227. elif len(sps) == 2:
  228. self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1])
  229. else:
  230. raise ValueError(
  231. "Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]"
  232. )
  233. else:
  234. self.noises = None
  235. def _speech_process(
  236. self, data: Dict[str, Union[str, np.ndarray]]
  237. ) -> Dict[str, Union[str, np.ndarray]]:
  238. assert check_argument_types()
  239. if self.speech_name in data:
  240. if self.train and (self.rirs is not None or self.noises is not None):
  241. speech = data[self.speech_name]
  242. nsamples = len(speech)
  243. # speech: (Nmic, Time)
  244. if speech.ndim == 1:
  245. speech = speech[None, :]
  246. else:
  247. speech = speech.T
  248. # Calc power on non shlence region
  249. power = (speech[detect_non_silence(speech)] ** 2).mean()
  250. # 1. Convolve RIR
  251. if self.rirs is not None and self.rir_apply_prob >= np.random.random():
  252. rir_path = np.random.choice(self.rirs)
  253. if rir_path is not None:
  254. rir, _ = soundfile.read(
  255. rir_path, dtype=np.float64, always_2d=True
  256. )
  257. # rir: (Nmic, Time)
  258. rir = rir.T
  259. # speech: (Nmic, Time)
  260. # Note that this operation doesn't change the signal length
  261. speech = scipy.signal.convolve(speech, rir, mode="full")[
  262. :, : speech.shape[1]
  263. ]
  264. # Reverse mean power to the original power
  265. power2 = (speech[detect_non_silence(speech)] ** 2).mean()
  266. speech = np.sqrt(power / max(power2, 1e-10)) * speech
  267. # 2. Add Noise
  268. if (
  269. self.noises is not None
  270. and self.noise_apply_prob >= np.random.random()
  271. ):
  272. noise_path = np.random.choice(self.noises)
  273. if noise_path is not None:
  274. noise_db = np.random.uniform(
  275. self.noise_db_low, self.noise_db_high
  276. )
  277. with soundfile.SoundFile(noise_path) as f:
  278. if f.frames == nsamples:
  279. noise = f.read(dtype=np.float64, always_2d=True)
  280. elif f.frames < nsamples:
  281. offset = np.random.randint(0, nsamples - f.frames)
  282. # noise: (Time, Nmic)
  283. noise = f.read(dtype=np.float64, always_2d=True)
  284. # Repeat noise
  285. noise = np.pad(
  286. noise,
  287. [(offset, nsamples - f.frames - offset), (0, 0)],
  288. mode="wrap",
  289. )
  290. else:
  291. offset = np.random.randint(0, f.frames - nsamples)
  292. f.seek(offset)
  293. # noise: (Time, Nmic)
  294. noise = f.read(
  295. nsamples, dtype=np.float64, always_2d=True
  296. )
  297. if len(noise) != nsamples:
  298. raise RuntimeError(f"Something wrong: {noise_path}")
  299. # noise: (Nmic, Time)
  300. noise = noise.T
  301. noise_power = (noise ** 2).mean()
  302. scale = (
  303. 10 ** (-noise_db / 20)
  304. * np.sqrt(power)
  305. / np.sqrt(max(noise_power, 1e-10))
  306. )
  307. speech = speech + scale * noise
  308. speech = speech.T
  309. ma = np.max(np.abs(speech))
  310. if ma > 1.0:
  311. speech /= ma
  312. data[self.speech_name] = speech
  313. if self.speech_volume_normalize is not None:
  314. speech = data[self.speech_name]
  315. ma = np.max(np.abs(speech))
  316. data[self.speech_name] = speech * self.speech_volume_normalize / ma
  317. assert check_return_type(data)
  318. return data
  319. def _text_process(
  320. self, data: Dict[str, Union[str, np.ndarray]]
  321. ) -> Dict[str, np.ndarray]:
  322. if self.text_name in data and self.tokenizer is not None:
  323. text = data[self.text_name]
  324. text = self.text_cleaner(text)
  325. if self.split_with_space:
  326. tokens = text.strip().split(" ")
  327. if self.seg_dict is not None:
  328. tokens = seg_tokenize(tokens, self.seg_dict)
  329. else:
  330. tokens = self.tokenizer.text2tokens(text)
  331. text_ints = self.token_id_converter.tokens2ids(tokens)
  332. data[self.text_name] = np.array(text_ints, dtype=np.int64)
  333. assert check_return_type(data)
  334. return data
  335. def __call__(
  336. self, uid: str, data: Dict[str, Union[str, np.ndarray]]
  337. ) -> Dict[str, np.ndarray]:
  338. assert check_argument_types()
  339. data = self._speech_process(data)
  340. data = self._text_process(data)
  341. return data
  342. ## FIXME
  343. class LMPreprocessor(CommonPreprocessor):
  344. def __init__(
  345. self,
  346. train: bool,
  347. token_type: str = None,
  348. token_list: Union[Path, str, Iterable[str]] = None,
  349. bpemodel: Union[Path, str, Iterable[str]] = None,
  350. text_cleaner: Collection[str] = None,
  351. g2p_type: str = None,
  352. unk_symbol: str = "<unk>",
  353. space_symbol: str = "<space>",
  354. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  355. delimiter: str = None,
  356. rir_scp: str = None,
  357. rir_apply_prob: float = 1.0,
  358. noise_scp: str = None,
  359. noise_apply_prob: float = 1.0,
  360. noise_db_range: str = "3_10",
  361. speech_volume_normalize: float = None,
  362. speech_name: str = "speech",
  363. text_name: str = "text",
  364. split_with_space: bool = False,
  365. seg_dict_file: str = None,
  366. ):
  367. super().__init__(train,
  368. token_type,
  369. token_list,
  370. bpemodel,
  371. text_cleaner,
  372. g2p_type,
  373. unk_symbol,
  374. space_symbol,
  375. non_linguistic_symbols,
  376. delimiter,
  377. rir_scp,
  378. rir_apply_prob,
  379. noise_scp,
  380. noise_apply_prob,
  381. noise_db_range,
  382. speech_volume_normalize,
  383. speech_name,
  384. text_name,
  385. split_with_space,
  386. seg_dict_file,
  387. )
  388. def _text_process(
  389. self, data: Dict[str, Union[str, np.ndarray]]
  390. ) -> Dict[str, np.ndarray]:
  391. if self.text_name in data and self.tokenizer is not None:
  392. text = data[self.text_name]
  393. text = self.text_cleaner(text)
  394. if self.split_with_space:
  395. tokens = text.strip().split(" ")
  396. if self.seg_dict is not None:
  397. tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
  398. else:
  399. tokens = self.tokenizer.text2tokens(text)
  400. text_ints = self.token_id_converter.tokens2ids(tokens)
  401. data[self.text_name] = np.array(text_ints, dtype=np.int64)
  402. assert check_return_type(data)
  403. return data
  404. class CommonPreprocessor_multi(AbsPreprocessor):
  405. def __init__(
  406. self,
  407. train: bool,
  408. token_type: str = None,
  409. token_list: Union[Path, str, Iterable[str]] = None,
  410. bpemodel: Union[Path, str, Iterable[str]] = None,
  411. text_cleaner: Collection[str] = None,
  412. g2p_type: str = None,
  413. unk_symbol: str = "<unk>",
  414. space_symbol: str = "<space>",
  415. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  416. delimiter: str = None,
  417. speech_name: str = "speech",
  418. text_name: List[str] = ["text"],
  419. ):
  420. super().__init__(train)
  421. self.train = train
  422. self.speech_name = speech_name
  423. self.text_name = text_name
  424. if token_type is not None:
  425. if token_list is None:
  426. raise ValueError("token_list is required if token_type is not None")
  427. self.text_cleaner = TextCleaner(text_cleaner)
  428. self.tokenizer = build_tokenizer(
  429. token_type=token_type,
  430. bpemodel=bpemodel,
  431. delimiter=delimiter,
  432. space_symbol=space_symbol,
  433. non_linguistic_symbols=non_linguistic_symbols,
  434. g2p_type=g2p_type,
  435. )
  436. self.token_id_converter = TokenIDConverter(
  437. token_list=token_list,
  438. unk_symbol=unk_symbol,
  439. )
  440. else:
  441. self.text_cleaner = None
  442. self.tokenizer = None
  443. self.token_id_converter = None
  444. def _text_process(
  445. self, data: Dict[str, Union[str, np.ndarray]]
  446. ) -> Dict[str, np.ndarray]:
  447. for text_n in self.text_name:
  448. if text_n in data and self.tokenizer is not None:
  449. text = data[text_n]
  450. text = self.text_cleaner(text)
  451. tokens = self.tokenizer.text2tokens(text)
  452. text_ints = self.token_id_converter.tokens2ids(tokens)
  453. data[text_n] = np.array(text_ints, dtype=np.int64)
  454. assert check_return_type(data)
  455. return data
  456. def __call__(
  457. self, uid: str, data: Dict[str, Union[str, np.ndarray]]
  458. ) -> Dict[str, np.ndarray]:
  459. assert check_argument_types()
  460. if self.speech_name in data:
  461. # Nothing now: candidates:
  462. # - STFT
  463. # - Fbank
  464. # - CMVN
  465. # - Data augmentation
  466. pass
  467. data = self._text_process(data)
  468. return data
  469. class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
  470. def __init__(
  471. self,
  472. train: bool,
  473. token_type: List[str] = [None],
  474. token_list: List[Union[Path, str, Iterable[str]]] = [None],
  475. bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
  476. text_cleaner: Collection[str] = None,
  477. g2p_type: str = None,
  478. unk_symbol: str = "<unk>",
  479. space_symbol: str = "<space>",
  480. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  481. delimiter: str = None,
  482. rir_scp: str = None,
  483. rir_apply_prob: float = 1.0,
  484. noise_scp: str = None,
  485. noise_apply_prob: float = 1.0,
  486. noise_db_range: str = "3_10",
  487. speech_volume_normalize: float = None,
  488. speech_name: str = "speech",
  489. text_name: List[str] = ["text"],
  490. ):
  491. # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
  492. super().__init__(
  493. train=train,
  494. token_type=token_type[0],
  495. token_list=token_list[0],
  496. bpemodel=bpemodel[0],
  497. text_cleaner=text_cleaner,
  498. g2p_type=g2p_type,
  499. unk_symbol=unk_symbol,
  500. space_symbol=space_symbol,
  501. non_linguistic_symbols=non_linguistic_symbols,
  502. delimiter=delimiter,
  503. speech_name=speech_name,
  504. text_name=text_name[0],
  505. rir_scp=rir_scp,
  506. rir_apply_prob=rir_apply_prob,
  507. noise_scp=noise_scp,
  508. noise_apply_prob=noise_apply_prob,
  509. noise_db_range=noise_db_range,
  510. speech_volume_normalize=speech_volume_normalize,
  511. )
  512. assert (
  513. len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
  514. ), "token_type, token_list, bpemodel, or processing text_name mismatched"
  515. self.num_tokenizer = len(token_type)
  516. self.tokenizer = []
  517. self.token_id_converter = []
  518. for i in range(self.num_tokenizer):
  519. if token_type[i] is not None:
  520. if token_list[i] is None:
  521. raise ValueError("token_list is required if token_type is not None")
  522. self.tokenizer.append(
  523. build_tokenizer(
  524. token_type=token_type[i],
  525. bpemodel=bpemodel[i],
  526. delimiter=delimiter,
  527. space_symbol=space_symbol,
  528. non_linguistic_symbols=non_linguistic_symbols,
  529. g2p_type=g2p_type,
  530. )
  531. )
  532. self.token_id_converter.append(
  533. TokenIDConverter(
  534. token_list=token_list[i],
  535. unk_symbol=unk_symbol,
  536. )
  537. )
  538. else:
  539. self.tokenizer.append(None)
  540. self.token_id_converter.append(None)
  541. self.text_cleaner = TextCleaner(text_cleaner)
  542. self.text_name = text_name # override the text_name from CommonPreprocessor
  543. def _text_process(
  544. self, data: Dict[str, Union[str, np.ndarray]]
  545. ) -> Dict[str, np.ndarray]:
  546. for i in range(self.num_tokenizer):
  547. text_name = self.text_name[i]
  548. if text_name in data and self.tokenizer[i] is not None:
  549. text = data[text_name]
  550. text = self.text_cleaner(text)
  551. tokens = self.tokenizer[i].text2tokens(text)
  552. text_ints = self.token_id_converter[i].tokens2ids(tokens)
  553. data[text_name] = np.array(text_ints, dtype=np.int64)
  554. assert check_return_type(data)
  555. return data
  556. class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
  557. def __init__(
  558. self,
  559. train: bool,
  560. token_type: str = None,
  561. token_list: Union[Path, str, Iterable[str]] = None,
  562. bpemodel: Union[Path, str, Iterable[str]] = None,
  563. text_cleaner: Collection[str] = None,
  564. g2p_type: str = None,
  565. unk_symbol: str = "<unk>",
  566. space_symbol: str = "<space>",
  567. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  568. delimiter: str = None,
  569. rir_scp: str = None,
  570. rir_apply_prob: float = 1.0,
  571. noise_scp: str = None,
  572. noise_apply_prob: float = 1.0,
  573. noise_db_range: str = "3_10",
  574. speech_volume_normalize: float = None,
  575. speech_name: str = "speech",
  576. text_name: str = "text",
  577. split_text_name: str = "split_text",
  578. split_with_space: bool = False,
  579. seg_dict_file: str = None,
  580. ):
  581. super().__init__(
  582. train=train,
  583. # Force to use word.
  584. token_type="word",
  585. token_list=token_list,
  586. bpemodel=bpemodel,
  587. text_cleaner=text_cleaner,
  588. g2p_type=g2p_type,
  589. unk_symbol=unk_symbol,
  590. space_symbol=space_symbol,
  591. non_linguistic_symbols=non_linguistic_symbols,
  592. delimiter=delimiter,
  593. speech_name=speech_name,
  594. text_name=text_name,
  595. rir_scp=rir_scp,
  596. rir_apply_prob=rir_apply_prob,
  597. noise_scp=noise_scp,
  598. noise_apply_prob=noise_apply_prob,
  599. noise_db_range=noise_db_range,
  600. speech_volume_normalize=speech_volume_normalize,
  601. split_with_space=split_with_space,
  602. seg_dict_file=seg_dict_file,
  603. )
  604. # The data field name for split text.
  605. self.split_text_name = split_text_name
  606. @classmethod
  607. def split_words(cls, text: str):
  608. words = []
  609. segs = text.split()
  610. for seg in segs:
  611. # There is no space in seg.
  612. current_word = ""
  613. for c in seg:
  614. if len(c.encode()) == 1:
  615. # This is an ASCII char.
  616. current_word += c
  617. else:
  618. # This is a Chinese char.
  619. if len(current_word) > 0:
  620. words.append(current_word)
  621. current_word = ""
  622. words.append(c)
  623. if len(current_word) > 0:
  624. words.append(current_word)
  625. return words
  626. def __call__(
  627. self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
  628. ) -> Dict[str, Union[list, np.ndarray]]:
  629. assert check_argument_types()
  630. # Split words.
  631. if isinstance(data[self.text_name], str):
  632. split_text = self.split_words(data[self.text_name])
  633. else:
  634. split_text = data[self.text_name]
  635. data[self.text_name] = " ".join(split_text)
  636. data = self._speech_process(data)
  637. data = self._text_process(data)
  638. data[self.split_text_name] = split_text
  639. return data
  640. def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
  641. result = data[self.split_text_name]
  642. del data[self.split_text_name]
  643. return result
  644. class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
  645. def __init__(
  646. self,
  647. train: bool,
  648. token_type: List[str] = [None],
  649. token_list: List[Union[Path, str, Iterable[str]]] = [None],
  650. bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
  651. text_cleaner: Collection[str] = None,
  652. g2p_type: str = None,
  653. unk_symbol: str = "<unk>",
  654. space_symbol: str = "<space>",
  655. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  656. delimiter: str = None,
  657. rir_scp: str = None,
  658. rir_apply_prob: float = 1.0,
  659. noise_scp: str = None,
  660. noise_apply_prob: float = 1.0,
  661. noise_db_range: str = "3_10",
  662. speech_volume_normalize: float = None,
  663. speech_name: str = "speech",
  664. text_name: List[str] = ["text"],
  665. vad_name: str = "vad_indexes",
  666. ):
  667. # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
  668. super().__init__(
  669. train=train,
  670. token_type=token_type[0],
  671. token_list=token_list[0],
  672. bpemodel=bpemodel[0],
  673. text_cleaner=text_cleaner,
  674. g2p_type=g2p_type,
  675. unk_symbol=unk_symbol,
  676. space_symbol=space_symbol,
  677. non_linguistic_symbols=non_linguistic_symbols,
  678. delimiter=delimiter,
  679. speech_name=speech_name,
  680. text_name=text_name[0],
  681. rir_scp=rir_scp,
  682. rir_apply_prob=rir_apply_prob,
  683. noise_scp=noise_scp,
  684. noise_apply_prob=noise_apply_prob,
  685. noise_db_range=noise_db_range,
  686. speech_volume_normalize=speech_volume_normalize,
  687. )
  688. assert (
  689. len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
  690. ), "token_type, token_list, bpemodel, or processing text_name mismatched"
  691. self.num_tokenizer = len(token_type)
  692. self.tokenizer = []
  693. self.token_id_converter = []
  694. for i in range(self.num_tokenizer):
  695. if token_type[i] is not None:
  696. if token_list[i] is None:
  697. raise ValueError("token_list is required if token_type is not None")
  698. self.tokenizer.append(
  699. build_tokenizer(
  700. token_type=token_type[i],
  701. bpemodel=bpemodel[i],
  702. delimiter=delimiter,
  703. space_symbol=space_symbol,
  704. non_linguistic_symbols=non_linguistic_symbols,
  705. g2p_type=g2p_type,
  706. )
  707. )
  708. self.token_id_converter.append(
  709. TokenIDConverter(
  710. token_list=token_list[i],
  711. unk_symbol=unk_symbol,
  712. )
  713. )
  714. else:
  715. self.tokenizer.append(None)
  716. self.token_id_converter.append(None)
  717. self.text_cleaner = TextCleaner(text_cleaner)
  718. self.text_name = text_name # override the text_name from CommonPreprocessor
  719. self.vad_name = vad_name
  720. def _text_process(
  721. self, data: Dict[str, Union[str, np.ndarray]]
  722. ) -> Dict[str, np.ndarray]:
  723. for i in range(self.num_tokenizer):
  724. text_name = self.text_name[i]
  725. #import pdb; pdb.set_trace()
  726. if text_name in data and self.tokenizer[i] is not None:
  727. text = data[text_name]
  728. text = self.text_cleaner(text)
  729. tokens = self.tokenizer[i].text2tokens(text)
  730. if "vad:" in tokens[-1]:
  731. vad = tokens[-1][4:]
  732. tokens = tokens[:-1]
  733. if len(vad) == 0:
  734. vad = -1
  735. else:
  736. vad = int(vad)
  737. data[self.vad_name] = np.array([vad], dtype=np.int64)
  738. text_ints = self.token_id_converter[i].tokens2ids(tokens)
  739. data[text_name] = np.array(text_ints, dtype=np.int64)
  740. return data
  741. def split_to_mini_sentence(words: list, word_limit: int = 20):
  742. assert word_limit > 1
  743. if len(words) <= word_limit:
  744. return [words]
  745. sentences = []
  746. length = len(words)
  747. sentence_len = length // word_limit
  748. for i in range(sentence_len):
  749. sentences.append(words[i * word_limit:(i + 1) * word_limit])
  750. if length % word_limit > 0:
  751. sentences.append(words[sentence_len * word_limit:])
  752. return sentences