preprocessor.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878
  1. import re
  2. from abc import ABC
  3. from abc import abstractmethod
  4. from pathlib import Path
  5. from typing import Collection
  6. from typing import Dict
  7. from typing import Iterable
  8. from typing import List
  9. from typing import Union
  10. import numpy as np
  11. import scipy.signal
  12. import soundfile
  13. import jieba
  14. from funasr.text.build_tokenizer import build_tokenizer
  15. from funasr.text.cleaner import TextCleaner
  16. from funasr.text.token_id_converter import TokenIDConverter
  17. class AbsPreprocessor(ABC):
  18. def __init__(self, train: bool):
  19. self.train = train
  20. @abstractmethod
  21. def __call__(
  22. self, uid: str, data: Dict[str, Union[str, np.ndarray]]
  23. ) -> Dict[str, np.ndarray]:
  24. raise NotImplementedError
  25. def forward_segment(text, dic):
  26. word_list = []
  27. i = 0
  28. while i < len(text):
  29. longest_word = text[i]
  30. for j in range(i + 1, len(text) + 1):
  31. word = text[i:j]
  32. if word in dic:
  33. if len(word) > len(longest_word):
  34. longest_word = word
  35. word_list.append(longest_word)
  36. i += len(longest_word)
  37. return word_list
  38. def seg_tokenize(txt, seg_dict):
  39. pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
  40. out_txt = ""
  41. for word in txt:
  42. word = word.lower()
  43. if word in seg_dict:
  44. out_txt += seg_dict[word] + " "
  45. else:
  46. if pattern.match(word):
  47. for char in word:
  48. if char in seg_dict:
  49. out_txt += seg_dict[char] + " "
  50. else:
  51. out_txt += "<unk>" + " "
  52. else:
  53. out_txt += "<unk>" + " "
  54. return out_txt.strip().split()
  55. def seg_tokenize_wo_pattern(txt, seg_dict):
  56. out_txt = ""
  57. for word in txt:
  58. if word in seg_dict:
  59. out_txt += seg_dict[word] + " "
  60. else:
  61. out_txt += "<unk>" + " "
  62. return out_txt.strip().split()
  63. def framing(
  64. x,
  65. frame_length: int = 512,
  66. frame_shift: int = 256,
  67. centered: bool = True,
  68. padded: bool = True,
  69. ):
  70. if x.size == 0:
  71. raise ValueError("Input array size is zero")
  72. if frame_length < 1:
  73. raise ValueError("frame_length must be a positive integer")
  74. if frame_length > x.shape[-1]:
  75. raise ValueError("frame_length is greater than input length")
  76. if 0 >= frame_shift:
  77. raise ValueError("frame_shift must be greater than 0")
  78. if centered:
  79. pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [
  80. (frame_length // 2, frame_length // 2)
  81. ]
  82. x = np.pad(x, pad_shape, mode="constant", constant_values=0)
  83. if padded:
  84. # Pad to integer number of windowed segments
  85. # I.e make x.shape[-1] = frame_length + (nseg-1)*nstep,
  86. # with integer nseg
  87. nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length
  88. pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)]
  89. x = np.pad(x, pad_shape, mode="constant", constant_values=0)
  90. # Created strided array of data segments
  91. if frame_length == 1 and frame_length == frame_shift:
  92. result = x[..., None]
  93. else:
  94. shape = x.shape[:-1] + (
  95. (x.shape[-1] - frame_length) // frame_shift + 1,
  96. frame_length,
  97. )
  98. strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1])
  99. result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
  100. return result
  101. def detect_non_silence(
  102. x: np.ndarray,
  103. threshold: float = 0.01,
  104. frame_length: int = 1024,
  105. frame_shift: int = 512,
  106. window: str = "boxcar",
  107. ) -> np.ndarray:
  108. """Power based voice activity detection.
  109. Args:
  110. x: (Channel, Time)
  111. >>> x = np.random.randn(1000)
  112. >>> detect = detect_non_silence(x)
  113. >>> assert x.shape == detect.shape
  114. >>> assert detect.dtype == np.bool
  115. """
  116. if x.shape[-1] < frame_length:
  117. return np.full(x.shape, fill_value=True, dtype=np.bool)
  118. if x.dtype.kind == "i":
  119. x = x.astype(np.float64)
  120. # framed_w: (C, T, F)
  121. framed_w = framing(
  122. x,
  123. frame_length=frame_length,
  124. frame_shift=frame_shift,
  125. centered=False,
  126. padded=True,
  127. )
  128. framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype)
  129. # power: (C, T)
  130. power = (framed_w ** 2).mean(axis=-1)
  131. # mean_power: (C, 1)
  132. mean_power = np.mean(power, axis=-1, keepdims=True)
  133. if np.all(mean_power == 0):
  134. return np.full(x.shape, fill_value=True, dtype=np.bool)
  135. # detect_frames: (C, T)
  136. detect_frames = power / mean_power > threshold
  137. # detects: (C, T, F)
  138. detects = np.broadcast_to(
  139. detect_frames[..., None], detect_frames.shape + (frame_shift,)
  140. )
  141. # detects: (C, TF)
  142. detects = detects.reshape(*detect_frames.shape[:-1], -1)
  143. # detects: (C, TF)
  144. return np.pad(
  145. detects,
  146. [(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])],
  147. mode="edge",
  148. )
  149. class CommonPreprocessor(AbsPreprocessor):
  150. def __init__(
  151. self,
  152. train: bool,
  153. token_type: str = None,
  154. token_list: Union[Path, str, Iterable[str]] = None,
  155. bpemodel: Union[Path, str, Iterable[str]] = None,
  156. text_cleaner: Collection[str] = None,
  157. g2p_type: str = None,
  158. unk_symbol: str = "<unk>",
  159. space_symbol: str = "<space>",
  160. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  161. delimiter: str = None,
  162. rir_scp: str = None,
  163. rir_apply_prob: float = 1.0,
  164. noise_scp: str = None,
  165. noise_apply_prob: float = 1.0,
  166. noise_db_range: str = "3_10",
  167. speech_volume_normalize: float = None,
  168. speech_name: str = "speech",
  169. text_name: str = "text",
  170. split_with_space: bool = False,
  171. seg_dict_file: str = None,
  172. ):
  173. super().__init__(train)
  174. self.train = train
  175. self.speech_name = speech_name
  176. self.text_name = text_name
  177. self.speech_volume_normalize = speech_volume_normalize
  178. self.rir_apply_prob = rir_apply_prob
  179. self.noise_apply_prob = noise_apply_prob
  180. self.split_with_space = split_with_space
  181. self.seg_dict = None
  182. if seg_dict_file is not None:
  183. self.seg_dict = {}
  184. with open(seg_dict_file, "r", encoding="utf8") as f:
  185. lines = f.readlines()
  186. for line in lines:
  187. s = line.strip().split()
  188. key = s[0]
  189. value = s[1:]
  190. self.seg_dict[key] = " ".join(value)
  191. if token_type is not None:
  192. if token_list is None:
  193. raise ValueError("token_list is required if token_type is not None")
  194. self.text_cleaner = TextCleaner(text_cleaner)
  195. self.tokenizer = build_tokenizer(
  196. token_type=token_type,
  197. bpemodel=bpemodel,
  198. delimiter=delimiter,
  199. space_symbol=space_symbol,
  200. non_linguistic_symbols=non_linguistic_symbols,
  201. g2p_type=g2p_type,
  202. )
  203. self.token_id_converter = TokenIDConverter(
  204. token_list=token_list,
  205. unk_symbol=unk_symbol,
  206. )
  207. else:
  208. self.text_cleaner = None
  209. self.tokenizer = None
  210. self.token_id_converter = None
  211. if train and rir_scp is not None:
  212. self.rirs = []
  213. with open(rir_scp, "r", encoding="utf-8") as f:
  214. for line in f:
  215. sps = line.strip().split(None, 1)
  216. if len(sps) == 1:
  217. self.rirs.append(sps[0])
  218. else:
  219. self.rirs.append(sps[1])
  220. else:
  221. self.rirs = None
  222. if train and noise_scp is not None:
  223. self.noises = []
  224. with open(noise_scp, "r", encoding="utf-8") as f:
  225. for line in f:
  226. sps = line.strip().split(None, 1)
  227. if len(sps) == 1:
  228. self.noises.append(sps[0])
  229. else:
  230. self.noises.append(sps[1])
  231. sps = noise_db_range.split("_")
  232. if len(sps) == 1:
  233. self.noise_db_low, self.noise_db_high = float(sps[0])
  234. elif len(sps) == 2:
  235. self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1])
  236. else:
  237. raise ValueError(
  238. "Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]"
  239. )
  240. else:
  241. self.noises = None
  242. def _speech_process(
  243. self, data: Dict[str, Union[str, np.ndarray]]
  244. ) -> Dict[str, Union[str, np.ndarray]]:
  245. if self.speech_name in data:
  246. if self.train and (self.rirs is not None or self.noises is not None):
  247. speech = data[self.speech_name]
  248. nsamples = len(speech)
  249. # speech: (Nmic, Time)
  250. if speech.ndim == 1:
  251. speech = speech[None, :]
  252. else:
  253. speech = speech.T
  254. # Calc power on non shlence region
  255. power = (speech[detect_non_silence(speech)] ** 2).mean()
  256. # 1. Convolve RIR
  257. if self.rirs is not None and self.rir_apply_prob >= np.random.random():
  258. rir_path = np.random.choice(self.rirs)
  259. if rir_path is not None:
  260. rir, _ = soundfile.read(
  261. rir_path, dtype=np.float64, always_2d=True
  262. )
  263. # rir: (Nmic, Time)
  264. rir = rir.T
  265. # speech: (Nmic, Time)
  266. # Note that this operation doesn't change the signal length
  267. speech = scipy.signal.convolve(speech, rir, mode="full")[
  268. :, : speech.shape[1]
  269. ]
  270. # Reverse mean power to the original power
  271. power2 = (speech[detect_non_silence(speech)] ** 2).mean()
  272. speech = np.sqrt(power / max(power2, 1e-10)) * speech
  273. # 2. Add Noise
  274. if (
  275. self.noises is not None
  276. and self.noise_apply_prob >= np.random.random()
  277. ):
  278. noise_path = np.random.choice(self.noises)
  279. if noise_path is not None:
  280. noise_db = np.random.uniform(
  281. self.noise_db_low, self.noise_db_high
  282. )
  283. with soundfile.SoundFile(noise_path) as f:
  284. if f.frames == nsamples:
  285. noise = f.read(dtype=np.float64, always_2d=True)
  286. elif f.frames < nsamples:
  287. offset = np.random.randint(0, nsamples - f.frames)
  288. # noise: (Time, Nmic)
  289. noise = f.read(dtype=np.float64, always_2d=True)
  290. # Repeat noise
  291. noise = np.pad(
  292. noise,
  293. [(offset, nsamples - f.frames - offset), (0, 0)],
  294. mode="wrap",
  295. )
  296. else:
  297. offset = np.random.randint(0, f.frames - nsamples)
  298. f.seek(offset)
  299. # noise: (Time, Nmic)
  300. noise = f.read(
  301. nsamples, dtype=np.float64, always_2d=True
  302. )
  303. if len(noise) != nsamples:
  304. raise RuntimeError(f"Something wrong: {noise_path}")
  305. # noise: (Nmic, Time)
  306. noise = noise.T
  307. noise_power = (noise ** 2).mean()
  308. scale = (
  309. 10 ** (-noise_db / 20)
  310. * np.sqrt(power)
  311. / np.sqrt(max(noise_power, 1e-10))
  312. )
  313. speech = speech + scale * noise
  314. speech = speech.T
  315. ma = np.max(np.abs(speech))
  316. if ma > 1.0:
  317. speech /= ma
  318. data[self.speech_name] = speech
  319. if self.speech_volume_normalize is not None:
  320. speech = data[self.speech_name]
  321. ma = np.max(np.abs(speech))
  322. data[self.speech_name] = speech * self.speech_volume_normalize / ma
  323. return data
  324. def _text_process(
  325. self, data: Dict[str, Union[str, np.ndarray]]
  326. ) -> Dict[str, np.ndarray]:
  327. if self.text_name in data and self.tokenizer is not None:
  328. text = data[self.text_name]
  329. text = self.text_cleaner(text)
  330. if self.split_with_space:
  331. tokens = text.strip().split(" ")
  332. if self.seg_dict is not None:
  333. tokens = seg_tokenize(tokens, self.seg_dict)
  334. else:
  335. tokens = self.tokenizer.text2tokens(text)
  336. text_ints = self.token_id_converter.tokens2ids(tokens)
  337. data[self.text_name] = np.array(text_ints, dtype=np.int64)
  338. return data
  339. def __call__(
  340. self, uid: str, data: Dict[str, Union[str, np.ndarray]]
  341. ) -> Dict[str, np.ndarray]:
  342. data = self._speech_process(data)
  343. data = self._text_process(data)
  344. return data
  345. ## FIXME
  346. class LMPreprocessor(CommonPreprocessor):
  347. def __init__(
  348. self,
  349. train: bool,
  350. token_type: str = None,
  351. token_list: Union[Path, str, Iterable[str]] = None,
  352. bpemodel: Union[Path, str, Iterable[str]] = None,
  353. text_cleaner: Collection[str] = None,
  354. g2p_type: str = None,
  355. unk_symbol: str = "<unk>",
  356. space_symbol: str = "<space>",
  357. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  358. delimiter: str = None,
  359. rir_scp: str = None,
  360. rir_apply_prob: float = 1.0,
  361. noise_scp: str = None,
  362. noise_apply_prob: float = 1.0,
  363. noise_db_range: str = "3_10",
  364. speech_volume_normalize: float = None,
  365. speech_name: str = "speech",
  366. text_name: str = "text",
  367. split_with_space: bool = False,
  368. seg_dict_file: str = None,
  369. ):
  370. super().__init__(train,
  371. token_type,
  372. token_list,
  373. bpemodel,
  374. text_cleaner,
  375. g2p_type,
  376. unk_symbol,
  377. space_symbol,
  378. non_linguistic_symbols,
  379. delimiter,
  380. rir_scp,
  381. rir_apply_prob,
  382. noise_scp,
  383. noise_apply_prob,
  384. noise_db_range,
  385. speech_volume_normalize,
  386. speech_name,
  387. text_name,
  388. split_with_space,
  389. seg_dict_file,
  390. )
  391. def _text_process(
  392. self, data: Dict[str, Union[str, np.ndarray]]
  393. ) -> Dict[str, np.ndarray]:
  394. if self.text_name in data and self.tokenizer is not None:
  395. text = data[self.text_name]
  396. text = self.text_cleaner(text)
  397. if self.split_with_space:
  398. tokens = text.strip().split(" ")
  399. if self.seg_dict is not None:
  400. tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
  401. else:
  402. tokens = self.tokenizer.text2tokens(text)
  403. text_ints = self.token_id_converter.tokens2ids(tokens)
  404. data[self.text_name] = np.array(text_ints, dtype=np.int64)
  405. return data
  406. class CommonPreprocessor_multi(AbsPreprocessor):
  407. def __init__(
  408. self,
  409. train: bool,
  410. token_type: str = None,
  411. token_list: Union[Path, str, Iterable[str]] = None,
  412. bpemodel: Union[Path, str, Iterable[str]] = None,
  413. text_cleaner: Collection[str] = None,
  414. g2p_type: str = None,
  415. unk_symbol: str = "<unk>",
  416. space_symbol: str = "<space>",
  417. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  418. delimiter: str = None,
  419. speech_name: str = "speech",
  420. text_name: List[str] = ["text"],
  421. ):
  422. super().__init__(train)
  423. self.train = train
  424. self.speech_name = speech_name
  425. self.text_name = text_name
  426. if token_type is not None:
  427. if token_list is None:
  428. raise ValueError("token_list is required if token_type is not None")
  429. self.text_cleaner = TextCleaner(text_cleaner)
  430. self.tokenizer = build_tokenizer(
  431. token_type=token_type,
  432. bpemodel=bpemodel,
  433. delimiter=delimiter,
  434. space_symbol=space_symbol,
  435. non_linguistic_symbols=non_linguistic_symbols,
  436. g2p_type=g2p_type,
  437. )
  438. self.token_id_converter = TokenIDConverter(
  439. token_list=token_list,
  440. unk_symbol=unk_symbol,
  441. )
  442. else:
  443. self.text_cleaner = None
  444. self.tokenizer = None
  445. self.token_id_converter = None
  446. def _text_process(
  447. self, data: Dict[str, Union[str, np.ndarray]]
  448. ) -> Dict[str, np.ndarray]:
  449. for text_n in self.text_name:
  450. if text_n in data and self.tokenizer is not None:
  451. text = data[text_n]
  452. text = self.text_cleaner(text)
  453. tokens = self.tokenizer.text2tokens(text)
  454. text_ints = self.token_id_converter.tokens2ids(tokens)
  455. data[text_n] = np.array(text_ints, dtype=np.int64)
  456. return data
  457. def __call__(
  458. self, uid: str, data: Dict[str, Union[str, np.ndarray]]
  459. ) -> Dict[str, np.ndarray]:
  460. if self.speech_name in data:
  461. # Nothing now: candidates:
  462. # - STFT
  463. # - Fbank
  464. # - CMVN
  465. # - Data augmentation
  466. pass
  467. data = self._text_process(data)
  468. return data
  469. class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
  470. def __init__(
  471. self,
  472. train: bool,
  473. token_type: List[str] = [None],
  474. token_list: List[Union[Path, str, Iterable[str]]] = [None],
  475. bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
  476. text_cleaner: Collection[str] = None,
  477. g2p_type: str = None,
  478. unk_symbol: str = "<unk>",
  479. space_symbol: str = "<space>",
  480. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  481. delimiter: str = None,
  482. rir_scp: str = None,
  483. rir_apply_prob: float = 1.0,
  484. noise_scp: str = None,
  485. noise_apply_prob: float = 1.0,
  486. noise_db_range: str = "3_10",
  487. speech_volume_normalize: float = None,
  488. speech_name: str = "speech",
  489. text_name: List[str] = ["text"],
  490. ):
  491. # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
  492. super().__init__(
  493. train=train,
  494. token_type=token_type[0],
  495. token_list=token_list[0],
  496. bpemodel=bpemodel[0],
  497. text_cleaner=text_cleaner,
  498. g2p_type=g2p_type,
  499. unk_symbol=unk_symbol,
  500. space_symbol=space_symbol,
  501. non_linguistic_symbols=non_linguistic_symbols,
  502. delimiter=delimiter,
  503. speech_name=speech_name,
  504. text_name=text_name[0],
  505. rir_scp=rir_scp,
  506. rir_apply_prob=rir_apply_prob,
  507. noise_scp=noise_scp,
  508. noise_apply_prob=noise_apply_prob,
  509. noise_db_range=noise_db_range,
  510. speech_volume_normalize=speech_volume_normalize,
  511. )
  512. assert (
  513. len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
  514. ), "token_type, token_list, bpemodel, or processing text_name mismatched"
  515. self.num_tokenizer = len(token_type)
  516. self.tokenizer = []
  517. self.token_id_converter = []
  518. for i in range(self.num_tokenizer):
  519. if token_type[i] is not None:
  520. if token_list[i] is None:
  521. raise ValueError("token_list is required if token_type is not None")
  522. self.tokenizer.append(
  523. build_tokenizer(
  524. token_type=token_type[i],
  525. bpemodel=bpemodel[i],
  526. delimiter=delimiter,
  527. space_symbol=space_symbol,
  528. non_linguistic_symbols=non_linguistic_symbols,
  529. g2p_type=g2p_type,
  530. )
  531. )
  532. self.token_id_converter.append(
  533. TokenIDConverter(
  534. token_list=token_list[i],
  535. unk_symbol=unk_symbol,
  536. )
  537. )
  538. else:
  539. self.tokenizer.append(None)
  540. self.token_id_converter.append(None)
  541. self.text_cleaner = TextCleaner(text_cleaner)
  542. self.text_name = text_name # override the text_name from CommonPreprocessor
  543. def _text_process(
  544. self, data: Dict[str, Union[str, np.ndarray]]
  545. ) -> Dict[str, np.ndarray]:
  546. for i in range(self.num_tokenizer):
  547. text_name = self.text_name[i]
  548. if text_name in data and self.tokenizer[i] is not None:
  549. text = data[text_name]
  550. text = self.text_cleaner(text)
  551. tokens = self.tokenizer[i].text2tokens(text)
  552. text_ints = self.token_id_converter[i].tokens2ids(tokens)
  553. data[text_name] = np.array(text_ints, dtype=np.int64)
  554. return data
  555. class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
  556. def __init__(
  557. self,
  558. train: bool,
  559. token_type: str = None,
  560. token_list: Union[Path, str, Iterable[str]] = None,
  561. bpemodel: Union[Path, str, Iterable[str]] = None,
  562. text_cleaner: Collection[str] = None,
  563. g2p_type: str = None,
  564. unk_symbol: str = "<unk>",
  565. space_symbol: str = "<space>",
  566. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  567. delimiter: str = None,
  568. rir_scp: str = None,
  569. rir_apply_prob: float = 1.0,
  570. noise_scp: str = None,
  571. noise_apply_prob: float = 1.0,
  572. noise_db_range: str = "3_10",
  573. speech_volume_normalize: float = None,
  574. speech_name: str = "speech",
  575. text_name: str = "text",
  576. split_text_name: str = "split_text",
  577. split_with_space: bool = False,
  578. seg_jieba: bool = False,
  579. seg_dict_file: str = None,
  580. ):
  581. super().__init__(
  582. train=train,
  583. # Force to use word.
  584. token_type="word",
  585. token_list=token_list,
  586. bpemodel=bpemodel,
  587. text_cleaner=text_cleaner,
  588. g2p_type=g2p_type,
  589. unk_symbol=unk_symbol,
  590. space_symbol=space_symbol,
  591. non_linguistic_symbols=non_linguistic_symbols,
  592. delimiter=delimiter,
  593. speech_name=speech_name,
  594. text_name=text_name,
  595. rir_scp=rir_scp,
  596. rir_apply_prob=rir_apply_prob,
  597. noise_scp=noise_scp,
  598. noise_apply_prob=noise_apply_prob,
  599. noise_db_range=noise_db_range,
  600. speech_volume_normalize=speech_volume_normalize,
  601. split_with_space=split_with_space,
  602. seg_dict_file=seg_dict_file,
  603. )
  604. # The data field name for split text.
  605. self.split_text_name = split_text_name
  606. self.seg_jieba = seg_jieba
  607. if self.seg_jieba:
  608. jieba.load_userdict(seg_dict_file)
  609. @classmethod
  610. def split_words(cls, text: str):
  611. words = []
  612. segs = text.split()
  613. for seg in segs:
  614. # There is no space in seg.
  615. current_word = ""
  616. for c in seg:
  617. if len(c.encode()) == 1:
  618. # This is an ASCII char.
  619. current_word += c
  620. else:
  621. # This is a Chinese char.
  622. if len(current_word) > 0:
  623. words.append(current_word)
  624. current_word = ""
  625. words.append(c)
  626. if len(current_word) > 0:
  627. words.append(current_word)
  628. return words
  629. @classmethod
  630. def isEnglish(cls, text:str):
  631. if re.search('^[a-zA-Z\']+$', text):
  632. return True
  633. else:
  634. return False
  635. @classmethod
  636. def join_chinese_and_english(cls, input_list):
  637. line = ''
  638. for token in input_list:
  639. if cls.isEnglish(token):
  640. line = line + ' ' + token
  641. else:
  642. line = line + token
  643. line = line.strip()
  644. return line
  645. @classmethod
  646. def split_words_jieba(cls, text: str):
  647. input_list = text.split()
  648. token_list_all = []
  649. langauge_list = []
  650. token_list_tmp = []
  651. language_flag = None
  652. for token in input_list:
  653. if cls.isEnglish(token) and language_flag == 'Chinese':
  654. token_list_all.append(token_list_tmp)
  655. langauge_list.append('Chinese')
  656. token_list_tmp = []
  657. elif not cls.isEnglish(token) and language_flag == 'English':
  658. token_list_all.append(token_list_tmp)
  659. langauge_list.append('English')
  660. token_list_tmp = []
  661. token_list_tmp.append(token)
  662. if cls.isEnglish(token):
  663. language_flag = 'English'
  664. else:
  665. language_flag = 'Chinese'
  666. if token_list_tmp:
  667. token_list_all.append(token_list_tmp)
  668. langauge_list.append(language_flag)
  669. result_list = []
  670. for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
  671. if language_flag == 'English':
  672. result_list.extend(token_list_tmp)
  673. else:
  674. seg_list = jieba.cut(cls.join_chinese_and_english(token_list_tmp), HMM=False)
  675. result_list.extend(seg_list)
  676. return result_list
  677. def __call__(
  678. self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
  679. ) -> Dict[str, Union[list, np.ndarray]]:
  680. # Split words.
  681. if isinstance(data[self.text_name], str):
  682. if self.seg_jieba:
  683. # jieba.load_userdict(seg_dict_file)
  684. split_text = self.split_words_jieba(data[self.text_name])
  685. else:
  686. split_text = self.split_words(data[self.text_name])
  687. else:
  688. split_text = data[self.text_name]
  689. data[self.text_name] = " ".join(split_text)
  690. data = self._speech_process(data)
  691. data = self._text_process(data)
  692. data[self.split_text_name] = split_text
  693. return data
  694. def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
  695. result = data[self.split_text_name]
  696. del data[self.split_text_name]
  697. return result
  698. class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
  699. def __init__(
  700. self,
  701. train: bool,
  702. token_type: List[str] = [None],
  703. token_list: List[Union[Path, str, Iterable[str]]] = [None],
  704. bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
  705. text_cleaner: Collection[str] = None,
  706. g2p_type: str = None,
  707. unk_symbol: str = "<unk>",
  708. space_symbol: str = "<space>",
  709. non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
  710. delimiter: str = None,
  711. rir_scp: str = None,
  712. rir_apply_prob: float = 1.0,
  713. noise_scp: str = None,
  714. noise_apply_prob: float = 1.0,
  715. noise_db_range: str = "3_10",
  716. speech_volume_normalize: float = None,
  717. speech_name: str = "speech",
  718. text_name: List[str] = ["text"],
  719. vad_name: str = "vad_indexes",
  720. ):
  721. # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
  722. super().__init__(
  723. train=train,
  724. token_type=token_type[0],
  725. token_list=token_list[0],
  726. bpemodel=bpemodel[0],
  727. text_cleaner=text_cleaner,
  728. g2p_type=g2p_type,
  729. unk_symbol=unk_symbol,
  730. space_symbol=space_symbol,
  731. non_linguistic_symbols=non_linguistic_symbols,
  732. delimiter=delimiter,
  733. speech_name=speech_name,
  734. text_name=text_name[0],
  735. rir_scp=rir_scp,
  736. rir_apply_prob=rir_apply_prob,
  737. noise_scp=noise_scp,
  738. noise_apply_prob=noise_apply_prob,
  739. noise_db_range=noise_db_range,
  740. speech_volume_normalize=speech_volume_normalize,
  741. )
  742. assert (
  743. len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
  744. ), "token_type, token_list, bpemodel, or processing text_name mismatched"
  745. self.num_tokenizer = len(token_type)
  746. self.tokenizer = []
  747. self.token_id_converter = []
  748. for i in range(self.num_tokenizer):
  749. if token_type[i] is not None:
  750. if token_list[i] is None:
  751. raise ValueError("token_list is required if token_type is not None")
  752. self.tokenizer.append(
  753. build_tokenizer(
  754. token_type=token_type[i],
  755. bpemodel=bpemodel[i],
  756. delimiter=delimiter,
  757. space_symbol=space_symbol,
  758. non_linguistic_symbols=non_linguistic_symbols,
  759. g2p_type=g2p_type,
  760. )
  761. )
  762. self.token_id_converter.append(
  763. TokenIDConverter(
  764. token_list=token_list[i],
  765. unk_symbol=unk_symbol,
  766. )
  767. )
  768. else:
  769. self.tokenizer.append(None)
  770. self.token_id_converter.append(None)
  771. self.text_cleaner = TextCleaner(text_cleaner)
  772. self.text_name = text_name # override the text_name from CommonPreprocessor
  773. self.vad_name = vad_name
  774. def _text_process(
  775. self, data: Dict[str, Union[str, np.ndarray]]
  776. ) -> Dict[str, np.ndarray]:
  777. for i in range(self.num_tokenizer):
  778. text_name = self.text_name[i]
  779. if text_name in data and self.tokenizer[i] is not None:
  780. text = data[text_name]
  781. text = self.text_cleaner(text)
  782. tokens = self.tokenizer[i].text2tokens(text)
  783. if "vad:" in tokens[-1]:
  784. vad = tokens[-1][4:]
  785. tokens = tokens[:-1]
  786. if len(vad) == 0:
  787. vad = -1
  788. else:
  789. vad = int(vad)
  790. data[self.vad_name] = np.array([vad], dtype=np.int64)
  791. text_ints = self.token_id_converter[i].tokens2ids(tokens)
  792. data[text_name] = np.array(text_ints, dtype=np.int64)
  793. return data
  794. def split_to_mini_sentence(words: list, word_limit: int = 20):
  795. assert word_limit > 1
  796. if len(words) <= word_limit:
  797. return [words]
  798. sentences = []
  799. length = len(words)
  800. sentence_len = length // word_limit
  801. for i in range(sentence_len):
  802. sentences.append(words[i * word_limit:(i + 1) * word_limit])
  803. if length % word_limit > 0:
  804. sentences.append(words[sentence_len * word_limit:])
  805. return sentences