Kaynağa Gözat

support paraformer-large-contextual with vad and punc model

lzr265946 3 yıl önce
ebeveyn
işleme
267e2d09e6

+ 6 - 0
funasr/bin/asr_inference_paraformer_vad.py

@@ -167,6 +167,11 @@ def inference_modelscope(
         level=log_level,
         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
     )
+
+    if param_dict is not None:
+        hotword_list_or_file = param_dict.get('hotword')
+    else:
+        hotword_list_or_file = None
     
     if ngpu >= 1 and torch.cuda.is_available():
         device = "cuda"
@@ -206,6 +211,7 @@ def inference_modelscope(
         ngram_weight=ngram_weight,
         penalty=penalty,
         nbest=nbest,
+        hotword_list_or_file=hotword_list_or_file,
     )
     speech2text = Speech2Text(**speech2text_kwargs)
     text2punc = None

+ 78 - 7
funasr/bin/asr_inference_paraformer_vad_punc.py

@@ -5,6 +5,10 @@ import argparse
 import logging
 import sys
 import time
+import os
+import codecs
+import tempfile
+import requests
 from pathlib import Path
 from typing import Optional
 from typing import Sequence
@@ -41,7 +45,7 @@ from funasr.models.frontend.wav_frontend import WavFrontend
 from funasr.tasks.vad import VADTask
 from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
 from funasr.bin.punctuation_infer import Text2Punc
-from funasr.models.e2e_asr_paraformer import BiCifParaformer
+from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
 
 header_colors = '\033[95m'
 end_colors = '\033[0m'
@@ -79,6 +83,7 @@ class Speech2Text:
             penalty: float = 0.0,
             nbest: int = 1,
             frontend_conf: dict = None,
+            hotword_list_or_file: str = None,
             **kwargs,
     ):
         assert check_argument_types()
@@ -169,6 +174,58 @@ class Speech2Text:
         self.asr_train_args = asr_train_args
         self.converter = converter
         self.tokenizer = tokenizer
+
+        # 6. [Optional] Build hotword list from str, local file or url
+        # for None
+        if hotword_list_or_file is None:
+            self.hotword_list = None
+        # for text str input
+        elif not os.path.exists(hotword_list_or_file) and not hotword_list_or_file.startswith('http'):
+            logging.info("Attempting to parse hotwords as str...")
+            self.hotword_list = []
+            hotword_str_list = []
+            for hw in hotword_list_or_file.strip().split():
+                hotword_str_list.append(hw)
+                self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+            self.hotword_list.append([self.asr_model.sos])
+            hotword_str_list.append('<s>')
+            logging.info("Hotword list: {}.".format(hotword_str_list))
+        # for local txt inputs
+        elif os.path.exists(hotword_list_or_file):
+            logging.info("Attempting to parse hotwords from local txt...")
+            self.hotword_list = []
+            hotword_str_list = []
+            with codecs.open(hotword_list_or_file, 'r') as fin:
+                for line in fin.readlines():
+                    hw = line.strip()
+                    hotword_str_list.append(hw)
+                    self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                self.hotword_list.append([self.asr_model.sos])
+                hotword_str_list.append('<s>')
+            logging.info("Initialized hotword list from file: {}, hotword list: {}."
+                .format(hotword_list_or_file, hotword_str_list))
+        # for url, download and generate txt
+        else:
+            logging.info("Attempting to parse hotwords from url...")
+            work_dir = tempfile.TemporaryDirectory().name
+            if not os.path.exists(work_dir):
+                os.makedirs(work_dir)
+            text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
+            local_file = requests.get(hotword_list_or_file)
+            open(text_file_path, "wb").write(local_file.content)
+            hotword_list_or_file = text_file_path
+            self.hotword_list = []
+            hotword_str_list = []
+            with codecs.open(hotword_list_or_file, 'r') as fin:
+                for line in fin.readlines():
+                    hw = line.strip()
+                    hotword_str_list.append(hw)
+                    self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                self.hotword_list.append([self.asr_model.sos])
+                hotword_str_list.append('<s>')
+            logging.info("Initialized hotword list from file: {}, hotword list: {}."
+                .format(hotword_list_or_file, hotword_str_list))
+
         is_use_lm = lm_weight != 0.0 and lm_file is not None
         if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
             beam_search = None
@@ -233,8 +290,15 @@ class Speech2Text:
         pre_token_length = pre_token_length.round().long()
         if torch.max(pre_token_length) < 1:
             return []
-        decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
-        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+        if not isinstance(self.asr_model, ContextualParaformer):
+            if self.hotword_list:
+                logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
+            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+        else:
+            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
+            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
 
         if isinstance(self.asr_model, BiCifParaformer):
             _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
@@ -282,10 +346,11 @@ class Speech2Text:
                 else:
                     text = None
 
-
-                timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
-                results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
-
+                if isinstance(self.asr_model, BiCifParaformer):
+                    timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
+                    results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
+                else:
+                    results.append((text, token, token_int, enc_len_batch_total, lfr_factor))
 
         # assert check_return_type(results)
         return results
@@ -512,6 +577,11 @@ def inference_modelscope(
         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
     )
 
+    if param_dict is not None:
+        hotword_list_or_file = param_dict.get('hotword')
+    else:
+        hotword_list_or_file = None
+
     if ngpu >= 1 and torch.cuda.is_available():
         device = "cuda"
     else:
@@ -550,6 +620,7 @@ def inference_modelscope(
         ngram_weight=ngram_weight,
         penalty=penalty,
         nbest=nbest,
+        hotword_list_or_file=hotword_list_or_file,
     )
     speech2text = Speech2Text(**speech2text_kwargs)
     text2punc = None