normalize_with_audio.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. import json
  2. import os
  3. import time
  4. from argparse import ArgumentParser
  5. from glob import glob
  6. from typing import List, Tuple
  7. import pynini
  8. from joblib import Parallel, delayed
  9. from fun_text_processing.text_normalization.data_loader_utils import post_process_punct, pre_process
  10. from fun_text_processing.text_normalization.normalize import Normalizer
  11. from pynini.lib import rewrite
  12. from tqdm import tqdm
  13. try:
  14. from nemo.collections.asr.metrics.wer import word_error_rate
  15. from nemo.collections.asr.models import ASRModel
  16. ASR_AVAILABLE = True
  17. except (ModuleNotFoundError, ImportError):
  18. ASR_AVAILABLE = False
  19. """
  20. The script provides multiple normalization options and chooses the best one that minimizes CER of the ASR output
  21. (most of the semiotic classes use deterministic=False flag).
  22. To run this script with a .json manifest file, the manifest file should contain the following fields:
  23. "audio_data" - path to the audio file
  24. "text" - raw text
  25. "pred_text" - ASR model prediction
  26. See https://github.com/NVIDIA/NeMo/blob/main/examples/asr/transcribe_speech.py on how to add ASR predictions
  27. When the manifest is ready, run:
  28. python normalize_with_audio.py \
  29. --audio_data PATH/TO/MANIFEST.JSON \
  30. --language en
  31. To run with a single audio file, specify path to audio and text with:
  32. python normalize_with_audio.py \
  33. --audio_data PATH/TO/AUDIO.WAV \
  34. --language en \
  35. --text raw text OR PATH/TO/.TXT/FILE
  36. --model QuartzNet15x5Base-En \
  37. --verbose
  38. To see possible normalization options for a text input without an audio file (could be used for debugging), run:
  39. python python normalize_with_audio.py --text "RAW TEXT"
  40. Specify `--cache_dir` to generate .far grammars once and re-used them for faster inference
  41. """
  42. class NormalizerWithAudio(Normalizer):
  43. """
  44. Normalizer class that converts text from written to spoken form.
  45. Useful for TTS preprocessing.
  46. Args:
  47. input_case: expected input capitalization
  48. lang: language
  49. cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
  50. overwrite_cache: set to True to overwrite .far files
  51. whitelist: path to a file with whitelist replacements
  52. post_process: WFST-based post processing, e.g. to remove extra spaces added during TN.
  53. Note: punct_post_process flag in normalize() supports all languages.
  54. """
  55. def __init__(
  56. self,
  57. input_case: str,
  58. lang: str = 'en',
  59. cache_dir: str = None,
  60. overwrite_cache: bool = False,
  61. whitelist: str = None,
  62. lm: bool = False,
  63. post_process: bool = True,
  64. ):
  65. super().__init__(
  66. input_case=input_case,
  67. lang=lang,
  68. deterministic=False,
  69. cache_dir=cache_dir,
  70. overwrite_cache=overwrite_cache,
  71. whitelist=whitelist,
  72. lm=lm,
  73. post_process=post_process,
  74. )
  75. self.lm = lm
  76. def normalize(self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False,) -> str:
  77. """
  78. Main function. Normalizes tokens from written to spoken form
  79. e.g. 12 kg -> twelve kilograms
  80. Args:
  81. text: string that may include semiotic classes
  82. n_tagged: number of tagged options to consider, -1 - to get all possible tagged options
  83. punct_post_process: whether to normalize punctuation
  84. verbose: whether to print intermediate meta information
  85. Returns:
  86. normalized text options (usually there are multiple ways of normalizing a given semiotic class)
  87. """
  88. if len(text.split()) > 500:
  89. raise ValueError(
  90. "Your input is too long. Please split up the input into sentences, "
  91. "or strings with fewer than 500 words"
  92. )
  93. original_text = text
  94. text = pre_process(text) # to handle []
  95. text = text.strip()
  96. if not text:
  97. if verbose:
  98. print(text)
  99. return text
  100. text = pynini.escape(text)
  101. print(text)
  102. if self.lm:
  103. if self.lang not in ["en"]:
  104. raise ValueError(f"{self.lang} is not supported in LM mode")
  105. if self.lang == "en":
  106. # this to keep arpabet phonemes in the list of options
  107. if "[" in text and "]" in text:
  108. lattice = rewrite.rewrite_lattice(text, self.tagger.fst)
  109. else:
  110. try:
  111. lattice = rewrite.rewrite_lattice(text, self.tagger.fst_no_digits)
  112. except pynini.lib.rewrite.Error:
  113. lattice = rewrite.rewrite_lattice(text, self.tagger.fst)
  114. lattice = rewrite.lattice_to_nshortest(lattice, n_tagged)
  115. tagged_texts = [(x[1], float(x[2])) for x in lattice.paths().items()]
  116. tagged_texts.sort(key=lambda x: x[1])
  117. tagged_texts, weights = list(zip(*tagged_texts))
  118. else:
  119. tagged_texts = self._get_tagged_text(text, n_tagged)
  120. # non-deterministic Eng normalization uses tagger composed with verbalizer, no permutation in between
  121. if self.lang == "en":
  122. normalized_texts = tagged_texts
  123. normalized_texts = [self.post_process(text) for text in normalized_texts]
  124. else:
  125. normalized_texts = []
  126. for tagged_text in tagged_texts:
  127. self._verbalize(tagged_text, normalized_texts, verbose=verbose)
  128. if len(normalized_texts) == 0:
  129. raise ValueError()
  130. if punct_post_process:
  131. # do post-processing based on Moses detokenizer
  132. if self.processor:
  133. normalized_texts = [self.processor.detokenize([t]) for t in normalized_texts]
  134. normalized_texts = [
  135. post_process_punct(input=original_text, normalized_text=t) for t in normalized_texts
  136. ]
  137. if self.lm:
  138. remove_dup = sorted(list(set(zip(normalized_texts, weights))), key=lambda x: x[1])
  139. normalized_texts, weights = zip(*remove_dup)
  140. return list(normalized_texts), weights
  141. normalized_texts = set(normalized_texts)
  142. return normalized_texts
  143. def _get_tagged_text(self, text, n_tagged):
  144. """
  145. Returns text after tokenize and classify
  146. Args;
  147. text: input text
  148. n_tagged: number of tagged options to consider, -1 - return all possible tagged options
  149. """
  150. if n_tagged == -1:
  151. if self.lang == "en":
  152. # this to keep arpabet phonemes in the list of options
  153. if "[" in text and "]" in text:
  154. tagged_texts = rewrite.rewrites(text, self.tagger.fst)
  155. else:
  156. try:
  157. tagged_texts = rewrite.rewrites(text, self.tagger.fst_no_digits)
  158. except pynini.lib.rewrite.Error:
  159. tagged_texts = rewrite.rewrites(text, self.tagger.fst)
  160. else:
  161. tagged_texts = rewrite.rewrites(text, self.tagger.fst)
  162. else:
  163. if self.lang == "en":
  164. # this to keep arpabet phonemes in the list of options
  165. if "[" in text and "]" in text:
  166. tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
  167. else:
  168. try:
  169. # try self.tagger graph that produces output without digits
  170. tagged_texts = rewrite.top_rewrites(text, self.tagger.fst_no_digits, nshortest=n_tagged)
  171. except pynini.lib.rewrite.Error:
  172. tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
  173. else:
  174. tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
  175. return tagged_texts
  176. def _verbalize(self, tagged_text: str, normalized_texts: List[str], verbose: bool = False):
  177. """
  178. Verbalizes tagged text
  179. Args:
  180. tagged_text: text with tags
  181. normalized_texts: list of possible normalization options
  182. verbose: if true prints intermediate classification results
  183. """
  184. def get_verbalized_text(tagged_text):
  185. return rewrite.rewrites(tagged_text, self.verbalizer.fst)
  186. self.parser(tagged_text)
  187. tokens = self.parser.parse()
  188. tags_reordered = self.generate_permutations(tokens)
  189. for tagged_text_reordered in tags_reordered:
  190. try:
  191. tagged_text_reordered = pynini.escape(tagged_text_reordered)
  192. normalized_texts.extend(get_verbalized_text(tagged_text_reordered))
  193. if verbose:
  194. print(tagged_text_reordered)
  195. except pynini.lib.rewrite.Error:
  196. continue
  197. def select_best_match(
  198. self,
  199. normalized_texts: List[str],
  200. input_text: str,
  201. pred_text: str,
  202. verbose: bool = False,
  203. remove_punct: bool = False,
  204. cer_threshold: int = 100,
  205. ):
  206. """
  207. Selects the best normalization option based on the lowest CER
  208. Args:
  209. normalized_texts: normalized text options
  210. input_text: input text
  211. pred_text: ASR model transcript of the audio file corresponding to the normalized text
  212. verbose: whether to print intermediate meta information
  213. remove_punct: whether to remove punctuation before calculating CER
  214. cer_threshold: if CER for pred_text is above the cer_threshold, no normalization will be performed
  215. Returns:
  216. normalized text with the lowest CER and CER value
  217. """
  218. if pred_text == "":
  219. return input_text, cer_threshold
  220. normalized_texts_cer = calculate_cer(normalized_texts, pred_text, remove_punct)
  221. normalized_texts_cer = sorted(normalized_texts_cer, key=lambda x: x[1])
  222. normalized_text, cer = normalized_texts_cer[0]
  223. if cer > cer_threshold:
  224. return input_text, cer
  225. if verbose:
  226. print('-' * 30)
  227. for option in normalized_texts:
  228. print(option)
  229. print('-' * 30)
  230. return normalized_text, cer
  231. def calculate_cer(normalized_texts: List[str], pred_text: str, remove_punct=False) -> List[Tuple[str, float]]:
  232. """
  233. Calculates character error rate (CER)
  234. Args:
  235. normalized_texts: normalized text options
  236. pred_text: ASR model output
  237. Returns: normalized options with corresponding CER
  238. """
  239. normalized_options = []
  240. for text in normalized_texts:
  241. text_clean = text.replace('-', ' ').lower()
  242. if remove_punct:
  243. for punct in "!?:;,.-()*+-/<=>@^_":
  244. text_clean = text_clean.replace(punct, "")
  245. cer = round(word_error_rate([pred_text], [text_clean], use_cer=True) * 100, 2)
  246. normalized_options.append((text, cer))
  247. return normalized_options
  248. def get_asr_model(asr_model):
  249. """
  250. Returns ASR Model
  251. Args:
  252. asr_model: NeMo ASR model
  253. """
  254. if os.path.exists(args.model):
  255. asr_model = ASRModel.restore_from(asr_model)
  256. elif args.model in ASRModel.get_available_model_names():
  257. asr_model = ASRModel.from_pretrained(asr_model)
  258. else:
  259. raise ValueError(
  260. f'Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}'
  261. )
  262. return asr_model
  263. def parse_args():
  264. parser = ArgumentParser()
  265. parser.add_argument("--text", help="input string or path to a .txt file", default=None, type=str)
  266. parser.add_argument(
  267. "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
  268. )
  269. parser.add_argument(
  270. "--language", help="Select target language", choices=["en", "ru", "de", "es"], default="en", type=str
  271. )
  272. parser.add_argument("--audio_data", default=None, help="path to an audio file or .json manifest")
  273. parser.add_argument(
  274. '--model', type=str, default='QuartzNet15x5Base-En', help='Pre-trained model name or path to model checkpoint'
  275. )
  276. parser.add_argument(
  277. "--n_tagged",
  278. type=int,
  279. default=30,
  280. help="number of tagged options to consider, -1 - return all possible tagged options",
  281. )
  282. parser.add_argument("--verbose", help="print info for debugging", action="store_true")
  283. parser.add_argument(
  284. "--no_remove_punct_for_cer",
  285. help="Set to True to NOT remove punctuation before calculating CER",
  286. action="store_true",
  287. )
  288. parser.add_argument(
  289. "--no_punct_post_process", help="set to True to disable punctuation post processing", action="store_true"
  290. )
  291. parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
  292. parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str)
  293. parser.add_argument(
  294. "--cache_dir",
  295. help="path to a dir with .far grammar file. Set to None to avoid using cache",
  296. default=None,
  297. type=str,
  298. )
  299. parser.add_argument("--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs")
  300. parser.add_argument(
  301. "--lm", action="store_true", help="Set to True for WFST+LM. Only available for English right now."
  302. )
  303. parser.add_argument(
  304. "--cer_threshold",
  305. default=100,
  306. type=int,
  307. help="if CER for pred_text is above the cer_threshold, no normalization will be performed",
  308. )
  309. parser.add_argument("--batch_size", default=200, type=int, help="Number of examples for each process")
  310. return parser.parse_args()
  311. def _normalize_line(
  312. normalizer: NormalizerWithAudio, n_tagged, verbose, line: str, remove_punct, punct_post_process, cer_threshold
  313. ):
  314. line = json.loads(line)
  315. pred_text = line["pred_text"]
  316. normalized_texts = normalizer.normalize(
  317. text=line["text"], verbose=verbose, n_tagged=n_tagged, punct_post_process=punct_post_process,
  318. )
  319. normalized_texts = set(normalized_texts)
  320. normalized_text, cer = normalizer.select_best_match(
  321. normalized_texts=normalized_texts,
  322. input_text=line["text"],
  323. pred_text=pred_text,
  324. verbose=verbose,
  325. remove_punct=remove_punct,
  326. cer_threshold=cer_threshold,
  327. )
  328. line["nemo_normalized"] = normalized_text
  329. line["CER_nemo_normalized"] = cer
  330. return line
  331. def normalize_manifest(
  332. normalizer,
  333. audio_data: str,
  334. n_jobs: int,
  335. n_tagged: int,
  336. remove_punct: bool,
  337. punct_post_process: bool,
  338. batch_size: int,
  339. cer_threshold: int,
  340. ):
  341. """
  342. Args:
  343. args.audio_data: path to .json manifest file.
  344. """
  345. def __process_batch(batch_idx: int, batch: List[str], dir_name: str):
  346. """
  347. Normalizes batch of text sequences
  348. Args:
  349. batch: list of texts
  350. batch_idx: batch index
  351. dir_name: path to output directory to save results
  352. """
  353. normalized_lines = [
  354. _normalize_line(
  355. normalizer,
  356. n_tagged,
  357. verbose=False,
  358. line=line,
  359. remove_punct=remove_punct,
  360. punct_post_process=punct_post_process,
  361. cer_threshold=cer_threshold,
  362. )
  363. for line in tqdm(batch)
  364. ]
  365. with open(f"{dir_name}/{batch_idx:05}.json", "w") as f_out:
  366. for line in normalized_lines:
  367. f_out.write(json.dumps(line, ensure_ascii=False) + '\n')
  368. print(f"Batch -- {batch_idx} -- is complete")
  369. manifest_out = audio_data.replace('.json', '_normalized.json')
  370. with open(audio_data, 'r') as f:
  371. lines = f.readlines()
  372. print(f'Normalizing {len(lines)} lines of {audio_data}...')
  373. # to save intermediate results to a file
  374. batch = min(len(lines), batch_size)
  375. tmp_dir = manifest_out.replace(".json", "_parts")
  376. os.makedirs(tmp_dir, exist_ok=True)
  377. Parallel(n_jobs=n_jobs)(
  378. delayed(__process_batch)(idx, lines[i : i + batch], tmp_dir)
  379. for idx, i in enumerate(range(0, len(lines), batch))
  380. )
  381. # aggregate all intermediate files
  382. with open(manifest_out, "w") as f_out:
  383. for batch_f in sorted(glob(f"{tmp_dir}/*.json")):
  384. with open(batch_f, "r") as f_in:
  385. lines = f_in.read()
  386. f_out.write(lines)
  387. print(f'Normalized version saved at {manifest_out}')
  388. if __name__ == "__main__":
  389. args = parse_args()
  390. if not ASR_AVAILABLE and args.audio_data:
  391. raise ValueError("NeMo ASR collection is not installed.")
  392. start = time.time()
  393. args.whitelist = os.path.abspath(args.whitelist) if args.whitelist else None
  394. if args.text is not None:
  395. normalizer = NormalizerWithAudio(
  396. input_case=args.input_case,
  397. lang=args.language,
  398. cache_dir=args.cache_dir,
  399. overwrite_cache=args.overwrite_cache,
  400. whitelist=args.whitelist,
  401. lm=args.lm,
  402. )
  403. if os.path.exists(args.text):
  404. with open(args.text, 'r') as f:
  405. args.text = f.read().strip()
  406. normalized_texts = normalizer.normalize(
  407. text=args.text,
  408. verbose=args.verbose,
  409. n_tagged=args.n_tagged,
  410. punct_post_process=not args.no_punct_post_process,
  411. )
  412. if not normalizer.lm:
  413. normalized_texts = set(normalized_texts)
  414. if args.audio_data:
  415. asr_model = get_asr_model(args.model)
  416. pred_text = asr_model.transcribe([args.audio_data])[0]
  417. normalized_text, cer = normalizer.select_best_match(
  418. normalized_texts=normalized_texts,
  419. pred_text=pred_text,
  420. input_text=args.text,
  421. verbose=args.verbose,
  422. remove_punct=not args.no_remove_punct_for_cer,
  423. cer_threshold=args.cer_threshold,
  424. )
  425. print(f"Transcript: {pred_text}")
  426. print(f"Normalized: {normalized_text}")
  427. else:
  428. print("Normalization options:")
  429. for norm_text in normalized_texts:
  430. print(norm_text)
  431. elif not os.path.exists(args.audio_data):
  432. raise ValueError(f"{args.audio_data} not found.")
  433. elif args.audio_data.endswith('.json'):
  434. normalizer = NormalizerWithAudio(
  435. input_case=args.input_case,
  436. lang=args.language,
  437. cache_dir=args.cache_dir,
  438. overwrite_cache=args.overwrite_cache,
  439. whitelist=args.whitelist,
  440. )
  441. normalize_manifest(
  442. normalizer=normalizer,
  443. audio_data=args.audio_data,
  444. n_jobs=args.n_jobs,
  445. n_tagged=args.n_tagged,
  446. remove_punct=not args.no_remove_punct_for_cer,
  447. punct_post_process=not args.no_punct_post_process,
  448. batch_size=args.batch_size,
  449. cer_threshold=args.cer_threshold,
  450. )
  451. else:
  452. raise ValueError(
  453. "Provide either path to .json manifest in '--audio_data' OR "
  454. + "'--audio_data' path to audio file and '--text' path to a text file OR"
  455. "'--text' string text (for debugging without audio)"
  456. )
  457. print(f'Execution time: {round((time.time() - start)/60, 2)} min.')