normalize_with_audio.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. # Copyright NeMo (https://github.com/NVIDIA/NeMo). All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import os
  16. import time
  17. from argparse import ArgumentParser
  18. from glob import glob
  19. from typing import List, Tuple
  20. import pynini
  21. from joblib import Parallel, delayed
  22. from fun_text_processing.text_normalization.data_loader_utils import post_process_punct, pre_process
  23. from fun_text_processing.text_normalization.normalize import Normalizer
  24. from pynini.lib import rewrite
  25. from tqdm import tqdm
  26. try:
  27. from nemo.collections.asr.metrics.wer import word_error_rate
  28. from nemo.collections.asr.models import ASRModel
  29. ASR_AVAILABLE = True
  30. except (ModuleNotFoundError, ImportError):
  31. ASR_AVAILABLE = False
  32. """
  33. The script provides multiple normalization options and chooses the best one that minimizes CER of the ASR output
  34. (most of the semiotic classes use deterministic=False flag).
  35. To run this script with a .json manifest file, the manifest file should contain the following fields:
  36. "audio_data" - path to the audio file
  37. "text" - raw text
  38. "pred_text" - ASR model prediction
  39. See https://github.com/NVIDIA/NeMo/blob/main/examples/asr/transcribe_speech.py on how to add ASR predictions
  40. When the manifest is ready, run:
  41. python normalize_with_audio.py \
  42. --audio_data PATH/TO/MANIFEST.JSON \
  43. --language en
  44. To run with a single audio file, specify path to audio and text with:
  45. python normalize_with_audio.py \
  46. --audio_data PATH/TO/AUDIO.WAV \
  47. --language en \
  48. --text raw text OR PATH/TO/.TXT/FILE
  49. --model QuartzNet15x5Base-En \
  50. --verbose
  51. To see possible normalization options for a text input without an audio file (could be used for debugging), run:
  52. python python normalize_with_audio.py --text "RAW TEXT"
  53. Specify `--cache_dir` to generate .far grammars once and re-used them for faster inference
  54. """
  55. class NormalizerWithAudio(Normalizer):
  56. """
  57. Normalizer class that converts text from written to spoken form.
  58. Useful for TTS preprocessing.
  59. Args:
  60. input_case: expected input capitalization
  61. lang: language
  62. cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
  63. overwrite_cache: set to True to overwrite .far files
  64. whitelist: path to a file with whitelist replacements
  65. post_process: WFST-based post processing, e.g. to remove extra spaces added during TN.
  66. Note: punct_post_process flag in normalize() supports all languages.
  67. """
  68. def __init__(
  69. self,
  70. input_case: str,
  71. lang: str = 'en',
  72. cache_dir: str = None,
  73. overwrite_cache: bool = False,
  74. whitelist: str = None,
  75. lm: bool = False,
  76. post_process: bool = True,
  77. ):
  78. super().__init__(
  79. input_case=input_case,
  80. lang=lang,
  81. deterministic=False,
  82. cache_dir=cache_dir,
  83. overwrite_cache=overwrite_cache,
  84. whitelist=whitelist,
  85. lm=lm,
  86. post_process=post_process,
  87. )
  88. self.lm = lm
  89. def normalize(self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False,) -> str:
  90. """
  91. Main function. Normalizes tokens from written to spoken form
  92. e.g. 12 kg -> twelve kilograms
  93. Args:
  94. text: string that may include semiotic classes
  95. n_tagged: number of tagged options to consider, -1 - to get all possible tagged options
  96. punct_post_process: whether to normalize punctuation
  97. verbose: whether to print intermediate meta information
  98. Returns:
  99. normalized text options (usually there are multiple ways of normalizing a given semiotic class)
  100. """
  101. if len(text.split()) > 500:
  102. raise ValueError(
  103. "Your input is too long. Please split up the input into sentences, "
  104. "or strings with fewer than 500 words"
  105. )
  106. original_text = text
  107. text = pre_process(text) # to handle []
  108. text = text.strip()
  109. if not text:
  110. if verbose:
  111. print(text)
  112. return text
  113. text = pynini.escape(text)
  114. print(text)
  115. if self.lm:
  116. if self.lang not in ["en"]:
  117. raise ValueError(f"{self.lang} is not supported in LM mode")
  118. if self.lang == "en":
  119. # this to keep arpabet phonemes in the list of options
  120. if "[" in text and "]" in text:
  121. lattice = rewrite.rewrite_lattice(text, self.tagger.fst)
  122. else:
  123. try:
  124. lattice = rewrite.rewrite_lattice(text, self.tagger.fst_no_digits)
  125. except pynini.lib.rewrite.Error:
  126. lattice = rewrite.rewrite_lattice(text, self.tagger.fst)
  127. lattice = rewrite.lattice_to_nshortest(lattice, n_tagged)
  128. tagged_texts = [(x[1], float(x[2])) for x in lattice.paths().items()]
  129. tagged_texts.sort(key=lambda x: x[1])
  130. tagged_texts, weights = list(zip(*tagged_texts))
  131. else:
  132. tagged_texts = self._get_tagged_text(text, n_tagged)
  133. # non-deterministic Eng normalization uses tagger composed with verbalizer, no permutation in between
  134. if self.lang == "en":
  135. normalized_texts = tagged_texts
  136. normalized_texts = [self.post_process(text) for text in normalized_texts]
  137. else:
  138. normalized_texts = []
  139. for tagged_text in tagged_texts:
  140. self._verbalize(tagged_text, normalized_texts, verbose=verbose)
  141. if len(normalized_texts) == 0:
  142. raise ValueError()
  143. if punct_post_process:
  144. # do post-processing based on Moses detokenizer
  145. if self.processor:
  146. normalized_texts = [self.processor.detokenize([t]) for t in normalized_texts]
  147. normalized_texts = [
  148. post_process_punct(input=original_text, normalized_text=t) for t in normalized_texts
  149. ]
  150. if self.lm:
  151. remove_dup = sorted(list(set(zip(normalized_texts, weights))), key=lambda x: x[1])
  152. normalized_texts, weights = zip(*remove_dup)
  153. return list(normalized_texts), weights
  154. normalized_texts = set(normalized_texts)
  155. return normalized_texts
  156. def _get_tagged_text(self, text, n_tagged):
  157. """
  158. Returns text after tokenize and classify
  159. Args;
  160. text: input text
  161. n_tagged: number of tagged options to consider, -1 - return all possible tagged options
  162. """
  163. if n_tagged == -1:
  164. if self.lang == "en":
  165. # this to keep arpabet phonemes in the list of options
  166. if "[" in text and "]" in text:
  167. tagged_texts = rewrite.rewrites(text, self.tagger.fst)
  168. else:
  169. try:
  170. tagged_texts = rewrite.rewrites(text, self.tagger.fst_no_digits)
  171. except pynini.lib.rewrite.Error:
  172. tagged_texts = rewrite.rewrites(text, self.tagger.fst)
  173. else:
  174. tagged_texts = rewrite.rewrites(text, self.tagger.fst)
  175. else:
  176. if self.lang == "en":
  177. # this to keep arpabet phonemes in the list of options
  178. if "[" in text and "]" in text:
  179. tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
  180. else:
  181. try:
  182. # try self.tagger graph that produces output without digits
  183. tagged_texts = rewrite.top_rewrites(text, self.tagger.fst_no_digits, nshortest=n_tagged)
  184. except pynini.lib.rewrite.Error:
  185. tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
  186. else:
  187. tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
  188. return tagged_texts
  189. def _verbalize(self, tagged_text: str, normalized_texts: List[str], verbose: bool = False):
  190. """
  191. Verbalizes tagged text
  192. Args:
  193. tagged_text: text with tags
  194. normalized_texts: list of possible normalization options
  195. verbose: if true prints intermediate classification results
  196. """
  197. def get_verbalized_text(tagged_text):
  198. return rewrite.rewrites(tagged_text, self.verbalizer.fst)
  199. self.parser(tagged_text)
  200. tokens = self.parser.parse()
  201. tags_reordered = self.generate_permutations(tokens)
  202. for tagged_text_reordered in tags_reordered:
  203. try:
  204. tagged_text_reordered = pynini.escape(tagged_text_reordered)
  205. normalized_texts.extend(get_verbalized_text(tagged_text_reordered))
  206. if verbose:
  207. print(tagged_text_reordered)
  208. except pynini.lib.rewrite.Error:
  209. continue
  210. def select_best_match(
  211. self,
  212. normalized_texts: List[str],
  213. input_text: str,
  214. pred_text: str,
  215. verbose: bool = False,
  216. remove_punct: bool = False,
  217. cer_threshold: int = 100,
  218. ):
  219. """
  220. Selects the best normalization option based on the lowest CER
  221. Args:
  222. normalized_texts: normalized text options
  223. input_text: input text
  224. pred_text: ASR model transcript of the audio file corresponding to the normalized text
  225. verbose: whether to print intermediate meta information
  226. remove_punct: whether to remove punctuation before calculating CER
  227. cer_threshold: if CER for pred_text is above the cer_threshold, no normalization will be performed
  228. Returns:
  229. normalized text with the lowest CER and CER value
  230. """
  231. if pred_text == "":
  232. return input_text, cer_threshold
  233. normalized_texts_cer = calculate_cer(normalized_texts, pred_text, remove_punct)
  234. normalized_texts_cer = sorted(normalized_texts_cer, key=lambda x: x[1])
  235. normalized_text, cer = normalized_texts_cer[0]
  236. if cer > cer_threshold:
  237. return input_text, cer
  238. if verbose:
  239. print('-' * 30)
  240. for option in normalized_texts:
  241. print(option)
  242. print('-' * 30)
  243. return normalized_text, cer
  244. def calculate_cer(normalized_texts: List[str], pred_text: str, remove_punct=False) -> List[Tuple[str, float]]:
  245. """
  246. Calculates character error rate (CER)
  247. Args:
  248. normalized_texts: normalized text options
  249. pred_text: ASR model output
  250. Returns: normalized options with corresponding CER
  251. """
  252. normalized_options = []
  253. for text in normalized_texts:
  254. text_clean = text.replace('-', ' ').lower()
  255. if remove_punct:
  256. for punct in "!?:;,.-()*+-/<=>@^_":
  257. text_clean = text_clean.replace(punct, "")
  258. cer = round(word_error_rate([pred_text], [text_clean], use_cer=True) * 100, 2)
  259. normalized_options.append((text, cer))
  260. return normalized_options
  261. def get_asr_model(asr_model):
  262. """
  263. Returns ASR Model
  264. Args:
  265. asr_model: NeMo ASR model
  266. """
  267. if os.path.exists(args.model):
  268. asr_model = ASRModel.restore_from(asr_model)
  269. elif args.model in ASRModel.get_available_model_names():
  270. asr_model = ASRModel.from_pretrained(asr_model)
  271. else:
  272. raise ValueError(
  273. f'Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}'
  274. )
  275. return asr_model
  276. def parse_args():
  277. parser = ArgumentParser()
  278. parser.add_argument("--text", help="input string or path to a .txt file", default=None, type=str)
  279. parser.add_argument(
  280. "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
  281. )
  282. parser.add_argument(
  283. "--language", help="Select target language", choices=["en", "ru", "de", "es"], default="en", type=str
  284. )
  285. parser.add_argument("--audio_data", default=None, help="path to an audio file or .json manifest")
  286. parser.add_argument(
  287. '--model', type=str, default='QuartzNet15x5Base-En', help='Pre-trained model name or path to model checkpoint'
  288. )
  289. parser.add_argument(
  290. "--n_tagged",
  291. type=int,
  292. default=30,
  293. help="number of tagged options to consider, -1 - return all possible tagged options",
  294. )
  295. parser.add_argument("--verbose", help="print info for debugging", action="store_true")
  296. parser.add_argument(
  297. "--no_remove_punct_for_cer",
  298. help="Set to True to NOT remove punctuation before calculating CER",
  299. action="store_true",
  300. )
  301. parser.add_argument(
  302. "--no_punct_post_process", help="set to True to disable punctuation post processing", action="store_true"
  303. )
  304. parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
  305. parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str)
  306. parser.add_argument(
  307. "--cache_dir",
  308. help="path to a dir with .far grammar file. Set to None to avoid using cache",
  309. default=None,
  310. type=str,
  311. )
  312. parser.add_argument("--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs")
  313. parser.add_argument(
  314. "--lm", action="store_true", help="Set to True for WFST+LM. Only available for English right now."
  315. )
  316. parser.add_argument(
  317. "--cer_threshold",
  318. default=100,
  319. type=int,
  320. help="if CER for pred_text is above the cer_threshold, no normalization will be performed",
  321. )
  322. parser.add_argument("--batch_size", default=200, type=int, help="Number of examples for each process")
  323. return parser.parse_args()
  324. def _normalize_line(
  325. normalizer: NormalizerWithAudio, n_tagged, verbose, line: str, remove_punct, punct_post_process, cer_threshold
  326. ):
  327. line = json.loads(line)
  328. pred_text = line["pred_text"]
  329. normalized_texts = normalizer.normalize(
  330. text=line["text"], verbose=verbose, n_tagged=n_tagged, punct_post_process=punct_post_process,
  331. )
  332. normalized_texts = set(normalized_texts)
  333. normalized_text, cer = normalizer.select_best_match(
  334. normalized_texts=normalized_texts,
  335. input_text=line["text"],
  336. pred_text=pred_text,
  337. verbose=verbose,
  338. remove_punct=remove_punct,
  339. cer_threshold=cer_threshold,
  340. )
  341. line["nemo_normalized"] = normalized_text
  342. line["CER_nemo_normalized"] = cer
  343. return line
  344. def normalize_manifest(
  345. normalizer,
  346. audio_data: str,
  347. n_jobs: int,
  348. n_tagged: int,
  349. remove_punct: bool,
  350. punct_post_process: bool,
  351. batch_size: int,
  352. cer_threshold: int,
  353. ):
  354. """
  355. Args:
  356. args.audio_data: path to .json manifest file.
  357. """
  358. def __process_batch(batch_idx: int, batch: List[str], dir_name: str):
  359. """
  360. Normalizes batch of text sequences
  361. Args:
  362. batch: list of texts
  363. batch_idx: batch index
  364. dir_name: path to output directory to save results
  365. """
  366. normalized_lines = [
  367. _normalize_line(
  368. normalizer,
  369. n_tagged,
  370. verbose=False,
  371. line=line,
  372. remove_punct=remove_punct,
  373. punct_post_process=punct_post_process,
  374. cer_threshold=cer_threshold,
  375. )
  376. for line in tqdm(batch)
  377. ]
  378. with open(f"{dir_name}/{batch_idx:05}.json", "w") as f_out:
  379. for line in normalized_lines:
  380. f_out.write(json.dumps(line, ensure_ascii=False) + '\n')
  381. print(f"Batch -- {batch_idx} -- is complete")
  382. manifest_out = audio_data.replace('.json', '_normalized.json')
  383. with open(audio_data, 'r') as f:
  384. lines = f.readlines()
  385. print(f'Normalizing {len(lines)} lines of {audio_data}...')
  386. # to save intermediate results to a file
  387. batch = min(len(lines), batch_size)
  388. tmp_dir = manifest_out.replace(".json", "_parts")
  389. os.makedirs(tmp_dir, exist_ok=True)
  390. Parallel(n_jobs=n_jobs)(
  391. delayed(__process_batch)(idx, lines[i : i + batch], tmp_dir)
  392. for idx, i in enumerate(range(0, len(lines), batch))
  393. )
  394. # aggregate all intermediate files
  395. with open(manifest_out, "w") as f_out:
  396. for batch_f in sorted(glob(f"{tmp_dir}/*.json")):
  397. with open(batch_f, "r") as f_in:
  398. lines = f_in.read()
  399. f_out.write(lines)
  400. print(f'Normalized version saved at {manifest_out}')
  401. if __name__ == "__main__":
  402. args = parse_args()
  403. if not ASR_AVAILABLE and args.audio_data:
  404. raise ValueError("NeMo ASR collection is not installed.")
  405. start = time.time()
  406. args.whitelist = os.path.abspath(args.whitelist) if args.whitelist else None
  407. if args.text is not None:
  408. normalizer = NormalizerWithAudio(
  409. input_case=args.input_case,
  410. lang=args.language,
  411. cache_dir=args.cache_dir,
  412. overwrite_cache=args.overwrite_cache,
  413. whitelist=args.whitelist,
  414. lm=args.lm,
  415. )
  416. if os.path.exists(args.text):
  417. with open(args.text, 'r') as f:
  418. args.text = f.read().strip()
  419. normalized_texts = normalizer.normalize(
  420. text=args.text,
  421. verbose=args.verbose,
  422. n_tagged=args.n_tagged,
  423. punct_post_process=not args.no_punct_post_process,
  424. )
  425. if not normalizer.lm:
  426. normalized_texts = set(normalized_texts)
  427. if args.audio_data:
  428. asr_model = get_asr_model(args.model)
  429. pred_text = asr_model.transcribe([args.audio_data])[0]
  430. normalized_text, cer = normalizer.select_best_match(
  431. normalized_texts=normalized_texts,
  432. pred_text=pred_text,
  433. input_text=args.text,
  434. verbose=args.verbose,
  435. remove_punct=not args.no_remove_punct_for_cer,
  436. cer_threshold=args.cer_threshold,
  437. )
  438. print(f"Transcript: {pred_text}")
  439. print(f"Normalized: {normalized_text}")
  440. else:
  441. print("Normalization options:")
  442. for norm_text in normalized_texts:
  443. print(norm_text)
  444. elif not os.path.exists(args.audio_data):
  445. raise ValueError(f"{args.audio_data} not found.")
  446. elif args.audio_data.endswith('.json'):
  447. normalizer = NormalizerWithAudio(
  448. input_case=args.input_case,
  449. lang=args.language,
  450. cache_dir=args.cache_dir,
  451. overwrite_cache=args.overwrite_cache,
  452. whitelist=args.whitelist,
  453. )
  454. normalize_manifest(
  455. normalizer=normalizer,
  456. audio_data=args.audio_data,
  457. n_jobs=args.n_jobs,
  458. n_tagged=args.n_tagged,
  459. remove_punct=not args.no_remove_punct_for_cer,
  460. punct_post_process=not args.no_punct_post_process,
  461. batch_size=args.batch_size,
  462. cer_threshold=args.cer_threshold,
  463. )
  464. else:
  465. raise ValueError(
  466. "Provide either path to .json manifest in '--audio_data' OR "
  467. + "'--audio_data' path to audio file and '--text' path to a text file OR"
  468. "'--text' string text (for debugging without audio)"
  469. )
  470. print(f'Execution time: {round((time.time() - start)/60, 2)} min.')