preprocessor.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824
  1. import re
  2. from abc import ABC
  3. from abc import abstractmethod
  4. from pathlib import Path
  5. from typing import Collection
  6. from typing import Dict
  7. from typing import Iterable
  8. from typing import List
  9. from typing import Union
  10. import numpy as np
  11. import scipy.signal
  12. import soundfile
  13. from typeguard import check_argument_types
  14. from typeguard import check_return_type
  15. from funasr.text.build_tokenizer import build_tokenizer
  16. from funasr.text.cleaner import TextCleaner
  17. from funasr.text.token_id_converter import TokenIDConverter
  18. class AbsPreprocessor(ABC):
  19. def __init__(self, train: bool):
  20. self.train = train
  21. @abstractmethod
  22. def __call__(
  23. self, uid: str, data: Dict[str, Union[str, np.ndarray]]
  24. ) -> Dict[str, np.ndarray]:
  25. raise NotImplementedError
  26. def forward_segment(text, dic):
  27. word_list = []
  28. i = 0
  29. while i < len(text):
  30. longest_word = text[i]
  31. for j in range(i + 1, len(text) + 1):
  32. word = text[i:j]
  33. if word in dic:
  34. if len(word) > len(longest_word):
  35. longest_word = word
  36. word_list.append(longest_word)
  37. i += len(longest_word)
  38. return word_list
  39. def seg_tokenize(txt, seg_dict):
  40. pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
  41. out_txt = ""
  42. for word in txt:
  43. word = word.lower()
  44. if word in seg_dict:
  45. out_txt += seg_dict[word] + " "
  46. else:
  47. if pattern.match(word):
  48. for char in word:
  49. if char in seg_dict:
  50. out_txt += seg_dict[char] + " "
  51. else:
  52. out_txt += "<unk>" + " "
  53. else:
  54. out_txt += "<unk>" + " "
  55. return out_txt.strip().split()
  56. def seg_tokenize_wo_pattern(txt, seg_dict):
  57. out_txt = ""
  58. for word in txt:
  59. if word in seg_dict:
  60. out_txt += seg_dict[word] + " "
  61. else:
  62. out_txt += "<unk>" + " "
  63. return out_txt.strip().split()
  64. def framing(
  65. x,
  66. frame_length: int = 512,
  67. frame_shift: int = 256,
  68. centered: bool = True,
  69. padded: bool = True,
  70. ):
  71. if x.size == 0:
  72. raise ValueError("Input array size is zero")
  73. if frame_length < 1:
  74. raise ValueError("frame_length must be a positive integer")
  75. if frame_length > x.shape[-1]:
  76. raise ValueError("frame_length is greater than input length")
  77. if 0 >= frame_shift:
  78. raise ValueError("frame_shift must be greater than 0")
  79. if centered:
  80. pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [
  81. (frame_length // 2, frame_length // 2)
  82. ]
  83. x = np.pad(x, pad_shape, mode="constant", constant_values=0)
  84. if padded:
  85. # Pad to integer number of windowed segments
  86. # I.e make x.shape[-1] = frame_length + (nseg-1)*nstep,
  87. # with integer nseg
  88. nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length
  89. pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)]
  90. x = np.pad(x, pad_shape, mode="constant", constant_values=0)
  91. # Created strided array of data segments
  92. if frame_length == 1 and frame_length == frame_shift:
  93. result = x[..., None]
  94. else:
  95. shape = x.shape[:-1] + (
  96. (x.shape[-1] - frame_length) // frame_shift + 1,
  97. frame_length,
  98. )
  99. strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1])
  100. result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
  101. return result
  102. def detect_non_silence(
  103. x: np.ndarray,
  104. threshold: float = 0.01,
  105. frame_length: int = 1024,
  106. frame_shift: int = 512,
  107. window: str = "boxcar",
  108. ) -> np.ndarray:
  109. """Power based voice activity detection.
  110. Args:
  111. x: (Channel, Time)
  112. >>> x = np.random.randn(1000)
  113. >>> detect = detect_non_silence(x)
  114. >>> assert x.shape == detect.shape
  115. >>> assert detect.dtype == np.bool
  116. """
  117. if x.shape[-1] < frame_length:
  118. return np.full(x.shape, fill_value=True, dtype=np.bool)
  119. if x.dtype.kind == "i":
  120. x = x.astype(np.float64)
  121. # framed_w: (C, T, F)
  122. framed_w = framing(
  123. x,
  124. frame_length=frame_length,
  125. frame_shift=frame_shift,
  126. centered=False,
  127. padded=True,
  128. )
  129. framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype)
  130. # power: (C, T)
  131. power = (framed_w ** 2).mean(axis=-1)
  132. # mean_power: (C, 1)
  133. mean_power = np.mean(power, axis=-1, keepdims=True)
  134. if np.all(mean_power == 0):
  135. return np.full(x.shape, fill_value=True, dtype=np.bool)
  136. # detect_frames: (C, T)
  137. detect_frames = power / mean_power > threshold
  138. # detects: (C, T, F)
  139. detects = np.broadcast_to(
  140. detect_frames[..., None], detect_frames.shape + (frame_shift,)
  141. )
  142. # detects: (C, TF)
  143. detects = detects.reshape(*detect_frames.shape[:-1], -1)
  144. # detects: (C, TF)
  145. return np.pad(
  146. detects,
  147. [(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])],
  148. mode="edge",
  149. )
  150. class CommonPreprocessor(AbsPreprocessor):
  151. def __init__(
  152. self,
  153. train: bool,
  154. token_type: str = None,
  155. token_list: Union[Path, str, Iterable[str]] = None,
  156. bpemodel: Union[Path, str, Iterable[str]] = None,
  157. text_cleaner: Collection[str] = None,
  158. g2p_type: str = None,
  159. unk_symbol: str = "<unk>",
  160. space_symbol: str = "<space>",
  161. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  162. delimiter: str = None,
  163. rir_scp: str = None,
  164. rir_apply_prob: float = 1.0,
  165. noise_scp: str = None,
  166. noise_apply_prob: float = 1.0,
  167. noise_db_range: str = "3_10",
  168. speech_volume_normalize: float = None,
  169. speech_name: str = "speech",
  170. text_name: str = "text",
  171. split_with_space: bool = False,
  172. seg_dict_file: str = None,
  173. ):
  174. super().__init__(train)
  175. self.train = train
  176. self.speech_name = speech_name
  177. self.text_name = text_name
  178. self.speech_volume_normalize = speech_volume_normalize
  179. self.rir_apply_prob = rir_apply_prob
  180. self.noise_apply_prob = noise_apply_prob
  181. self.split_with_space = split_with_space
  182. self.seg_dict = None
  183. if seg_dict_file is not None:
  184. self.seg_dict = {}
  185. with open(seg_dict_file) as f:
  186. lines = f.readlines()
  187. for line in lines:
  188. s = line.strip().split()
  189. key = s[0]
  190. value = s[1:]
  191. self.seg_dict[key] = " ".join(value)
  192. if token_type is not None:
  193. if token_list is None:
  194. raise ValueError("token_list is required if token_type is not None")
  195. self.text_cleaner = TextCleaner(text_cleaner)
  196. self.tokenizer = build_tokenizer(
  197. token_type=token_type,
  198. bpemodel=bpemodel,
  199. delimiter=delimiter,
  200. space_symbol=space_symbol,
  201. non_linguistic_symbols=non_linguistic_symbols,
  202. g2p_type=g2p_type,
  203. )
  204. self.token_id_converter = TokenIDConverter(
  205. token_list=token_list,
  206. unk_symbol=unk_symbol,
  207. )
  208. else:
  209. self.text_cleaner = None
  210. self.tokenizer = None
  211. self.token_id_converter = None
  212. if train and rir_scp is not None:
  213. self.rirs = []
  214. with open(rir_scp, "r", encoding="utf-8") as f:
  215. for line in f:
  216. sps = line.strip().split(None, 1)
  217. if len(sps) == 1:
  218. self.rirs.append(sps[0])
  219. else:
  220. self.rirs.append(sps[1])
  221. else:
  222. self.rirs = None
  223. if train and noise_scp is not None:
  224. self.noises = []
  225. with open(noise_scp, "r", encoding="utf-8") as f:
  226. for line in f:
  227. sps = line.strip().split(None, 1)
  228. if len(sps) == 1:
  229. self.noises.append(sps[0])
  230. else:
  231. self.noises.append(sps[1])
  232. sps = noise_db_range.split("_")
  233. if len(sps) == 1:
  234. self.noise_db_low, self.noise_db_high = float(sps[0])
  235. elif len(sps) == 2:
  236. self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1])
  237. else:
  238. raise ValueError(
  239. "Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]"
  240. )
  241. else:
  242. self.noises = None
  243. def _speech_process(
  244. self, data: Dict[str, Union[str, np.ndarray]]
  245. ) -> Dict[str, Union[str, np.ndarray]]:
  246. assert check_argument_types()
  247. if self.speech_name in data:
  248. if self.train and (self.rirs is not None or self.noises is not None):
  249. speech = data[self.speech_name]
  250. nsamples = len(speech)
  251. # speech: (Nmic, Time)
  252. if speech.ndim == 1:
  253. speech = speech[None, :]
  254. else:
  255. speech = speech.T
  256. # Calc power on non shlence region
  257. power = (speech[detect_non_silence(speech)] ** 2).mean()
  258. # 1. Convolve RIR
  259. if self.rirs is not None and self.rir_apply_prob >= np.random.random():
  260. rir_path = np.random.choice(self.rirs)
  261. if rir_path is not None:
  262. rir, _ = soundfile.read(
  263. rir_path, dtype=np.float64, always_2d=True
  264. )
  265. # rir: (Nmic, Time)
  266. rir = rir.T
  267. # speech: (Nmic, Time)
  268. # Note that this operation doesn't change the signal length
  269. speech = scipy.signal.convolve(speech, rir, mode="full")[
  270. :, : speech.shape[1]
  271. ]
  272. # Reverse mean power to the original power
  273. power2 = (speech[detect_non_silence(speech)] ** 2).mean()
  274. speech = np.sqrt(power / max(power2, 1e-10)) * speech
  275. # 2. Add Noise
  276. if (
  277. self.noises is not None
  278. and self.noise_apply_prob >= np.random.random()
  279. ):
  280. noise_path = np.random.choice(self.noises)
  281. if noise_path is not None:
  282. noise_db = np.random.uniform(
  283. self.noise_db_low, self.noise_db_high
  284. )
  285. with soundfile.SoundFile(noise_path) as f:
  286. if f.frames == nsamples:
  287. noise = f.read(dtype=np.float64, always_2d=True)
  288. elif f.frames < nsamples:
  289. offset = np.random.randint(0, nsamples - f.frames)
  290. # noise: (Time, Nmic)
  291. noise = f.read(dtype=np.float64, always_2d=True)
  292. # Repeat noise
  293. noise = np.pad(
  294. noise,
  295. [(offset, nsamples - f.frames - offset), (0, 0)],
  296. mode="wrap",
  297. )
  298. else:
  299. offset = np.random.randint(0, f.frames - nsamples)
  300. f.seek(offset)
  301. # noise: (Time, Nmic)
  302. noise = f.read(
  303. nsamples, dtype=np.float64, always_2d=True
  304. )
  305. if len(noise) != nsamples:
  306. raise RuntimeError(f"Something wrong: {noise_path}")
  307. # noise: (Nmic, Time)
  308. noise = noise.T
  309. noise_power = (noise ** 2).mean()
  310. scale = (
  311. 10 ** (-noise_db / 20)
  312. * np.sqrt(power)
  313. / np.sqrt(max(noise_power, 1e-10))
  314. )
  315. speech = speech + scale * noise
  316. speech = speech.T
  317. ma = np.max(np.abs(speech))
  318. if ma > 1.0:
  319. speech /= ma
  320. data[self.speech_name] = speech
  321. if self.speech_volume_normalize is not None:
  322. speech = data[self.speech_name]
  323. ma = np.max(np.abs(speech))
  324. data[self.speech_name] = speech * self.speech_volume_normalize / ma
  325. assert check_return_type(data)
  326. return data
  327. def _text_process(
  328. self, data: Dict[str, Union[str, np.ndarray]]
  329. ) -> Dict[str, np.ndarray]:
  330. if self.text_name in data and self.tokenizer is not None:
  331. text = data[self.text_name]
  332. text = self.text_cleaner(text)
  333. if self.split_with_space:
  334. tokens = text.strip().split(" ")
  335. if self.seg_dict is not None:
  336. tokens = seg_tokenize(tokens, self.seg_dict)
  337. else:
  338. tokens = self.tokenizer.text2tokens(text)
  339. text_ints = self.token_id_converter.tokens2ids(tokens)
  340. data[self.text_name] = np.array(text_ints, dtype=np.int64)
  341. assert check_return_type(data)
  342. return data
  343. def __call__(
  344. self, uid: str, data: Dict[str, Union[str, np.ndarray]]
  345. ) -> Dict[str, np.ndarray]:
  346. assert check_argument_types()
  347. data = self._speech_process(data)
  348. data = self._text_process(data)
  349. return data
  350. ## FIXME
  351. class LMPreprocessor(CommonPreprocessor):
  352. def __init__(
  353. self,
  354. train: bool,
  355. token_type: str = None,
  356. token_list: Union[Path, str, Iterable[str]] = None,
  357. bpemodel: Union[Path, str, Iterable[str]] = None,
  358. text_cleaner: Collection[str] = None,
  359. g2p_type: str = None,
  360. unk_symbol: str = "<unk>",
  361. space_symbol: str = "<space>",
  362. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  363. delimiter: str = None,
  364. rir_scp: str = None,
  365. rir_apply_prob: float = 1.0,
  366. noise_scp: str = None,
  367. noise_apply_prob: float = 1.0,
  368. noise_db_range: str = "3_10",
  369. speech_volume_normalize: float = None,
  370. speech_name: str = "speech",
  371. text_name: str = "text",
  372. split_with_space: bool = False,
  373. seg_dict_file: str = None,
  374. ):
  375. super().__init__(train,
  376. token_type,
  377. token_list,
  378. bpemodel,
  379. text_cleaner,
  380. g2p_type,
  381. unk_symbol,
  382. space_symbol,
  383. non_linguistic_symbols,
  384. delimiter,
  385. rir_scp,
  386. rir_apply_prob,
  387. noise_scp,
  388. noise_apply_prob,
  389. noise_db_range,
  390. speech_volume_normalize,
  391. speech_name,
  392. text_name,
  393. split_with_space,
  394. seg_dict_file,
  395. )
  396. def _text_process(
  397. self, data: Dict[str, Union[str, np.ndarray]]
  398. ) -> Dict[str, np.ndarray]:
  399. if self.text_name in data and self.tokenizer is not None:
  400. text = data[self.text_name]
  401. text = self.text_cleaner(text)
  402. if self.split_with_space:
  403. tokens = text.strip().split(" ")
  404. if self.seg_dict is not None:
  405. tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
  406. else:
  407. tokens = self.tokenizer.text2tokens(text)
  408. text_ints = self.token_id_converter.tokens2ids(tokens)
  409. data[self.text_name] = np.array(text_ints, dtype=np.int64)
  410. assert check_return_type(data)
  411. return data
  412. class CommonPreprocessor_multi(AbsPreprocessor):
  413. def __init__(
  414. self,
  415. train: bool,
  416. token_type: str = None,
  417. token_list: Union[Path, str, Iterable[str]] = None,
  418. bpemodel: Union[Path, str, Iterable[str]] = None,
  419. text_cleaner: Collection[str] = None,
  420. g2p_type: str = None,
  421. unk_symbol: str = "<unk>",
  422. space_symbol: str = "<space>",
  423. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  424. delimiter: str = None,
  425. speech_name: str = "speech",
  426. text_name: List[str] = ["text"],
  427. ):
  428. super().__init__(train)
  429. self.train = train
  430. self.speech_name = speech_name
  431. self.text_name = text_name
  432. if token_type is not None:
  433. if token_list is None:
  434. raise ValueError("token_list is required if token_type is not None")
  435. self.text_cleaner = TextCleaner(text_cleaner)
  436. self.tokenizer = build_tokenizer(
  437. token_type=token_type,
  438. bpemodel=bpemodel,
  439. delimiter=delimiter,
  440. space_symbol=space_symbol,
  441. non_linguistic_symbols=non_linguistic_symbols,
  442. g2p_type=g2p_type,
  443. )
  444. self.token_id_converter = TokenIDConverter(
  445. token_list=token_list,
  446. unk_symbol=unk_symbol,
  447. )
  448. else:
  449. self.text_cleaner = None
  450. self.tokenizer = None
  451. self.token_id_converter = None
  452. def _text_process(
  453. self, data: Dict[str, Union[str, np.ndarray]]
  454. ) -> Dict[str, np.ndarray]:
  455. for text_n in self.text_name:
  456. if text_n in data and self.tokenizer is not None:
  457. text = data[text_n]
  458. text = self.text_cleaner(text)
  459. tokens = self.tokenizer.text2tokens(text)
  460. text_ints = self.token_id_converter.tokens2ids(tokens)
  461. data[text_n] = np.array(text_ints, dtype=np.int64)
  462. assert check_return_type(data)
  463. return data
  464. def __call__(
  465. self, uid: str, data: Dict[str, Union[str, np.ndarray]]
  466. ) -> Dict[str, np.ndarray]:
  467. assert check_argument_types()
  468. if self.speech_name in data:
  469. # Nothing now: candidates:
  470. # - STFT
  471. # - Fbank
  472. # - CMVN
  473. # - Data augmentation
  474. pass
  475. data = self._text_process(data)
  476. return data
  477. class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
  478. def __init__(
  479. self,
  480. train: bool,
  481. token_type: List[str] = [None],
  482. token_list: List[Union[Path, str, Iterable[str]]] = [None],
  483. bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
  484. text_cleaner: Collection[str] = None,
  485. g2p_type: str = None,
  486. unk_symbol: str = "<unk>",
  487. space_symbol: str = "<space>",
  488. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  489. delimiter: str = None,
  490. rir_scp: str = None,
  491. rir_apply_prob: float = 1.0,
  492. noise_scp: str = None,
  493. noise_apply_prob: float = 1.0,
  494. noise_db_range: str = "3_10",
  495. speech_volume_normalize: float = None,
  496. speech_name: str = "speech",
  497. text_name: List[str] = ["text"],
  498. ):
  499. # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
  500. super().__init__(
  501. train=train,
  502. token_type=token_type[0],
  503. token_list=token_list[0],
  504. bpemodel=bpemodel[0],
  505. text_cleaner=text_cleaner,
  506. g2p_type=g2p_type,
  507. unk_symbol=unk_symbol,
  508. space_symbol=space_symbol,
  509. non_linguistic_symbols=non_linguistic_symbols,
  510. delimiter=delimiter,
  511. speech_name=speech_name,
  512. text_name=text_name[0],
  513. rir_scp=rir_scp,
  514. rir_apply_prob=rir_apply_prob,
  515. noise_scp=noise_scp,
  516. noise_apply_prob=noise_apply_prob,
  517. noise_db_range=noise_db_range,
  518. speech_volume_normalize=speech_volume_normalize,
  519. )
  520. assert (
  521. len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
  522. ), "token_type, token_list, bpemodel, or processing text_name mismatched"
  523. self.num_tokenizer = len(token_type)
  524. self.tokenizer = []
  525. self.token_id_converter = []
  526. for i in range(self.num_tokenizer):
  527. if token_type[i] is not None:
  528. if token_list[i] is None:
  529. raise ValueError("token_list is required if token_type is not None")
  530. self.tokenizer.append(
  531. build_tokenizer(
  532. token_type=token_type[i],
  533. bpemodel=bpemodel[i],
  534. delimiter=delimiter,
  535. space_symbol=space_symbol,
  536. non_linguistic_symbols=non_linguistic_symbols,
  537. g2p_type=g2p_type,
  538. )
  539. )
  540. self.token_id_converter.append(
  541. TokenIDConverter(
  542. token_list=token_list[i],
  543. unk_symbol=unk_symbol,
  544. )
  545. )
  546. else:
  547. self.tokenizer.append(None)
  548. self.token_id_converter.append(None)
  549. self.text_cleaner = TextCleaner(text_cleaner)
  550. self.text_name = text_name # override the text_name from CommonPreprocessor
  551. def _text_process(
  552. self, data: Dict[str, Union[str, np.ndarray]]
  553. ) -> Dict[str, np.ndarray]:
  554. for i in range(self.num_tokenizer):
  555. text_name = self.text_name[i]
  556. if text_name in data and self.tokenizer[i] is not None:
  557. text = data[text_name]
  558. text = self.text_cleaner(text)
  559. tokens = self.tokenizer[i].text2tokens(text)
  560. text_ints = self.token_id_converter[i].tokens2ids(tokens)
  561. data[text_name] = np.array(text_ints, dtype=np.int64)
  562. assert check_return_type(data)
  563. return data
  564. class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
  565. def __init__(
  566. self,
  567. train: bool,
  568. token_type: str = None,
  569. token_list: Union[Path, str, Iterable[str]] = None,
  570. bpemodel: Union[Path, str, Iterable[str]] = None,
  571. text_cleaner: Collection[str] = None,
  572. g2p_type: str = None,
  573. unk_symbol: str = "<unk>",
  574. space_symbol: str = "<space>",
  575. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  576. delimiter: str = None,
  577. rir_scp: str = None,
  578. rir_apply_prob: float = 1.0,
  579. noise_scp: str = None,
  580. noise_apply_prob: float = 1.0,
  581. noise_db_range: str = "3_10",
  582. speech_volume_normalize: float = None,
  583. speech_name: str = "speech",
  584. text_name: str = "text",
  585. split_text_name: str = "split_text",
  586. split_with_space: bool = False,
  587. seg_dict_file: str = None,
  588. ):
  589. super().__init__(
  590. train=train,
  591. # Force to use word.
  592. token_type="word",
  593. token_list=token_list,
  594. bpemodel=bpemodel,
  595. text_cleaner=text_cleaner,
  596. g2p_type=g2p_type,
  597. unk_symbol=unk_symbol,
  598. space_symbol=space_symbol,
  599. non_linguistic_symbols=non_linguistic_symbols,
  600. delimiter=delimiter,
  601. speech_name=speech_name,
  602. text_name=text_name,
  603. rir_scp=rir_scp,
  604. rir_apply_prob=rir_apply_prob,
  605. noise_scp=noise_scp,
  606. noise_apply_prob=noise_apply_prob,
  607. noise_db_range=noise_db_range,
  608. speech_volume_normalize=speech_volume_normalize,
  609. split_with_space=split_with_space,
  610. seg_dict_file=seg_dict_file,
  611. )
  612. # The data field name for split text.
  613. self.split_text_name = split_text_name
  614. @classmethod
  615. def split_words(cls, text: str):
  616. words = []
  617. segs = text.split()
  618. for seg in segs:
  619. # There is no space in seg.
  620. current_word = ""
  621. for c in seg:
  622. if len(c.encode()) == 1:
  623. # This is an ASCII char.
  624. current_word += c
  625. else:
  626. # This is a Chinese char.
  627. if len(current_word) > 0:
  628. words.append(current_word)
  629. current_word = ""
  630. words.append(c)
  631. if len(current_word) > 0:
  632. words.append(current_word)
  633. return words
  634. def __call__(
  635. self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
  636. ) -> Dict[str, Union[list, np.ndarray]]:
  637. assert check_argument_types()
  638. # Split words.
  639. if isinstance(data[self.text_name], str):
  640. split_text = self.split_words(data[self.text_name])
  641. else:
  642. split_text = data[self.text_name]
  643. data[self.text_name] = " ".join(split_text)
  644. data = self._speech_process(data)
  645. data = self._text_process(data)
  646. data[self.split_text_name] = split_text
  647. return data
  648. def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
  649. result = data[self.split_text_name]
  650. del data[self.split_text_name]
  651. return result
  652. class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
  653. def __init__(
  654. self,
  655. train: bool,
  656. token_type: List[str] = [None],
  657. token_list: List[Union[Path, str, Iterable[str]]] = [None],
  658. bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
  659. text_cleaner: Collection[str] = None,
  660. g2p_type: str = None,
  661. unk_symbol: str = "<unk>",
  662. space_symbol: str = "<space>",
  663. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  664. delimiter: str = None,
  665. rir_scp: str = None,
  666. rir_apply_prob: float = 1.0,
  667. noise_scp: str = None,
  668. noise_apply_prob: float = 1.0,
  669. noise_db_range: str = "3_10",
  670. speech_volume_normalize: float = None,
  671. speech_name: str = "speech",
  672. text_name: List[str] = ["text"],
  673. vad_name: str = "vad_indexes",
  674. ):
  675. # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
  676. super().__init__(
  677. train=train,
  678. token_type=token_type[0],
  679. token_list=token_list[0],
  680. bpemodel=bpemodel[0],
  681. text_cleaner=text_cleaner,
  682. g2p_type=g2p_type,
  683. unk_symbol=unk_symbol,
  684. space_symbol=space_symbol,
  685. non_linguistic_symbols=non_linguistic_symbols,
  686. delimiter=delimiter,
  687. speech_name=speech_name,
  688. text_name=text_name[0],
  689. rir_scp=rir_scp,
  690. rir_apply_prob=rir_apply_prob,
  691. noise_scp=noise_scp,
  692. noise_apply_prob=noise_apply_prob,
  693. noise_db_range=noise_db_range,
  694. speech_volume_normalize=speech_volume_normalize,
  695. )
  696. assert (
  697. len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
  698. ), "token_type, token_list, bpemodel, or processing text_name mismatched"
  699. self.num_tokenizer = len(token_type)
  700. self.tokenizer = []
  701. self.token_id_converter = []
  702. for i in range(self.num_tokenizer):
  703. if token_type[i] is not None:
  704. if token_list[i] is None:
  705. raise ValueError("token_list is required if token_type is not None")
  706. self.tokenizer.append(
  707. build_tokenizer(
  708. token_type=token_type[i],
  709. bpemodel=bpemodel[i],
  710. delimiter=delimiter,
  711. space_symbol=space_symbol,
  712. non_linguistic_symbols=non_linguistic_symbols,
  713. g2p_type=g2p_type,
  714. )
  715. )
  716. self.token_id_converter.append(
  717. TokenIDConverter(
  718. token_list=token_list[i],
  719. unk_symbol=unk_symbol,
  720. )
  721. )
  722. else:
  723. self.tokenizer.append(None)
  724. self.token_id_converter.append(None)
  725. self.text_cleaner = TextCleaner(text_cleaner)
  726. self.text_name = text_name # override the text_name from CommonPreprocessor
  727. self.vad_name = vad_name
  728. def _text_process(
  729. self, data: Dict[str, Union[str, np.ndarray]]
  730. ) -> Dict[str, np.ndarray]:
  731. for i in range(self.num_tokenizer):
  732. text_name = self.text_name[i]
  733. #import pdb; pdb.set_trace()
  734. if text_name in data and self.tokenizer[i] is not None:
  735. text = data[text_name]
  736. text = self.text_cleaner(text)
  737. tokens = self.tokenizer[i].text2tokens(text)
  738. if "vad:" in tokens[-1]:
  739. vad = tokens[-1][4:]
  740. tokens = tokens[:-1]
  741. if len(vad) == 0:
  742. vad = -1
  743. else:
  744. vad = int(vad)
  745. data[self.vad_name] = np.array([vad], dtype=np.int64)
  746. text_ints = self.token_id_converter[i].tokens2ids(tokens)
  747. data[text_name] = np.array(text_ints, dtype=np.int64)
  748. return data
  749. def split_to_mini_sentence(words: list, word_limit: int = 20):
  750. assert word_limit > 1
  751. if len(words) <= word_limit:
  752. return [words]
  753. sentences = []
  754. length = len(words)
  755. sentence_len = length // word_limit
  756. for i in range(sentence_len):
  757. sentences.append(words[i * word_limit:(i + 1) * word_limit])
  758. if length % word_limit > 0:
  759. sentences.append(words[sentence_len * word_limit:])
  760. return sentences