preprocessor.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from pathlib import Path
  4. from typing import Collection
  5. from typing import Dict
  6. from typing import Iterable
  7. from typing import List
  8. from typing import Union
  9. import numpy as np
  10. import scipy.signal
  11. import soundfile
  12. from typeguard import check_argument_types
  13. from typeguard import check_return_type
  14. from funasr.text.build_tokenizer import build_tokenizer
  15. from funasr.text.cleaner import TextCleaner
  16. from funasr.text.token_id_converter import TokenIDConverter
  17. class AbsPreprocessor(ABC):
  18. def __init__(self, train: bool):
  19. self.train = train
  20. @abstractmethod
  21. def __call__(
  22. self, uid: str, data: Dict[str, Union[str, np.ndarray]]
  23. ) -> Dict[str, np.ndarray]:
  24. raise NotImplementedError
  25. def forward_segment(text, dic):
  26. word_list = []
  27. i = 0
  28. while i < len(text):
  29. longest_word = text[i]
  30. for j in range(i + 1, len(text) + 1):
  31. word = text[i:j]
  32. if word in dic:
  33. if len(word) > len(longest_word):
  34. longest_word = word
  35. word_list.append(longest_word)
  36. i += len(longest_word)
  37. return word_list
  38. def seg_tokenize(txt, seg_dict):
  39. out_txt = ""
  40. for word in txt:
  41. if word in seg_dict:
  42. out_txt += seg_dict[word] + " "
  43. else:
  44. out_txt += "<unk>" + " "
  45. return out_txt.strip().split()
  46. def seg_tokenize_wo_pattern(txt, seg_dict):
  47. out_txt = ""
  48. for word in txt:
  49. if word in seg_dict:
  50. out_txt += seg_dict[word] + " "
  51. else:
  52. out_txt += "<unk>" + " "
  53. return out_txt.strip().split()
  54. def framing(
  55. x,
  56. frame_length: int = 512,
  57. frame_shift: int = 256,
  58. centered: bool = True,
  59. padded: bool = True,
  60. ):
  61. if x.size == 0:
  62. raise ValueError("Input array size is zero")
  63. if frame_length < 1:
  64. raise ValueError("frame_length must be a positive integer")
  65. if frame_length > x.shape[-1]:
  66. raise ValueError("frame_length is greater than input length")
  67. if 0 >= frame_shift:
  68. raise ValueError("frame_shift must be greater than 0")
  69. if centered:
  70. pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [
  71. (frame_length // 2, frame_length // 2)
  72. ]
  73. x = np.pad(x, pad_shape, mode="constant", constant_values=0)
  74. if padded:
  75. # Pad to integer number of windowed segments
  76. # I.e make x.shape[-1] = frame_length + (nseg-1)*nstep,
  77. # with integer nseg
  78. nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length
  79. pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)]
  80. x = np.pad(x, pad_shape, mode="constant", constant_values=0)
  81. # Created strided array of data segments
  82. if frame_length == 1 and frame_length == frame_shift:
  83. result = x[..., None]
  84. else:
  85. shape = x.shape[:-1] + (
  86. (x.shape[-1] - frame_length) // frame_shift + 1,
  87. frame_length,
  88. )
  89. strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1])
  90. result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
  91. return result
  92. def detect_non_silence(
  93. x: np.ndarray,
  94. threshold: float = 0.01,
  95. frame_length: int = 1024,
  96. frame_shift: int = 512,
  97. window: str = "boxcar",
  98. ) -> np.ndarray:
  99. """Power based voice activity detection.
  100. Args:
  101. x: (Channel, Time)
  102. >>> x = np.random.randn(1000)
  103. >>> detect = detect_non_silence(x)
  104. >>> assert x.shape == detect.shape
  105. >>> assert detect.dtype == np.bool
  106. """
  107. if x.shape[-1] < frame_length:
  108. return np.full(x.shape, fill_value=True, dtype=np.bool)
  109. if x.dtype.kind == "i":
  110. x = x.astype(np.float64)
  111. # framed_w: (C, T, F)
  112. framed_w = framing(
  113. x,
  114. frame_length=frame_length,
  115. frame_shift=frame_shift,
  116. centered=False,
  117. padded=True,
  118. )
  119. framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype)
  120. # power: (C, T)
  121. power = (framed_w ** 2).mean(axis=-1)
  122. # mean_power: (C, 1)
  123. mean_power = np.mean(power, axis=-1, keepdims=True)
  124. if np.all(mean_power == 0):
  125. return np.full(x.shape, fill_value=True, dtype=np.bool)
  126. # detect_frames: (C, T)
  127. detect_frames = power / mean_power > threshold
  128. # detects: (C, T, F)
  129. detects = np.broadcast_to(
  130. detect_frames[..., None], detect_frames.shape + (frame_shift,)
  131. )
  132. # detects: (C, TF)
  133. detects = detects.reshape(*detect_frames.shape[:-1], -1)
  134. # detects: (C, TF)
  135. return np.pad(
  136. detects,
  137. [(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])],
  138. mode="edge",
  139. )
  140. class CommonPreprocessor(AbsPreprocessor):
  141. def __init__(
  142. self,
  143. train: bool,
  144. token_type: str = None,
  145. token_list: Union[Path, str, Iterable[str]] = None,
  146. bpemodel: Union[Path, str, Iterable[str]] = None,
  147. text_cleaner: Collection[str] = None,
  148. g2p_type: str = None,
  149. unk_symbol: str = "<unk>",
  150. space_symbol: str = "<space>",
  151. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  152. delimiter: str = None,
  153. rir_scp: str = None,
  154. rir_apply_prob: float = 1.0,
  155. noise_scp: str = None,
  156. noise_apply_prob: float = 1.0,
  157. noise_db_range: str = "3_10",
  158. speech_volume_normalize: float = None,
  159. speech_name: str = "speech",
  160. text_name: str = "text",
  161. split_with_space: bool = False,
  162. seg_dict_file: str = None,
  163. ):
  164. super().__init__(train)
  165. self.train = train
  166. self.speech_name = speech_name
  167. self.text_name = text_name
  168. self.speech_volume_normalize = speech_volume_normalize
  169. self.rir_apply_prob = rir_apply_prob
  170. self.noise_apply_prob = noise_apply_prob
  171. self.split_with_space = split_with_space
  172. self.seg_dict = None
  173. if seg_dict_file is not None:
  174. self.seg_dict = {}
  175. with open(seg_dict_file) as f:
  176. lines = f.readlines()
  177. for line in lines:
  178. s = line.strip().split()
  179. key = s[0]
  180. value = s[1:]
  181. self.seg_dict[key] = " ".join(value)
  182. if token_type is not None:
  183. if token_list is None:
  184. raise ValueError("token_list is required if token_type is not None")
  185. self.text_cleaner = TextCleaner(text_cleaner)
  186. self.tokenizer = build_tokenizer(
  187. token_type=token_type,
  188. bpemodel=bpemodel,
  189. delimiter=delimiter,
  190. space_symbol=space_symbol,
  191. non_linguistic_symbols=non_linguistic_symbols,
  192. g2p_type=g2p_type,
  193. )
  194. self.token_id_converter = TokenIDConverter(
  195. token_list=token_list,
  196. unk_symbol=unk_symbol,
  197. )
  198. else:
  199. self.text_cleaner = None
  200. self.tokenizer = None
  201. self.token_id_converter = None
  202. if train and rir_scp is not None:
  203. self.rirs = []
  204. with open(rir_scp, "r", encoding="utf-8") as f:
  205. for line in f:
  206. sps = line.strip().split(None, 1)
  207. if len(sps) == 1:
  208. self.rirs.append(sps[0])
  209. else:
  210. self.rirs.append(sps[1])
  211. else:
  212. self.rirs = None
  213. if train and noise_scp is not None:
  214. self.noises = []
  215. with open(noise_scp, "r", encoding="utf-8") as f:
  216. for line in f:
  217. sps = line.strip().split(None, 1)
  218. if len(sps) == 1:
  219. self.noises.append(sps[0])
  220. else:
  221. self.noises.append(sps[1])
  222. sps = noise_db_range.split("_")
  223. if len(sps) == 1:
  224. self.noise_db_low, self.noise_db_high = float(sps[0])
  225. elif len(sps) == 2:
  226. self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1])
  227. else:
  228. raise ValueError(
  229. "Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]"
  230. )
  231. else:
  232. self.noises = None
  233. def _speech_process(
  234. self, data: Dict[str, Union[str, np.ndarray]]
  235. ) -> Dict[str, Union[str, np.ndarray]]:
  236. assert check_argument_types()
  237. if self.speech_name in data:
  238. if self.train and (self.rirs is not None or self.noises is not None):
  239. speech = data[self.speech_name]
  240. nsamples = len(speech)
  241. # speech: (Nmic, Time)
  242. if speech.ndim == 1:
  243. speech = speech[None, :]
  244. else:
  245. speech = speech.T
  246. # Calc power on non shlence region
  247. power = (speech[detect_non_silence(speech)] ** 2).mean()
  248. # 1. Convolve RIR
  249. if self.rirs is not None and self.rir_apply_prob >= np.random.random():
  250. rir_path = np.random.choice(self.rirs)
  251. if rir_path is not None:
  252. rir, _ = soundfile.read(
  253. rir_path, dtype=np.float64, always_2d=True
  254. )
  255. # rir: (Nmic, Time)
  256. rir = rir.T
  257. # speech: (Nmic, Time)
  258. # Note that this operation doesn't change the signal length
  259. speech = scipy.signal.convolve(speech, rir, mode="full")[
  260. :, : speech.shape[1]
  261. ]
  262. # Reverse mean power to the original power
  263. power2 = (speech[detect_non_silence(speech)] ** 2).mean()
  264. speech = np.sqrt(power / max(power2, 1e-10)) * speech
  265. # 2. Add Noise
  266. if (
  267. self.noises is not None
  268. and self.noise_apply_prob >= np.random.random()
  269. ):
  270. noise_path = np.random.choice(self.noises)
  271. if noise_path is not None:
  272. noise_db = np.random.uniform(
  273. self.noise_db_low, self.noise_db_high
  274. )
  275. with soundfile.SoundFile(noise_path) as f:
  276. if f.frames == nsamples:
  277. noise = f.read(dtype=np.float64, always_2d=True)
  278. elif f.frames < nsamples:
  279. offset = np.random.randint(0, nsamples - f.frames)
  280. # noise: (Time, Nmic)
  281. noise = f.read(dtype=np.float64, always_2d=True)
  282. # Repeat noise
  283. noise = np.pad(
  284. noise,
  285. [(offset, nsamples - f.frames - offset), (0, 0)],
  286. mode="wrap",
  287. )
  288. else:
  289. offset = np.random.randint(0, f.frames - nsamples)
  290. f.seek(offset)
  291. # noise: (Time, Nmic)
  292. noise = f.read(
  293. nsamples, dtype=np.float64, always_2d=True
  294. )
  295. if len(noise) != nsamples:
  296. raise RuntimeError(f"Something wrong: {noise_path}")
  297. # noise: (Nmic, Time)
  298. noise = noise.T
  299. noise_power = (noise ** 2).mean()
  300. scale = (
  301. 10 ** (-noise_db / 20)
  302. * np.sqrt(power)
  303. / np.sqrt(max(noise_power, 1e-10))
  304. )
  305. speech = speech + scale * noise
  306. speech = speech.T
  307. ma = np.max(np.abs(speech))
  308. if ma > 1.0:
  309. speech /= ma
  310. data[self.speech_name] = speech
  311. if self.speech_volume_normalize is not None:
  312. speech = data[self.speech_name]
  313. ma = np.max(np.abs(speech))
  314. data[self.speech_name] = speech * self.speech_volume_normalize / ma
  315. assert check_return_type(data)
  316. return data
  317. def _text_process(
  318. self, data: Dict[str, Union[str, np.ndarray]]
  319. ) -> Dict[str, np.ndarray]:
  320. if self.text_name in data and self.tokenizer is not None:
  321. text = data[self.text_name]
  322. text = self.text_cleaner(text)
  323. if self.split_with_space:
  324. tokens = text.strip().split(" ")
  325. if self.seg_dict is not None:
  326. tokens = forward_segment("".join(tokens), self.seg_dict)
  327. tokens = seg_tokenize(tokens, self.seg_dict)
  328. else:
  329. tokens = self.tokenizer.text2tokens(text)
  330. text_ints = self.token_id_converter.tokens2ids(tokens)
  331. data[self.text_name] = np.array(text_ints, dtype=np.int64)
  332. assert check_return_type(data)
  333. return data
  334. def __call__(
  335. self, uid: str, data: Dict[str, Union[str, np.ndarray]]
  336. ) -> Dict[str, np.ndarray]:
  337. assert check_argument_types()
  338. data = self._speech_process(data)
  339. data = self._text_process(data)
  340. return data
  341. ## FIXME
  342. class LMPreprocessor(CommonPreprocessor):
  343. def __init__(
  344. self,
  345. train: bool,
  346. token_type: str = None,
  347. token_list: Union[Path, str, Iterable[str]] = None,
  348. bpemodel: Union[Path, str, Iterable[str]] = None,
  349. text_cleaner: Collection[str] = None,
  350. g2p_type: str = None,
  351. unk_symbol: str = "<unk>",
  352. space_symbol: str = "<space>",
  353. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  354. delimiter: str = None,
  355. rir_scp: str = None,
  356. rir_apply_prob: float = 1.0,
  357. noise_scp: str = None,
  358. noise_apply_prob: float = 1.0,
  359. noise_db_range: str = "3_10",
  360. speech_volume_normalize: float = None,
  361. speech_name: str = "speech",
  362. text_name: str = "text",
  363. split_with_space: bool = False,
  364. seg_dict_file: str = None,
  365. ):
  366. super().__init__(train,
  367. token_type,
  368. token_list,
  369. bpemodel,
  370. text_cleaner,
  371. g2p_type,
  372. unk_symbol,
  373. space_symbol,
  374. non_linguistic_symbols,
  375. delimiter,
  376. rir_scp,
  377. rir_apply_prob,
  378. noise_scp,
  379. noise_apply_prob,
  380. noise_db_range,
  381. speech_volume_normalize,
  382. speech_name,
  383. text_name,
  384. split_with_space,
  385. seg_dict_file,
  386. )
  387. def _text_process(
  388. self, data: Dict[str, Union[str, np.ndarray]]
  389. ) -> Dict[str, np.ndarray]:
  390. if self.text_name in data and self.tokenizer is not None:
  391. text = data[self.text_name]
  392. text = self.text_cleaner(text)
  393. if self.split_with_space:
  394. tokens = text.strip().split(" ")
  395. if self.seg_dict is not None:
  396. tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
  397. else:
  398. tokens = self.tokenizer.text2tokens(text)
  399. text_ints = self.token_id_converter.tokens2ids(tokens)
  400. data[self.text_name] = np.array(text_ints, dtype=np.int64)
  401. assert check_return_type(data)
  402. return data
  403. class CommonPreprocessor_multi(AbsPreprocessor):
  404. def __init__(
  405. self,
  406. train: bool,
  407. token_type: str = None,
  408. token_list: Union[Path, str, Iterable[str]] = None,
  409. bpemodel: Union[Path, str, Iterable[str]] = None,
  410. text_cleaner: Collection[str] = None,
  411. g2p_type: str = None,
  412. unk_symbol: str = "<unk>",
  413. space_symbol: str = "<space>",
  414. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  415. delimiter: str = None,
  416. speech_name: str = "speech",
  417. text_name: List[str] = ["text"],
  418. ):
  419. super().__init__(train)
  420. self.train = train
  421. self.speech_name = speech_name
  422. self.text_name = text_name
  423. if token_type is not None:
  424. if token_list is None:
  425. raise ValueError("token_list is required if token_type is not None")
  426. self.text_cleaner = TextCleaner(text_cleaner)
  427. self.tokenizer = build_tokenizer(
  428. token_type=token_type,
  429. bpemodel=bpemodel,
  430. delimiter=delimiter,
  431. space_symbol=space_symbol,
  432. non_linguistic_symbols=non_linguistic_symbols,
  433. g2p_type=g2p_type,
  434. )
  435. self.token_id_converter = TokenIDConverter(
  436. token_list=token_list,
  437. unk_symbol=unk_symbol,
  438. )
  439. else:
  440. self.text_cleaner = None
  441. self.tokenizer = None
  442. self.token_id_converter = None
  443. def _text_process(
  444. self, data: Dict[str, Union[str, np.ndarray]]
  445. ) -> Dict[str, np.ndarray]:
  446. for text_n in self.text_name:
  447. if text_n in data and self.tokenizer is not None:
  448. text = data[text_n]
  449. text = self.text_cleaner(text)
  450. tokens = self.tokenizer.text2tokens(text)
  451. text_ints = self.token_id_converter.tokens2ids(tokens)
  452. data[text_n] = np.array(text_ints, dtype=np.int64)
  453. assert check_return_type(data)
  454. return data
  455. def __call__(
  456. self, uid: str, data: Dict[str, Union[str, np.ndarray]]
  457. ) -> Dict[str, np.ndarray]:
  458. assert check_argument_types()
  459. if self.speech_name in data:
  460. # Nothing now: candidates:
  461. # - STFT
  462. # - Fbank
  463. # - CMVN
  464. # - Data augmentation
  465. pass
  466. data = self._text_process(data)
  467. return data
  468. class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
  469. def __init__(
  470. self,
  471. train: bool,
  472. token_type: List[str] = [None],
  473. token_list: List[Union[Path, str, Iterable[str]]] = [None],
  474. bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
  475. text_cleaner: Collection[str] = None,
  476. g2p_type: str = None,
  477. unk_symbol: str = "<unk>",
  478. space_symbol: str = "<space>",
  479. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  480. delimiter: str = None,
  481. rir_scp: str = None,
  482. rir_apply_prob: float = 1.0,
  483. noise_scp: str = None,
  484. noise_apply_prob: float = 1.0,
  485. noise_db_range: str = "3_10",
  486. speech_volume_normalize: float = None,
  487. speech_name: str = "speech",
  488. text_name: List[str] = ["text"],
  489. ):
  490. # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
  491. super().__init__(
  492. train=train,
  493. token_type=token_type[0],
  494. token_list=token_list[0],
  495. bpemodel=bpemodel[0],
  496. text_cleaner=text_cleaner,
  497. g2p_type=g2p_type,
  498. unk_symbol=unk_symbol,
  499. space_symbol=space_symbol,
  500. non_linguistic_symbols=non_linguistic_symbols,
  501. delimiter=delimiter,
  502. speech_name=speech_name,
  503. text_name=text_name[0],
  504. rir_scp=rir_scp,
  505. rir_apply_prob=rir_apply_prob,
  506. noise_scp=noise_scp,
  507. noise_apply_prob=noise_apply_prob,
  508. noise_db_range=noise_db_range,
  509. speech_volume_normalize=speech_volume_normalize,
  510. )
  511. assert (
  512. len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
  513. ), "token_type, token_list, bpemodel, or processing text_name mismatched"
  514. self.num_tokenizer = len(token_type)
  515. self.tokenizer = []
  516. self.token_id_converter = []
  517. for i in range(self.num_tokenizer):
  518. if token_type[i] is not None:
  519. if token_list[i] is None:
  520. raise ValueError("token_list is required if token_type is not None")
  521. self.tokenizer.append(
  522. build_tokenizer(
  523. token_type=token_type[i],
  524. bpemodel=bpemodel[i],
  525. delimiter=delimiter,
  526. space_symbol=space_symbol,
  527. non_linguistic_symbols=non_linguistic_symbols,
  528. g2p_type=g2p_type,
  529. )
  530. )
  531. self.token_id_converter.append(
  532. TokenIDConverter(
  533. token_list=token_list[i],
  534. unk_symbol=unk_symbol,
  535. )
  536. )
  537. else:
  538. self.tokenizer.append(None)
  539. self.token_id_converter.append(None)
  540. self.text_cleaner = TextCleaner(text_cleaner)
  541. self.text_name = text_name # override the text_name from CommonPreprocessor
  542. def _text_process(
  543. self, data: Dict[str, Union[str, np.ndarray]]
  544. ) -> Dict[str, np.ndarray]:
  545. for i in range(self.num_tokenizer):
  546. text_name = self.text_name[i]
  547. if text_name in data and self.tokenizer[i] is not None:
  548. text = data[text_name]
  549. text = self.text_cleaner(text)
  550. tokens = self.tokenizer[i].text2tokens(text)
  551. text_ints = self.token_id_converter[i].tokens2ids(tokens)
  552. data[text_name] = np.array(text_ints, dtype=np.int64)
  553. assert check_return_type(data)
  554. return data
  555. class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
  556. def __init__(
  557. self,
  558. train: bool,
  559. token_type: str = None,
  560. token_list: Union[Path, str, Iterable[str]] = None,
  561. bpemodel: Union[Path, str, Iterable[str]] = None,
  562. text_cleaner: Collection[str] = None,
  563. g2p_type: str = None,
  564. unk_symbol: str = "<unk>",
  565. space_symbol: str = "<space>",
  566. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  567. delimiter: str = None,
  568. rir_scp: str = None,
  569. rir_apply_prob: float = 1.0,
  570. noise_scp: str = None,
  571. noise_apply_prob: float = 1.0,
  572. noise_db_range: str = "3_10",
  573. speech_volume_normalize: float = None,
  574. speech_name: str = "speech",
  575. text_name: str = "text",
  576. split_text_name: str = "split_text",
  577. split_with_space: bool = False,
  578. seg_dict_file: str = None,
  579. ):
  580. super().__init__(
  581. train=train,
  582. # Force to use word.
  583. token_type="word",
  584. token_list=token_list,
  585. bpemodel=bpemodel,
  586. text_cleaner=text_cleaner,
  587. g2p_type=g2p_type,
  588. unk_symbol=unk_symbol,
  589. space_symbol=space_symbol,
  590. non_linguistic_symbols=non_linguistic_symbols,
  591. delimiter=delimiter,
  592. speech_name=speech_name,
  593. text_name=text_name,
  594. rir_scp=rir_scp,
  595. rir_apply_prob=rir_apply_prob,
  596. noise_scp=noise_scp,
  597. noise_apply_prob=noise_apply_prob,
  598. noise_db_range=noise_db_range,
  599. speech_volume_normalize=speech_volume_normalize,
  600. split_with_space=split_with_space,
  601. seg_dict_file=seg_dict_file,
  602. )
  603. # The data field name for split text.
  604. self.split_text_name = split_text_name
  605. @classmethod
  606. def split_words(cls, text: str):
  607. words = []
  608. segs = text.split()
  609. for seg in segs:
  610. # There is no space in seg.
  611. current_word = ""
  612. for c in seg:
  613. if len(c.encode()) == 1:
  614. # This is an ASCII char.
  615. current_word += c
  616. else:
  617. # This is a Chinese char.
  618. if len(current_word) > 0:
  619. words.append(current_word)
  620. current_word = ""
  621. words.append(c)
  622. if len(current_word) > 0:
  623. words.append(current_word)
  624. return words
  625. def __call__(
  626. self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
  627. ) -> Dict[str, Union[list, np.ndarray]]:
  628. assert check_argument_types()
  629. # Split words.
  630. if isinstance(data[self.text_name], str):
  631. split_text = self.split_words(data[self.text_name])
  632. else:
  633. split_text = data[self.text_name]
  634. data[self.text_name] = " ".join(split_text)
  635. data = self._speech_process(data)
  636. data = self._text_process(data)
  637. data[self.split_text_name] = split_text
  638. return data
  639. def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
  640. result = data[self.split_text_name]
  641. del data[self.split_text_name]
  642. return result
  643. class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
  644. def __init__(
  645. self,
  646. train: bool,
  647. token_type: List[str] = [None],
  648. token_list: List[Union[Path, str, Iterable[str]]] = [None],
  649. bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
  650. text_cleaner: Collection[str] = None,
  651. g2p_type: str = None,
  652. unk_symbol: str = "<unk>",
  653. space_symbol: str = "<space>",
  654. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  655. delimiter: str = None,
  656. rir_scp: str = None,
  657. rir_apply_prob: float = 1.0,
  658. noise_scp: str = None,
  659. noise_apply_prob: float = 1.0,
  660. noise_db_range: str = "3_10",
  661. speech_volume_normalize: float = None,
  662. speech_name: str = "speech",
  663. text_name: List[str] = ["text"],
  664. vad_name: str = "vad_indexes",
  665. ):
  666. # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
  667. super().__init__(
  668. train=train,
  669. token_type=token_type[0],
  670. token_list=token_list[0],
  671. bpemodel=bpemodel[0],
  672. text_cleaner=text_cleaner,
  673. g2p_type=g2p_type,
  674. unk_symbol=unk_symbol,
  675. space_symbol=space_symbol,
  676. non_linguistic_symbols=non_linguistic_symbols,
  677. delimiter=delimiter,
  678. speech_name=speech_name,
  679. text_name=text_name[0],
  680. rir_scp=rir_scp,
  681. rir_apply_prob=rir_apply_prob,
  682. noise_scp=noise_scp,
  683. noise_apply_prob=noise_apply_prob,
  684. noise_db_range=noise_db_range,
  685. speech_volume_normalize=speech_volume_normalize,
  686. )
  687. assert (
  688. len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
  689. ), "token_type, token_list, bpemodel, or processing text_name mismatched"
  690. self.num_tokenizer = len(token_type)
  691. self.tokenizer = []
  692. self.token_id_converter = []
  693. for i in range(self.num_tokenizer):
  694. if token_type[i] is not None:
  695. if token_list[i] is None:
  696. raise ValueError("token_list is required if token_type is not None")
  697. self.tokenizer.append(
  698. build_tokenizer(
  699. token_type=token_type[i],
  700. bpemodel=bpemodel[i],
  701. delimiter=delimiter,
  702. space_symbol=space_symbol,
  703. non_linguistic_symbols=non_linguistic_symbols,
  704. g2p_type=g2p_type,
  705. )
  706. )
  707. self.token_id_converter.append(
  708. TokenIDConverter(
  709. token_list=token_list[i],
  710. unk_symbol=unk_symbol,
  711. )
  712. )
  713. else:
  714. self.tokenizer.append(None)
  715. self.token_id_converter.append(None)
  716. self.text_cleaner = TextCleaner(text_cleaner)
  717. self.text_name = text_name # override the text_name from CommonPreprocessor
  718. self.vad_name = vad_name
  719. def _text_process(
  720. self, data: Dict[str, Union[str, np.ndarray]]
  721. ) -> Dict[str, np.ndarray]:
  722. for i in range(self.num_tokenizer):
  723. text_name = self.text_name[i]
  724. if text_name in data and self.tokenizer[i] is not None:
  725. text = data[text_name]
  726. text = self.text_cleaner(text)
  727. tokens = self.tokenizer[i].text2tokens(text)
  728. if "vad:" in tokens[-1]:
  729. vad = tokens[-1][4:]
  730. tokens = tokens[:-1]
  731. if len(vad) == 0:
  732. vad = -1
  733. else:
  734. vad = int(vad)
  735. data[self.vad_name] = np.array([vad], dtype=np.int64)
  736. text_ints = self.token_id_converter[i].tokens2ids(tokens)
  737. data[text_name] = np.array(text_ints, dtype=np.int64)
  738. def split_to_mini_sentence(words: list, word_limit: int = 20):
  739. assert word_limit > 1
  740. if len(words) <= word_limit:
  741. return [words]
  742. sentences = []
  743. length = len(words)
  744. sentence_len = length // word_limit
  745. for i in range(sentence_len):
  746. sentences.append(words[i * word_limit:(i + 1) * word_limit])
  747. if length % word_limit > 0:
  748. sentences.append(words[sentence_len * word_limit:])
  749. return sentences
  750. def build_preprocess(args, train):
  751. if not args.use_preprocessor:
  752. return None
  753. if args.task_name in ["asr", "data2vec", "diar", "sv"]:
  754. retval = CommonPreprocessor(
  755. train=train,
  756. token_type=args.token_type,
  757. token_list=args.token_list,
  758. bpemodel=args.bpemodel,
  759. non_linguistic_symbols=args.non_linguistic_symbols if hasattr(args, "non_linguistic_symbols") else None,
  760. text_cleaner=args.cleaner,
  761. g2p_type=args.g2p,
  762. split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
  763. seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
  764. rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  765. rir_apply_prob=args.rir_apply_prob if hasattr(args, "rir_apply_prob") else 1.0,
  766. noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  767. noise_apply_prob=args.noise_apply_prob if hasattr(args, "noise_apply_prob") else 1.0,
  768. noise_db_range=args.noise_db_range if hasattr(args, "noise_db_range") else "13_15",
  769. speech_volume_normalize=args.speech_volume_normalize if hasattr(args, "rir_scp") else None,
  770. )
  771. elif args.task_name == "punc":
  772. token_types = [args.token_type, args.token_type]
  773. token_lists = [args.token_list, args.punc_list]
  774. bpemodels = [args.bpemodel, args.bpemodel]
  775. text_names = ["text", "punc"]
  776. retval = PuncTrainTokenizerCommonPreprocessor(
  777. train=train,
  778. token_type=token_types,
  779. token_list=token_lists,
  780. bpemodel=bpemodels,
  781. text_cleaner=args.cleaner,
  782. g2p_type=args.g2p,
  783. text_name=text_names,
  784. non_linguistic_symbols=args.non_linguistic_symbols,
  785. )
  786. elif args.task_name == "lm":
  787. retval = LMPreprocessor(
  788. train=train,
  789. token_type=args.token_type,
  790. token_list=args.token_list,
  791. bpemodel=args.bpemodel,
  792. text_cleaner=args.cleaner,
  793. g2p_type=args.g2p,
  794. text_name="text",
  795. non_linguistic_symbols=args.non_linguistic_symbols,
  796. split_with_space=args.split_with_space,
  797. seg_dict_file=args.seg_dict_file
  798. )
  799. elif args.task_name == "vad":
  800. retval = None
  801. else:
  802. raise ValueError(f"Not supported task={args.task_name}")
  803. return retval