normalize.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. # Copyright NeMo (https://github.com/NVIDIA/NeMo). All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import itertools
  15. import os
  16. import re
  17. from argparse import ArgumentParser
  18. from collections import OrderedDict
  19. from math import factorial
  20. from time import perf_counter
  21. from typing import Dict, List, Union
  22. import pynini
  23. import regex
  24. from joblib import Parallel, delayed
  25. from fun_text_processing.text_normalization.data_loader_utils import (
  26. load_file,
  27. post_process_punct,
  28. pre_process,
  29. write_file,
  30. )
  31. from fun_text_processing.text_normalization.token_parser import PRESERVE_ORDER_KEY, TokenParser
  32. from pynini.lib.rewrite import top_rewrite
  33. from tqdm import tqdm
  34. try:
  35. from nemo.collections.common.tokenizers.moses_tokenizers import MosesProcessor
  36. NLP_AVAILABLE = True
  37. except (ModuleNotFoundError, ImportError) as e:
  38. NLP_AVAILABLE = False
  39. SPACE_DUP = re.compile(' {2,}')
  40. class Normalizer:
  41. """
  42. Normalizer class that converts text from written to spoken form.
  43. Useful for TTS preprocessing.
  44. Args:
  45. input_case: expected input capitalization
  46. lang: language specifying the TN rules, by default: English
  47. cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
  48. overwrite_cache: set to True to overwrite .far files
  49. whitelist: path to a file with whitelist replacements
  50. post_process: WFST-based post processing, e.g. to remove extra spaces added during TN.
  51. Note: punct_post_process flag in normalize() supports all languages.
  52. """
  53. def __init__(
  54. self,
  55. input_case: str,
  56. lang: str = 'en',
  57. deterministic: bool = True,
  58. cache_dir: str = None,
  59. overwrite_cache: bool = False,
  60. whitelist: str = None,
  61. lm: bool = False,
  62. post_process: bool = True,
  63. ):
  64. assert input_case in ["lower_cased", "cased"]
  65. self.post_processor = None
  66. if lang == "en":
  67. from fun_text_processing.text_normalization.en.verbalizers.verbalize_final import VerbalizeFinalFst
  68. from fun_text_processing.text_normalization.en.verbalizers.post_processing import PostProcessingFst
  69. if post_process:
  70. self.post_processor = PostProcessingFst(cache_dir=cache_dir, overwrite_cache=overwrite_cache)
  71. if deterministic:
  72. from fun_text_processing.text_normalization.en.taggers.tokenize_and_classify import ClassifyFst
  73. else:
  74. if lm:
  75. from fun_text_processing.text_normalization.en.taggers.tokenize_and_classify_lm import ClassifyFst
  76. else:
  77. from fun_text_processing.text_normalization.en.taggers.tokenize_and_classify_with_audio import (
  78. ClassifyFst,
  79. )
  80. elif lang == 'ru':
  81. # Ru TN only support non-deterministic cases and produces multiple normalization options
  82. # use normalize_with_audio.py
  83. from fun_text_processing.text_normalization.ru.taggers.tokenize_and_classify import ClassifyFst
  84. from fun_text_processing.text_normalization.ru.verbalizers.verbalize_final import VerbalizeFinalFst
  85. elif lang == 'de':
  86. from fun_text_processing.text_normalization.de.taggers.tokenize_and_classify import ClassifyFst
  87. from fun_text_processing.text_normalization.de.verbalizers.verbalize_final import VerbalizeFinalFst
  88. elif lang == 'es':
  89. from fun_text_processing.text_normalization.es.taggers.tokenize_and_classify import ClassifyFst
  90. from fun_text_processing.text_normalization.es.verbalizers.verbalize_final import VerbalizeFinalFst
  91. elif lang == 'zh':
  92. from fun_text_processing.text_normalization.zh.taggers.tokenize_and_classify import ClassifyFst
  93. from fun_text_processing.text_normalization.zh.verbalizers.verbalize_final import VerbalizeFinalFst
  94. self.tagger = ClassifyFst(
  95. input_case=input_case,
  96. deterministic=deterministic,
  97. cache_dir=cache_dir,
  98. overwrite_cache=overwrite_cache,
  99. whitelist=whitelist,
  100. )
  101. self.verbalizer = VerbalizeFinalFst(
  102. deterministic=deterministic, cache_dir=cache_dir, overwrite_cache=overwrite_cache
  103. )
  104. self.parser = TokenParser()
  105. self.lang = lang
  106. if NLP_AVAILABLE:
  107. self.processor = MosesProcessor(lang_id=lang)
  108. else:
  109. self.processor = None
  110. print("NeMo NLP is not available. Moses de-tokenization will be skipped.")
  111. def normalize_list(
  112. self,
  113. texts: List[str],
  114. verbose: bool = False,
  115. punct_pre_process: bool = False,
  116. punct_post_process: bool = False,
  117. batch_size: int = 1,
  118. n_jobs: int = 1,
  119. ):
  120. """
  121. NeMo text normalizer
  122. Args:
  123. texts: list of input strings
  124. verbose: whether to print intermediate meta information
  125. punct_pre_process: whether to do punctuation pre processing
  126. punct_post_process: whether to do punctuation post processing
  127. n_jobs: the maximum number of concurrently running jobs. If -1 all CPUs are used. If 1 is given,
  128. no parallel computing code is used at all, which is useful for debugging. For n_jobs below -1,
  129. (n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one are used.
  130. batch_size: Number of examples for each process
  131. Returns converted list input strings
  132. """
  133. # to save intermediate results to a file
  134. batch = min(len(texts), batch_size)
  135. try:
  136. normalized_texts = Parallel(n_jobs=n_jobs)(
  137. delayed(self.process_batch)(texts[i : i + batch], verbose, punct_pre_process, punct_post_process)
  138. for i in range(0, len(texts), batch)
  139. )
  140. except BaseException as e:
  141. raise e
  142. normalized_texts = list(itertools.chain(*normalized_texts))
  143. return normalized_texts
  144. def process_batch(self, batch, verbose, punct_pre_process, punct_post_process):
  145. """
  146. Normalizes batch of text sequences
  147. Args:
  148. batch: list of texts
  149. verbose: whether to print intermediate meta information
  150. punct_pre_process: whether to do punctuation pre processing
  151. punct_post_process: whether to do punctuation post processing
  152. """
  153. normalized_lines = [
  154. self.normalize(
  155. text, verbose=verbose, punct_pre_process=punct_pre_process, punct_post_process=punct_post_process
  156. )
  157. for text in tqdm(batch)
  158. ]
  159. return normalized_lines
  160. def _estimate_number_of_permutations_in_nested_dict(
  161. self, token_group: Dict[str, Union[OrderedDict, str, bool]]
  162. ) -> int:
  163. num_perms = 1
  164. for k, inner in token_group.items():
  165. if isinstance(inner, dict):
  166. num_perms *= self._estimate_number_of_permutations_in_nested_dict(inner)
  167. num_perms *= factorial(len(token_group))
  168. return num_perms
  169. def _split_tokens_to_reduce_number_of_permutations(
  170. self, tokens: List[dict], max_number_of_permutations_per_split: int = 729
  171. ) -> List[List[dict]]:
  172. """
  173. Splits a sequence of tokens in a smaller sequences of tokens in a way that maximum number of composite
  174. tokens permutations does not exceed ``max_number_of_permutations_per_split``.
  175. For example,
  176. .. code-block:: python
  177. tokens = [
  178. {"tokens": {"date": {"year": "twenty eighteen", "month": "december", "day": "thirty one"}}},
  179. {"tokens": {"date": {"year": "twenty eighteen", "month": "january", "day": "eight"}}},
  180. ]
  181. split = normalizer._split_tokens_to_reduce_number_of_permutations(
  182. tokens, max_number_of_permutations_per_split=6
  183. )
  184. assert split == [
  185. [{"tokens": {"date": {"year": "twenty eighteen", "month": "december", "day": "thirty one"}}}],
  186. [{"tokens": {"date": {"year": "twenty eighteen", "month": "january", "day": "eight"}}}],
  187. ]
  188. Date tokens contain 3 items each which gives 6 permutations for every date. Since there are 2 dates, total
  189. number of permutations would be ``6 * 6 == 36``. Parameter ``max_number_of_permutations_per_split`` equals 6,
  190. so input sequence of tokens is split into 2 smaller sequences.
  191. Args:
  192. tokens (:obj:`List[dict]`): a list of dictionaries, possibly nested.
  193. max_number_of_permutations_per_split (:obj:`int`, `optional`, defaults to :obj:`243`): a maximum number
  194. of permutations which can be generated from input sequence of tokens.
  195. Returns:
  196. :obj:`List[List[dict]]`: a list of smaller sequences of tokens resulting from ``tokens`` split.
  197. """
  198. splits = []
  199. prev_end_of_split = 0
  200. current_number_of_permutations = 1
  201. for i, token_group in enumerate(tokens):
  202. n = self._estimate_number_of_permutations_in_nested_dict(token_group)
  203. if n * current_number_of_permutations > max_number_of_permutations_per_split:
  204. splits.append(tokens[prev_end_of_split:i])
  205. prev_end_of_split = i
  206. current_number_of_permutations = 1
  207. if n > max_number_of_permutations_per_split:
  208. raise ValueError(
  209. f"Could not split token list with respect to condition that every split can generate number of "
  210. f"permutations less or equal to "
  211. f"`max_number_of_permutations_per_split={max_number_of_permutations_per_split}`. "
  212. f"There is an unsplittable token group that generates more than "
  213. f"{max_number_of_permutations_per_split} permutations. Try to increase "
  214. f"`max_number_of_permutations_per_split` parameter."
  215. )
  216. current_number_of_permutations *= n
  217. splits.append(tokens[prev_end_of_split:])
  218. assert sum([len(s) for s in splits]) == len(tokens)
  219. return splits
  220. def normalize(
  221. self, text: str, verbose: bool = False, punct_pre_process: bool = False, punct_post_process: bool = False
  222. ) -> str:
  223. """
  224. Main function. Normalizes tokens from written to spoken form
  225. e.g. 12 kg -> twelve kilograms
  226. Args:
  227. text: string that may include semiotic classes
  228. verbose: whether to print intermediate meta information
  229. punct_pre_process: whether to perform punctuation pre-processing, for example, [25] -> [ 25 ]
  230. punct_post_process: whether to normalize punctuation
  231. Returns: spoken form
  232. """
  233. if len(text.split()) > 500:
  234. print(
  235. "WARNING! Your input is too long and could take a long time to normalize."
  236. "Use split_text_into_sentences() to make the input shorter and then call normalize_list()."
  237. )
  238. original_text = text
  239. if punct_pre_process:
  240. text = pre_process(text)
  241. text = text.strip()
  242. if not text:
  243. if verbose:
  244. print(text)
  245. return text
  246. text = pynini.escape(text)
  247. tagged_lattice = self.find_tags(text)
  248. tagged_text = self.select_tag(tagged_lattice)
  249. if verbose:
  250. print(tagged_text)
  251. self.parser(tagged_text)
  252. tokens = self.parser.parse()
  253. split_tokens = self._split_tokens_to_reduce_number_of_permutations(tokens)
  254. output = ""
  255. for s in split_tokens:
  256. tags_reordered = self.generate_permutations(s)
  257. verbalizer_lattice = None
  258. for tagged_text in tags_reordered:
  259. tagged_text = pynini.escape(tagged_text)
  260. verbalizer_lattice = self.find_verbalizer(tagged_text)
  261. if verbalizer_lattice.num_states() != 0:
  262. break
  263. if verbalizer_lattice is None:
  264. raise ValueError(f"No permutations were generated from tokens {s}")
  265. output += ' ' + self.select_verbalizer(verbalizer_lattice)
  266. output = SPACE_DUP.sub(' ', output[1:])
  267. if self.lang == "en" and hasattr(self, 'post_processor'):
  268. output = self.post_process(output)
  269. if punct_post_process:
  270. # do post-processing based on Moses detokenizer
  271. if self.processor:
  272. output = self.processor.moses_detokenizer.detokenize([output], unescape=False)
  273. output = post_process_punct(input=original_text, normalized_text=output)
  274. else:
  275. print("DAMO_NLP collection is not available: skipping punctuation post_processing")
  276. return output
  277. def split_text_into_sentences(self, text: str) -> List[str]:
  278. """
  279. Split text into sentences.
  280. Args:
  281. text: text
  282. Returns list of sentences
  283. """
  284. lower_case_unicode = ''
  285. upper_case_unicode = ''
  286. if self.lang == "ru":
  287. lower_case_unicode = '\u0430-\u04FF'
  288. upper_case_unicode = '\u0410-\u042F'
  289. # Read and split transcript by utterance (roughly, sentences)
  290. split_pattern = f"(?<!\w\.\w.)(?<![A-Z{upper_case_unicode}][a-z{lower_case_unicode}]+\.)(?<![A-Z{upper_case_unicode}]\.)(?<=\.|\?|\!|\.”|\?”\!”)\s(?![0-9]+[a-z]*\.)"
  291. sentences = regex.split(split_pattern, text)
  292. return sentences
  293. def _permute(self, d: OrderedDict) -> List[str]:
  294. """
  295. Creates reorderings of dictionary elements and serializes as strings
  296. Args:
  297. d: (nested) dictionary of key value pairs
  298. Return permutations of different string serializations of key value pairs
  299. """
  300. l = []
  301. if PRESERVE_ORDER_KEY in d.keys():
  302. d_permutations = [d.items()]
  303. else:
  304. d_permutations = itertools.permutations(d.items())
  305. for perm in d_permutations:
  306. subl = [""]
  307. for k, v in perm:
  308. if isinstance(v, str):
  309. subl = ["".join(x) for x in itertools.product(subl, [f"{k}: \"{v}\" "])]
  310. elif isinstance(v, OrderedDict):
  311. rec = self._permute(v)
  312. subl = ["".join(x) for x in itertools.product(subl, [f" {k} {{ "], rec, [f" }} "])]
  313. elif isinstance(v, bool):
  314. subl = ["".join(x) for x in itertools.product(subl, [f"{k}: true "])]
  315. else:
  316. raise ValueError()
  317. l.extend(subl)
  318. return l
  319. def generate_permutations(self, tokens: List[dict]):
  320. """
  321. Generates permutations of string serializations of list of dictionaries
  322. Args:
  323. tokens: list of dictionaries
  324. Returns string serialization of list of dictionaries
  325. """
  326. def _helper(prefix: str, tokens: List[dict], idx: int):
  327. """
  328. Generates permutations of string serializations of given dictionary
  329. Args:
  330. tokens: list of dictionaries
  331. prefix: prefix string
  332. idx: index of next dictionary
  333. Returns string serialization of dictionary
  334. """
  335. if idx == len(tokens):
  336. yield prefix
  337. return
  338. token_options = self._permute(tokens[idx])
  339. for token_option in token_options:
  340. yield from _helper(prefix + token_option, tokens, idx + 1)
  341. return _helper("", tokens, 0)
  342. def find_tags(self, text: str) -> 'pynini.FstLike':
  343. """
  344. Given text use tagger Fst to tag text
  345. Args:
  346. text: sentence
  347. Returns: tagged lattice
  348. """
  349. lattice = text @ self.tagger.fst
  350. return lattice
  351. def select_tag(self, lattice: 'pynini.FstLike') -> str:
  352. """
  353. Given tagged lattice return shortest path
  354. Args:
  355. tagged_text: tagged text
  356. Returns: shortest path
  357. """
  358. tagged_text = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
  359. return tagged_text
  360. def find_verbalizer(self, tagged_text: str) -> 'pynini.FstLike':
  361. """
  362. Given tagged text creates verbalization lattice
  363. This is context-independent.
  364. Args:
  365. tagged_text: input text
  366. Returns: verbalized lattice
  367. """
  368. lattice = tagged_text @ self.verbalizer.fst
  369. return lattice
  370. def select_verbalizer(self, lattice: 'pynini.FstLike') -> str:
  371. """
  372. Given verbalized lattice return shortest path
  373. Args:
  374. lattice: verbalization lattice
  375. Returns: shortest path
  376. """
  377. output = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
  378. # lattice = output @ self.verbalizer.punct_graph
  379. # output = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
  380. return output
  381. def post_process(self, normalized_text: 'pynini.FstLike') -> str:
  382. """
  383. Runs post processing graph on normalized text
  384. Args:
  385. normalized_text: normalized text
  386. Returns: shortest path
  387. """
  388. normalized_text = normalized_text.strip()
  389. if not normalized_text:
  390. return normalized_text
  391. normalized_text = pynini.escape(normalized_text)
  392. if self.post_processor is not None:
  393. normalized_text = top_rewrite(normalized_text, self.post_processor.fst)
  394. return normalized_text
  395. def parse_args():
  396. parser = ArgumentParser()
  397. input = parser.add_mutually_exclusive_group()
  398. input.add_argument("--text", dest="input_string", help="input string", type=str)
  399. input.add_argument("--input_file", dest="input_file", help="input file path", type=str)
  400. parser.add_argument('--output_file', dest="output_file", help="output file path", type=str)
  401. parser.add_argument("--language", help="language", choices=["en", "de", "es", "zh"], default="en", type=str)
  402. parser.add_argument(
  403. "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
  404. )
  405. parser.add_argument("--verbose", help="print info for debugging", action='store_true')
  406. parser.add_argument(
  407. "--punct_post_process",
  408. help="set to True to enable punctuation post processing to match input.",
  409. action="store_true",
  410. )
  411. parser.add_argument(
  412. "--punct_pre_process", help="set to True to enable punctuation pre processing", action="store_true"
  413. )
  414. parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
  415. parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str)
  416. parser.add_argument(
  417. "--cache_dir",
  418. help="path to a dir with .far grammar file. Set to None to avoid using cache",
  419. default=None,
  420. type=str,
  421. )
  422. return parser.parse_args()
  423. if __name__ == "__main__":
  424. start_time = perf_counter()
  425. args = parse_args()
  426. whitelist = os.path.abspath(args.whitelist) if args.whitelist else None
  427. if not args.input_string and not args.input_file:
  428. raise ValueError("Either `--text` or `--input_file` required")
  429. normalizer = Normalizer(
  430. input_case=args.input_case,
  431. cache_dir=args.cache_dir,
  432. overwrite_cache=args.overwrite_cache,
  433. whitelist=whitelist,
  434. lang=args.language,
  435. )
  436. if args.input_string:
  437. print(
  438. normalizer.normalize(
  439. args.input_string,
  440. verbose=args.verbose,
  441. punct_pre_process=args.punct_pre_process,
  442. punct_post_process=args.punct_post_process,
  443. )
  444. )
  445. elif args.input_file:
  446. print("Loading data: " + args.input_file)
  447. data = load_file(args.input_file)
  448. print("- Data: " + str(len(data)) + " sentences")
  449. normalizer_prediction = normalizer.normalize_list(
  450. data,
  451. verbose=args.verbose,
  452. punct_pre_process=args.punct_pre_process,
  453. punct_post_process=args.punct_post_process,
  454. )
  455. if args.output_file:
  456. write_file(args.output_file, normalizer_prediction)
  457. print(f"- Normalized. Writing out to {args.output_file}")
  458. else:
  459. print(normalizer_prediction)
  460. print(f"Execution time: {perf_counter() - start_time:.02f} sec")