normalize.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. import itertools
  2. import os
  3. import re
  4. from argparse import ArgumentParser
  5. from collections import OrderedDict
  6. from math import factorial
  7. from time import perf_counter
  8. from typing import Dict, List, Union
  9. import pynini
  10. import regex
  11. from joblib import Parallel, delayed
  12. from fun_text_processing.text_normalization.data_loader_utils import (
  13. load_file,
  14. post_process_punct,
  15. pre_process,
  16. write_file,
  17. )
  18. from fun_text_processing.text_normalization.token_parser import PRESERVE_ORDER_KEY, TokenParser
  19. from pynini.lib.rewrite import top_rewrite
  20. from tqdm import tqdm
  21. try:
  22. from nemo.collections.common.tokenizers.moses_tokenizers import MosesProcessor
  23. NLP_AVAILABLE = True
  24. except (ModuleNotFoundError, ImportError) as e:
  25. NLP_AVAILABLE = False
  26. SPACE_DUP = re.compile(' {2,}')
  27. class Normalizer:
  28. """
  29. Normalizer class that converts text from written to spoken form.
  30. Useful for TTS preprocessing.
  31. Args:
  32. input_case: expected input capitalization
  33. lang: language specifying the TN rules, by default: English
  34. cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
  35. overwrite_cache: set to True to overwrite .far files
  36. whitelist: path to a file with whitelist replacements
  37. post_process: WFST-based post processing, e.g. to remove extra spaces added during TN.
  38. Note: punct_post_process flag in normalize() supports all languages.
  39. """
  40. def __init__(
  41. self,
  42. input_case: str,
  43. lang: str = 'en',
  44. deterministic: bool = True,
  45. cache_dir: str = None,
  46. overwrite_cache: bool = False,
  47. whitelist: str = None,
  48. lm: bool = False,
  49. post_process: bool = True,
  50. ):
  51. assert input_case in ["lower_cased", "cased"]
  52. self.post_processor = None
  53. if lang == "en":
  54. from fun_text_processing.text_normalization.en.verbalizers.verbalize_final import VerbalizeFinalFst
  55. from fun_text_processing.text_normalization.en.verbalizers.post_processing import PostProcessingFst
  56. if post_process:
  57. self.post_processor = PostProcessingFst(cache_dir=cache_dir, overwrite_cache=overwrite_cache)
  58. if deterministic:
  59. from fun_text_processing.text_normalization.en.taggers.tokenize_and_classify import ClassifyFst
  60. else:
  61. if lm:
  62. from fun_text_processing.text_normalization.en.taggers.tokenize_and_classify_lm import ClassifyFst
  63. else:
  64. from fun_text_processing.text_normalization.en.taggers.tokenize_and_classify_with_audio import (
  65. ClassifyFst,
  66. )
  67. elif lang == 'ru':
  68. # Ru TN only support non-deterministic cases and produces multiple normalization options
  69. # use normalize_with_audio.py
  70. from fun_text_processing.text_normalization.ru.taggers.tokenize_and_classify import ClassifyFst
  71. from fun_text_processing.text_normalization.ru.verbalizers.verbalize_final import VerbalizeFinalFst
  72. elif lang == 'de':
  73. from fun_text_processing.text_normalization.de.taggers.tokenize_and_classify import ClassifyFst
  74. from fun_text_processing.text_normalization.de.verbalizers.verbalize_final import VerbalizeFinalFst
  75. elif lang == 'es':
  76. from fun_text_processing.text_normalization.es.taggers.tokenize_and_classify import ClassifyFst
  77. from fun_text_processing.text_normalization.es.verbalizers.verbalize_final import VerbalizeFinalFst
  78. elif lang == 'zh':
  79. from fun_text_processing.text_normalization.zh.taggers.tokenize_and_classify import ClassifyFst
  80. from fun_text_processing.text_normalization.zh.verbalizers.verbalize_final import VerbalizeFinalFst
  81. self.tagger = ClassifyFst(
  82. input_case=input_case,
  83. deterministic=deterministic,
  84. cache_dir=cache_dir,
  85. overwrite_cache=overwrite_cache,
  86. whitelist=whitelist,
  87. )
  88. self.verbalizer = VerbalizeFinalFst(
  89. deterministic=deterministic, cache_dir=cache_dir, overwrite_cache=overwrite_cache
  90. )
  91. self.parser = TokenParser()
  92. self.lang = lang
  93. if NLP_AVAILABLE:
  94. self.processor = MosesProcessor(lang_id=lang)
  95. else:
  96. self.processor = None
  97. print("NeMo NLP is not available. Moses de-tokenization will be skipped.")
  98. def normalize_list(
  99. self,
  100. texts: List[str],
  101. verbose: bool = False,
  102. punct_pre_process: bool = False,
  103. punct_post_process: bool = False,
  104. batch_size: int = 1,
  105. n_jobs: int = 1,
  106. ):
  107. """
  108. NeMo text normalizer
  109. Args:
  110. texts: list of input strings
  111. verbose: whether to print intermediate meta information
  112. punct_pre_process: whether to do punctuation pre processing
  113. punct_post_process: whether to do punctuation post processing
  114. n_jobs: the maximum number of concurrently running jobs. If -1 all CPUs are used. If 1 is given,
  115. no parallel computing code is used at all, which is useful for debugging. For n_jobs below -1,
  116. (n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one are used.
  117. batch_size: Number of examples for each process
  118. Returns converted list input strings
  119. """
  120. # to save intermediate results to a file
  121. batch = min(len(texts), batch_size)
  122. try:
  123. normalized_texts = Parallel(n_jobs=n_jobs)(
  124. delayed(self.process_batch)(texts[i : i + batch], verbose, punct_pre_process, punct_post_process)
  125. for i in range(0, len(texts), batch)
  126. )
  127. except BaseException as e:
  128. raise e
  129. normalized_texts = list(itertools.chain(*normalized_texts))
  130. return normalized_texts
  131. def process_batch(self, batch, verbose, punct_pre_process, punct_post_process):
  132. """
  133. Normalizes batch of text sequences
  134. Args:
  135. batch: list of texts
  136. verbose: whether to print intermediate meta information
  137. punct_pre_process: whether to do punctuation pre processing
  138. punct_post_process: whether to do punctuation post processing
  139. """
  140. normalized_lines = [
  141. self.normalize(
  142. text, verbose=verbose, punct_pre_process=punct_pre_process, punct_post_process=punct_post_process
  143. )
  144. for text in tqdm(batch)
  145. ]
  146. return normalized_lines
  147. def _estimate_number_of_permutations_in_nested_dict(
  148. self, token_group: Dict[str, Union[OrderedDict, str, bool]]
  149. ) -> int:
  150. num_perms = 1
  151. for k, inner in token_group.items():
  152. if isinstance(inner, dict):
  153. num_perms *= self._estimate_number_of_permutations_in_nested_dict(inner)
  154. num_perms *= factorial(len(token_group))
  155. return num_perms
  156. def _split_tokens_to_reduce_number_of_permutations(
  157. self, tokens: List[dict], max_number_of_permutations_per_split: int = 729
  158. ) -> List[List[dict]]:
  159. """
  160. Splits a sequence of tokens in a smaller sequences of tokens in a way that maximum number of composite
  161. tokens permutations does not exceed ``max_number_of_permutations_per_split``.
  162. For example,
  163. .. code-block:: python
  164. tokens = [
  165. {"tokens": {"date": {"year": "twenty eighteen", "month": "december", "day": "thirty one"}}},
  166. {"tokens": {"date": {"year": "twenty eighteen", "month": "january", "day": "eight"}}},
  167. ]
  168. split = normalizer._split_tokens_to_reduce_number_of_permutations(
  169. tokens, max_number_of_permutations_per_split=6
  170. )
  171. assert split == [
  172. [{"tokens": {"date": {"year": "twenty eighteen", "month": "december", "day": "thirty one"}}}],
  173. [{"tokens": {"date": {"year": "twenty eighteen", "month": "january", "day": "eight"}}}],
  174. ]
  175. Date tokens contain 3 items each which gives 6 permutations for every date. Since there are 2 dates, total
  176. number of permutations would be ``6 * 6 == 36``. Parameter ``max_number_of_permutations_per_split`` equals 6,
  177. so input sequence of tokens is split into 2 smaller sequences.
  178. Args:
  179. tokens (:obj:`List[dict]`): a list of dictionaries, possibly nested.
  180. max_number_of_permutations_per_split (:obj:`int`, `optional`, defaults to :obj:`243`): a maximum number
  181. of permutations which can be generated from input sequence of tokens.
  182. Returns:
  183. :obj:`List[List[dict]]`: a list of smaller sequences of tokens resulting from ``tokens`` split.
  184. """
  185. splits = []
  186. prev_end_of_split = 0
  187. current_number_of_permutations = 1
  188. for i, token_group in enumerate(tokens):
  189. n = self._estimate_number_of_permutations_in_nested_dict(token_group)
  190. if n * current_number_of_permutations > max_number_of_permutations_per_split:
  191. splits.append(tokens[prev_end_of_split:i])
  192. prev_end_of_split = i
  193. current_number_of_permutations = 1
  194. if n > max_number_of_permutations_per_split:
  195. raise ValueError(
  196. f"Could not split token list with respect to condition that every split can generate number of "
  197. f"permutations less or equal to "
  198. f"`max_number_of_permutations_per_split={max_number_of_permutations_per_split}`. "
  199. f"There is an unsplittable token group that generates more than "
  200. f"{max_number_of_permutations_per_split} permutations. Try to increase "
  201. f"`max_number_of_permutations_per_split` parameter."
  202. )
  203. current_number_of_permutations *= n
  204. splits.append(tokens[prev_end_of_split:])
  205. assert sum([len(s) for s in splits]) == len(tokens)
  206. return splits
  207. def normalize(
  208. self, text: str, verbose: bool = False, punct_pre_process: bool = False, punct_post_process: bool = False
  209. ) -> str:
  210. """
  211. Main function. Normalizes tokens from written to spoken form
  212. e.g. 12 kg -> twelve kilograms
  213. Args:
  214. text: string that may include semiotic classes
  215. verbose: whether to print intermediate meta information
  216. punct_pre_process: whether to perform punctuation pre-processing, for example, [25] -> [ 25 ]
  217. punct_post_process: whether to normalize punctuation
  218. Returns: spoken form
  219. """
  220. if len(text.split()) > 500:
  221. print(
  222. "WARNING! Your input is too long and could take a long time to normalize."
  223. "Use split_text_into_sentences() to make the input shorter and then call normalize_list()."
  224. )
  225. original_text = text
  226. if punct_pre_process:
  227. text = pre_process(text)
  228. text = text.strip()
  229. if not text:
  230. if verbose:
  231. print(text)
  232. return text
  233. text = pynini.escape(text)
  234. tagged_lattice = self.find_tags(text)
  235. tagged_text = self.select_tag(tagged_lattice)
  236. if verbose:
  237. print(tagged_text)
  238. self.parser(tagged_text)
  239. tokens = self.parser.parse()
  240. split_tokens = self._split_tokens_to_reduce_number_of_permutations(tokens)
  241. output = ""
  242. for s in split_tokens:
  243. tags_reordered = self.generate_permutations(s)
  244. verbalizer_lattice = None
  245. for tagged_text in tags_reordered:
  246. tagged_text = pynini.escape(tagged_text)
  247. verbalizer_lattice = self.find_verbalizer(tagged_text)
  248. if verbalizer_lattice.num_states() != 0:
  249. break
  250. if verbalizer_lattice is None:
  251. raise ValueError(f"No permutations were generated from tokens {s}")
  252. output += ' ' + self.select_verbalizer(verbalizer_lattice)
  253. output = SPACE_DUP.sub(' ', output[1:])
  254. if self.lang == "en" and hasattr(self, 'post_processor'):
  255. output = self.post_process(output)
  256. if punct_post_process:
  257. # do post-processing based on Moses detokenizer
  258. if self.processor:
  259. output = self.processor.moses_detokenizer.detokenize([output], unescape=False)
  260. output = post_process_punct(input=original_text, normalized_text=output)
  261. else:
  262. print("DAMO_NLP collection is not available: skipping punctuation post_processing")
  263. return output
  264. def split_text_into_sentences(self, text: str) -> List[str]:
  265. """
  266. Split text into sentences.
  267. Args:
  268. text: text
  269. Returns list of sentences
  270. """
  271. lower_case_unicode = ''
  272. upper_case_unicode = ''
  273. if self.lang == "ru":
  274. lower_case_unicode = '\u0430-\u04FF'
  275. upper_case_unicode = '\u0410-\u042F'
  276. # Read and split transcript by utterance (roughly, sentences)
  277. split_pattern = f"(?<!\w\.\w.)(?<![A-Z{upper_case_unicode}][a-z{lower_case_unicode}]+\.)(?<![A-Z{upper_case_unicode}]\.)(?<=\.|\?|\!|\.”|\?”\!”)\s(?![0-9]+[a-z]*\.)"
  278. sentences = regex.split(split_pattern, text)
  279. return sentences
  280. def _permute(self, d: OrderedDict) -> List[str]:
  281. """
  282. Creates reorderings of dictionary elements and serializes as strings
  283. Args:
  284. d: (nested) dictionary of key value pairs
  285. Return permutations of different string serializations of key value pairs
  286. """
  287. l = []
  288. if PRESERVE_ORDER_KEY in d.keys():
  289. d_permutations = [d.items()]
  290. else:
  291. d_permutations = itertools.permutations(d.items())
  292. for perm in d_permutations:
  293. subl = [""]
  294. for k, v in perm:
  295. if isinstance(v, str):
  296. subl = ["".join(x) for x in itertools.product(subl, [f"{k}: \"{v}\" "])]
  297. elif isinstance(v, OrderedDict):
  298. rec = self._permute(v)
  299. subl = ["".join(x) for x in itertools.product(subl, [f" {k} {{ "], rec, [f" }} "])]
  300. elif isinstance(v, bool):
  301. subl = ["".join(x) for x in itertools.product(subl, [f"{k}: true "])]
  302. else:
  303. raise ValueError()
  304. l.extend(subl)
  305. return l
  306. def generate_permutations(self, tokens: List[dict]):
  307. """
  308. Generates permutations of string serializations of list of dictionaries
  309. Args:
  310. tokens: list of dictionaries
  311. Returns string serialization of list of dictionaries
  312. """
  313. def _helper(prefix: str, tokens: List[dict], idx: int):
  314. """
  315. Generates permutations of string serializations of given dictionary
  316. Args:
  317. tokens: list of dictionaries
  318. prefix: prefix string
  319. idx: index of next dictionary
  320. Returns string serialization of dictionary
  321. """
  322. if idx == len(tokens):
  323. yield prefix
  324. return
  325. token_options = self._permute(tokens[idx])
  326. for token_option in token_options:
  327. yield from _helper(prefix + token_option, tokens, idx + 1)
  328. return _helper("", tokens, 0)
  329. def find_tags(self, text: str) -> 'pynini.FstLike':
  330. """
  331. Given text use tagger Fst to tag text
  332. Args:
  333. text: sentence
  334. Returns: tagged lattice
  335. """
  336. lattice = text @ self.tagger.fst
  337. return lattice
  338. def select_tag(self, lattice: 'pynini.FstLike') -> str:
  339. """
  340. Given tagged lattice return shortest path
  341. Args:
  342. tagged_text: tagged text
  343. Returns: shortest path
  344. """
  345. tagged_text = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
  346. return tagged_text
  347. def find_verbalizer(self, tagged_text: str) -> 'pynini.FstLike':
  348. """
  349. Given tagged text creates verbalization lattice
  350. This is context-independent.
  351. Args:
  352. tagged_text: input text
  353. Returns: verbalized lattice
  354. """
  355. lattice = tagged_text @ self.verbalizer.fst
  356. return lattice
  357. def select_verbalizer(self, lattice: 'pynini.FstLike') -> str:
  358. """
  359. Given verbalized lattice return shortest path
  360. Args:
  361. lattice: verbalization lattice
  362. Returns: shortest path
  363. """
  364. output = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
  365. # lattice = output @ self.verbalizer.punct_graph
  366. # output = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
  367. return output
  368. def post_process(self, normalized_text: 'pynini.FstLike') -> str:
  369. """
  370. Runs post processing graph on normalized text
  371. Args:
  372. normalized_text: normalized text
  373. Returns: shortest path
  374. """
  375. normalized_text = normalized_text.strip()
  376. if not normalized_text:
  377. return normalized_text
  378. normalized_text = pynini.escape(normalized_text)
  379. if self.post_processor is not None:
  380. normalized_text = top_rewrite(normalized_text, self.post_processor.fst)
  381. return normalized_text
  382. def parse_args():
  383. parser = ArgumentParser()
  384. input = parser.add_mutually_exclusive_group()
  385. input.add_argument("--text", dest="input_string", help="input string", type=str)
  386. input.add_argument("--input_file", dest="input_file", help="input file path", type=str)
  387. parser.add_argument('--output_file', dest="output_file", help="output file path", type=str)
  388. parser.add_argument("--language", help="language", choices=["en", "de", "es", "zh"], default="en", type=str)
  389. parser.add_argument(
  390. "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
  391. )
  392. parser.add_argument("--verbose", help="print info for debugging", action='store_true')
  393. parser.add_argument(
  394. "--punct_post_process",
  395. help="set to True to enable punctuation post processing to match input.",
  396. action="store_true",
  397. )
  398. parser.add_argument(
  399. "--punct_pre_process", help="set to True to enable punctuation pre processing", action="store_true"
  400. )
  401. parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
  402. parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str)
  403. parser.add_argument(
  404. "--cache_dir",
  405. help="path to a dir with .far grammar file. Set to None to avoid using cache",
  406. default=None,
  407. type=str,
  408. )
  409. return parser.parse_args()
  410. if __name__ == "__main__":
  411. start_time = perf_counter()
  412. args = parse_args()
  413. whitelist = os.path.abspath(args.whitelist) if args.whitelist else None
  414. if not args.input_string and not args.input_file:
  415. raise ValueError("Either `--text` or `--input_file` required")
  416. normalizer = Normalizer(
  417. input_case=args.input_case,
  418. cache_dir=args.cache_dir,
  419. overwrite_cache=args.overwrite_cache,
  420. whitelist=whitelist,
  421. lang=args.language,
  422. )
  423. if args.input_string:
  424. print(
  425. normalizer.normalize(
  426. args.input_string,
  427. verbose=args.verbose,
  428. punct_pre_process=args.punct_pre_process,
  429. punct_post_process=args.punct_post_process,
  430. )
  431. )
  432. elif args.input_file:
  433. print("Loading data: " + args.input_file)
  434. data = load_file(args.input_file)
  435. print("- Data: " + str(len(data)) + " sentences")
  436. normalizer_prediction = normalizer.normalize_list(
  437. data,
  438. verbose=args.verbose,
  439. punct_pre_process=args.punct_pre_process,
  440. punct_post_process=args.punct_post_process,
  441. )
  442. if args.output_file:
  443. write_file(args.output_file, normalizer_prediction)
  444. print(f"- Normalized. Writing out to {args.output_file}")
  445. else:
  446. print(normalizer_prediction)
  447. print(f"Execution time: {perf_counter() - start_time:.02f} sec")