| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530 |
- import json
- import os
- import time
- from argparse import ArgumentParser
- from glob import glob
- from typing import List, Tuple
- import pynini
- from joblib import Parallel, delayed
- from fun_text_processing.text_normalization.data_loader_utils import post_process_punct, pre_process
- from fun_text_processing.text_normalization.normalize import Normalizer
- from pynini.lib import rewrite
- from tqdm import tqdm
- try:
- from nemo.collections.asr.metrics.wer import word_error_rate
- from nemo.collections.asr.models import ASRModel
- ASR_AVAILABLE = True
- except (ModuleNotFoundError, ImportError):
- ASR_AVAILABLE = False
- """
- The script provides multiple normalization options and chooses the best one that minimizes CER of the ASR output
- (most of the semiotic classes use deterministic=False flag).
- To run this script with a .json manifest file, the manifest file should contain the following fields:
- "audio_data" - path to the audio file
- "text" - raw text
- "pred_text" - ASR model prediction
- See https://github.com/NVIDIA/NeMo/blob/main/examples/asr/transcribe_speech.py on how to add ASR predictions
- When the manifest is ready, run:
- python normalize_with_audio.py \
- --audio_data PATH/TO/MANIFEST.JSON \
- --language en
- To run with a single audio file, specify path to audio and text with:
- python normalize_with_audio.py \
- --audio_data PATH/TO/AUDIO.WAV \
- --language en \
- --text raw text OR PATH/TO/.TXT/FILE
- --model QuartzNet15x5Base-En \
- --verbose
- To see possible normalization options for a text input without an audio file (could be used for debugging), run:
- python python normalize_with_audio.py --text "RAW TEXT"
- Specify `--cache_dir` to generate .far grammars once and re-used them for faster inference
- """
- class NormalizerWithAudio(Normalizer):
- """
- Normalizer class that converts text from written to spoken form.
- Useful for TTS preprocessing.
- Args:
- input_case: expected input capitalization
- lang: language
- cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
- overwrite_cache: set to True to overwrite .far files
- whitelist: path to a file with whitelist replacements
- post_process: WFST-based post processing, e.g. to remove extra spaces added during TN.
- Note: punct_post_process flag in normalize() supports all languages.
- """
- def __init__(
- self,
- input_case: str,
- lang: str = 'en',
- cache_dir: str = None,
- overwrite_cache: bool = False,
- whitelist: str = None,
- lm: bool = False,
- post_process: bool = True,
- ):
- super().__init__(
- input_case=input_case,
- lang=lang,
- deterministic=False,
- cache_dir=cache_dir,
- overwrite_cache=overwrite_cache,
- whitelist=whitelist,
- lm=lm,
- post_process=post_process,
- )
- self.lm = lm
- def normalize(self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False,) -> str:
- """
- Main function. Normalizes tokens from written to spoken form
- e.g. 12 kg -> twelve kilograms
- Args:
- text: string that may include semiotic classes
- n_tagged: number of tagged options to consider, -1 - to get all possible tagged options
- punct_post_process: whether to normalize punctuation
- verbose: whether to print intermediate meta information
- Returns:
- normalized text options (usually there are multiple ways of normalizing a given semiotic class)
- """
- if len(text.split()) > 500:
- raise ValueError(
- "Your input is too long. Please split up the input into sentences, "
- "or strings with fewer than 500 words"
- )
- original_text = text
- text = pre_process(text) # to handle []
- text = text.strip()
- if not text:
- if verbose:
- print(text)
- return text
- text = pynini.escape(text)
- print(text)
- if self.lm:
- if self.lang not in ["en"]:
- raise ValueError(f"{self.lang} is not supported in LM mode")
- if self.lang == "en":
- # this to keep arpabet phonemes in the list of options
- if "[" in text and "]" in text:
- lattice = rewrite.rewrite_lattice(text, self.tagger.fst)
- else:
- try:
- lattice = rewrite.rewrite_lattice(text, self.tagger.fst_no_digits)
- except pynini.lib.rewrite.Error:
- lattice = rewrite.rewrite_lattice(text, self.tagger.fst)
- lattice = rewrite.lattice_to_nshortest(lattice, n_tagged)
- tagged_texts = [(x[1], float(x[2])) for x in lattice.paths().items()]
- tagged_texts.sort(key=lambda x: x[1])
- tagged_texts, weights = list(zip(*tagged_texts))
- else:
- tagged_texts = self._get_tagged_text(text, n_tagged)
- # non-deterministic Eng normalization uses tagger composed with verbalizer, no permutation in between
- if self.lang == "en":
- normalized_texts = tagged_texts
- normalized_texts = [self.post_process(text) for text in normalized_texts]
- else:
- normalized_texts = []
- for tagged_text in tagged_texts:
- self._verbalize(tagged_text, normalized_texts, verbose=verbose)
- if len(normalized_texts) == 0:
- raise ValueError()
- if punct_post_process:
- # do post-processing based on Moses detokenizer
- if self.processor:
- normalized_texts = [self.processor.detokenize([t]) for t in normalized_texts]
- normalized_texts = [
- post_process_punct(input=original_text, normalized_text=t) for t in normalized_texts
- ]
- if self.lm:
- remove_dup = sorted(list(set(zip(normalized_texts, weights))), key=lambda x: x[1])
- normalized_texts, weights = zip(*remove_dup)
- return list(normalized_texts), weights
- normalized_texts = set(normalized_texts)
- return normalized_texts
- def _get_tagged_text(self, text, n_tagged):
- """
- Returns text after tokenize and classify
- Args;
- text: input text
- n_tagged: number of tagged options to consider, -1 - return all possible tagged options
- """
- if n_tagged == -1:
- if self.lang == "en":
- # this to keep arpabet phonemes in the list of options
- if "[" in text and "]" in text:
- tagged_texts = rewrite.rewrites(text, self.tagger.fst)
- else:
- try:
- tagged_texts = rewrite.rewrites(text, self.tagger.fst_no_digits)
- except pynini.lib.rewrite.Error:
- tagged_texts = rewrite.rewrites(text, self.tagger.fst)
- else:
- tagged_texts = rewrite.rewrites(text, self.tagger.fst)
- else:
- if self.lang == "en":
- # this to keep arpabet phonemes in the list of options
- if "[" in text and "]" in text:
- tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
- else:
- try:
- # try self.tagger graph that produces output without digits
- tagged_texts = rewrite.top_rewrites(text, self.tagger.fst_no_digits, nshortest=n_tagged)
- except pynini.lib.rewrite.Error:
- tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
- else:
- tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
- return tagged_texts
- def _verbalize(self, tagged_text: str, normalized_texts: List[str], verbose: bool = False):
- """
- Verbalizes tagged text
- Args:
- tagged_text: text with tags
- normalized_texts: list of possible normalization options
- verbose: if true prints intermediate classification results
- """
- def get_verbalized_text(tagged_text):
- return rewrite.rewrites(tagged_text, self.verbalizer.fst)
- self.parser(tagged_text)
- tokens = self.parser.parse()
- tags_reordered = self.generate_permutations(tokens)
- for tagged_text_reordered in tags_reordered:
- try:
- tagged_text_reordered = pynini.escape(tagged_text_reordered)
- normalized_texts.extend(get_verbalized_text(tagged_text_reordered))
- if verbose:
- print(tagged_text_reordered)
- except pynini.lib.rewrite.Error:
- continue
- def select_best_match(
- self,
- normalized_texts: List[str],
- input_text: str,
- pred_text: str,
- verbose: bool = False,
- remove_punct: bool = False,
- cer_threshold: int = 100,
- ):
- """
- Selects the best normalization option based on the lowest CER
- Args:
- normalized_texts: normalized text options
- input_text: input text
- pred_text: ASR model transcript of the audio file corresponding to the normalized text
- verbose: whether to print intermediate meta information
- remove_punct: whether to remove punctuation before calculating CER
- cer_threshold: if CER for pred_text is above the cer_threshold, no normalization will be performed
- Returns:
- normalized text with the lowest CER and CER value
- """
- if pred_text == "":
- return input_text, cer_threshold
- normalized_texts_cer = calculate_cer(normalized_texts, pred_text, remove_punct)
- normalized_texts_cer = sorted(normalized_texts_cer, key=lambda x: x[1])
- normalized_text, cer = normalized_texts_cer[0]
- if cer > cer_threshold:
- return input_text, cer
- if verbose:
- print('-' * 30)
- for option in normalized_texts:
- print(option)
- print('-' * 30)
- return normalized_text, cer
- def calculate_cer(normalized_texts: List[str], pred_text: str, remove_punct=False) -> List[Tuple[str, float]]:
- """
- Calculates character error rate (CER)
- Args:
- normalized_texts: normalized text options
- pred_text: ASR model output
- Returns: normalized options with corresponding CER
- """
- normalized_options = []
- for text in normalized_texts:
- text_clean = text.replace('-', ' ').lower()
- if remove_punct:
- for punct in "!?:;,.-()*+-/<=>@^_":
- text_clean = text_clean.replace(punct, "")
- cer = round(word_error_rate([pred_text], [text_clean], use_cer=True) * 100, 2)
- normalized_options.append((text, cer))
- return normalized_options
- def get_asr_model(asr_model):
- """
- Returns ASR Model
- Args:
- asr_model: NeMo ASR model
- """
- if os.path.exists(args.model):
- asr_model = ASRModel.restore_from(asr_model)
- elif args.model in ASRModel.get_available_model_names():
- asr_model = ASRModel.from_pretrained(asr_model)
- else:
- raise ValueError(
- f'Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}'
- )
- return asr_model
- def parse_args():
- parser = ArgumentParser()
- parser.add_argument("--text", help="input string or path to a .txt file", default=None, type=str)
- parser.add_argument(
- "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
- )
- parser.add_argument(
- "--language", help="Select target language", choices=["en", "ru", "de", "es"], default="en", type=str
- )
- parser.add_argument("--audio_data", default=None, help="path to an audio file or .json manifest")
- parser.add_argument(
- '--model', type=str, default='QuartzNet15x5Base-En', help='Pre-trained model name or path to model checkpoint'
- )
- parser.add_argument(
- "--n_tagged",
- type=int,
- default=30,
- help="number of tagged options to consider, -1 - return all possible tagged options",
- )
- parser.add_argument("--verbose", help="print info for debugging", action="store_true")
- parser.add_argument(
- "--no_remove_punct_for_cer",
- help="Set to True to NOT remove punctuation before calculating CER",
- action="store_true",
- )
- parser.add_argument(
- "--no_punct_post_process", help="set to True to disable punctuation post processing", action="store_true"
- )
- parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
- parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str)
- parser.add_argument(
- "--cache_dir",
- help="path to a dir with .far grammar file. Set to None to avoid using cache",
- default=None,
- type=str,
- )
- parser.add_argument("--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs")
- parser.add_argument(
- "--lm", action="store_true", help="Set to True for WFST+LM. Only available for English right now."
- )
- parser.add_argument(
- "--cer_threshold",
- default=100,
- type=int,
- help="if CER for pred_text is above the cer_threshold, no normalization will be performed",
- )
- parser.add_argument("--batch_size", default=200, type=int, help="Number of examples for each process")
- return parser.parse_args()
- def _normalize_line(
- normalizer: NormalizerWithAudio, n_tagged, verbose, line: str, remove_punct, punct_post_process, cer_threshold
- ):
- line = json.loads(line)
- pred_text = line["pred_text"]
- normalized_texts = normalizer.normalize(
- text=line["text"], verbose=verbose, n_tagged=n_tagged, punct_post_process=punct_post_process,
- )
- normalized_texts = set(normalized_texts)
- normalized_text, cer = normalizer.select_best_match(
- normalized_texts=normalized_texts,
- input_text=line["text"],
- pred_text=pred_text,
- verbose=verbose,
- remove_punct=remove_punct,
- cer_threshold=cer_threshold,
- )
- line["nemo_normalized"] = normalized_text
- line["CER_nemo_normalized"] = cer
- return line
- def normalize_manifest(
- normalizer,
- audio_data: str,
- n_jobs: int,
- n_tagged: int,
- remove_punct: bool,
- punct_post_process: bool,
- batch_size: int,
- cer_threshold: int,
- ):
- """
- Args:
- args.audio_data: path to .json manifest file.
- """
- def __process_batch(batch_idx: int, batch: List[str], dir_name: str):
- """
- Normalizes batch of text sequences
- Args:
- batch: list of texts
- batch_idx: batch index
- dir_name: path to output directory to save results
- """
- normalized_lines = [
- _normalize_line(
- normalizer,
- n_tagged,
- verbose=False,
- line=line,
- remove_punct=remove_punct,
- punct_post_process=punct_post_process,
- cer_threshold=cer_threshold,
- )
- for line in tqdm(batch)
- ]
- with open(f"{dir_name}/{batch_idx:05}.json", "w") as f_out:
- for line in normalized_lines:
- f_out.write(json.dumps(line, ensure_ascii=False) + '\n')
- print(f"Batch -- {batch_idx} -- is complete")
- manifest_out = audio_data.replace('.json', '_normalized.json')
- with open(audio_data, 'r') as f:
- lines = f.readlines()
- print(f'Normalizing {len(lines)} lines of {audio_data}...')
- # to save intermediate results to a file
- batch = min(len(lines), batch_size)
- tmp_dir = manifest_out.replace(".json", "_parts")
- os.makedirs(tmp_dir, exist_ok=True)
- Parallel(n_jobs=n_jobs)(
- delayed(__process_batch)(idx, lines[i : i + batch], tmp_dir)
- for idx, i in enumerate(range(0, len(lines), batch))
- )
- # aggregate all intermediate files
- with open(manifest_out, "w") as f_out:
- for batch_f in sorted(glob(f"{tmp_dir}/*.json")):
- with open(batch_f, "r") as f_in:
- lines = f_in.read()
- f_out.write(lines)
- print(f'Normalized version saved at {manifest_out}')
- if __name__ == "__main__":
- args = parse_args()
- if not ASR_AVAILABLE and args.audio_data:
- raise ValueError("NeMo ASR collection is not installed.")
- start = time.time()
- args.whitelist = os.path.abspath(args.whitelist) if args.whitelist else None
- if args.text is not None:
- normalizer = NormalizerWithAudio(
- input_case=args.input_case,
- lang=args.language,
- cache_dir=args.cache_dir,
- overwrite_cache=args.overwrite_cache,
- whitelist=args.whitelist,
- lm=args.lm,
- )
- if os.path.exists(args.text):
- with open(args.text, 'r') as f:
- args.text = f.read().strip()
- normalized_texts = normalizer.normalize(
- text=args.text,
- verbose=args.verbose,
- n_tagged=args.n_tagged,
- punct_post_process=not args.no_punct_post_process,
- )
- if not normalizer.lm:
- normalized_texts = set(normalized_texts)
- if args.audio_data:
- asr_model = get_asr_model(args.model)
- pred_text = asr_model.transcribe([args.audio_data])[0]
- normalized_text, cer = normalizer.select_best_match(
- normalized_texts=normalized_texts,
- pred_text=pred_text,
- input_text=args.text,
- verbose=args.verbose,
- remove_punct=not args.no_remove_punct_for_cer,
- cer_threshold=args.cer_threshold,
- )
- print(f"Transcript: {pred_text}")
- print(f"Normalized: {normalized_text}")
- else:
- print("Normalization options:")
- for norm_text in normalized_texts:
- print(norm_text)
- elif not os.path.exists(args.audio_data):
- raise ValueError(f"{args.audio_data} not found.")
- elif args.audio_data.endswith('.json'):
- normalizer = NormalizerWithAudio(
- input_case=args.input_case,
- lang=args.language,
- cache_dir=args.cache_dir,
- overwrite_cache=args.overwrite_cache,
- whitelist=args.whitelist,
- )
- normalize_manifest(
- normalizer=normalizer,
- audio_data=args.audio_data,
- n_jobs=args.n_jobs,
- n_tagged=args.n_tagged,
- remove_punct=not args.no_remove_punct_for_cer,
- punct_post_process=not args.no_punct_post_process,
- batch_size=args.batch_size,
- cer_threshold=args.cer_threshold,
- )
- else:
- raise ValueError(
- "Provide either path to .json manifest in '--audio_data' OR "
- + "'--audio_data' path to audio file and '--text' path to a text file OR"
- "'--text' string text (for debugging without audio)"
- )
- print(f'Execution time: {round((time.time() - start)/60, 2)} min.')
|