|
@@ -47,327 +47,323 @@ from funasr.bin.vad_inference import Speech2VadSegment
|
|
|
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
|
|
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
|
|
|
from funasr.bin.punctuation_infer import Text2Punc
|
|
from funasr.bin.punctuation_infer import Text2Punc
|
|
|
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
|
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-header_colors = '\033[95m'
|
|
|
|
|
-end_colors = '\033[0m'
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-class Speech2Text:
|
|
|
|
|
- """Speech2Text class
|
|
|
|
|
-
|
|
|
|
|
- Examples:
|
|
|
|
|
- >>> import soundfile
|
|
|
|
|
- >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
|
|
|
|
|
- >>> audio, rate = soundfile.read("speech.wav")
|
|
|
|
|
- >>> speech2text(audio)
|
|
|
|
|
- [(text, token, token_int, hypothesis object), ...]
|
|
|
|
|
-
|
|
|
|
|
- """
|
|
|
|
|
-
|
|
|
|
|
- def __init__(
|
|
|
|
|
- self,
|
|
|
|
|
- asr_train_config: Union[Path, str] = None,
|
|
|
|
|
- asr_model_file: Union[Path, str] = None,
|
|
|
|
|
- cmvn_file: Union[Path, str] = None,
|
|
|
|
|
- lm_train_config: Union[Path, str] = None,
|
|
|
|
|
- lm_file: Union[Path, str] = None,
|
|
|
|
|
- token_type: str = None,
|
|
|
|
|
- bpemodel: str = None,
|
|
|
|
|
- device: str = "cpu",
|
|
|
|
|
- maxlenratio: float = 0.0,
|
|
|
|
|
- minlenratio: float = 0.0,
|
|
|
|
|
- dtype: str = "float32",
|
|
|
|
|
- beam_size: int = 20,
|
|
|
|
|
- ctc_weight: float = 0.5,
|
|
|
|
|
- lm_weight: float = 1.0,
|
|
|
|
|
- ngram_weight: float = 0.9,
|
|
|
|
|
- penalty: float = 0.0,
|
|
|
|
|
- nbest: int = 1,
|
|
|
|
|
- frontend_conf: dict = None,
|
|
|
|
|
- hotword_list_or_file: str = None,
|
|
|
|
|
- **kwargs,
|
|
|
|
|
- ):
|
|
|
|
|
- assert check_argument_types()
|
|
|
|
|
-
|
|
|
|
|
- # 1. Build ASR model
|
|
|
|
|
- scorers = {}
|
|
|
|
|
- asr_model, asr_train_args = ASRTask.build_model_from_file(
|
|
|
|
|
- asr_train_config, asr_model_file, cmvn_file=cmvn_file, device=device
|
|
|
|
|
- )
|
|
|
|
|
- frontend = None
|
|
|
|
|
- if asr_model.frontend is not None and asr_train_args.frontend_conf is not None:
|
|
|
|
|
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
|
|
|
|
-
|
|
|
|
|
- # logging.info("asr_model: {}".format(asr_model))
|
|
|
|
|
- # logging.info("asr_train_args: {}".format(asr_train_args))
|
|
|
|
|
- asr_model.to(dtype=getattr(torch, dtype)).eval()
|
|
|
|
|
-
|
|
|
|
|
- if asr_model.ctc != None:
|
|
|
|
|
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
|
|
|
|
- scorers.update(
|
|
|
|
|
- ctc=ctc
|
|
|
|
|
- )
|
|
|
|
|
- token_list = asr_model.token_list
|
|
|
|
|
- scorers.update(
|
|
|
|
|
- length_bonus=LengthBonus(len(token_list)),
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- # 2. Build Language model
|
|
|
|
|
- if lm_train_config is not None:
|
|
|
|
|
- lm, lm_train_args = LMTask.build_model_from_file(
|
|
|
|
|
- lm_train_config, lm_file, device
|
|
|
|
|
- )
|
|
|
|
|
- scorers["lm"] = lm.lm
|
|
|
|
|
-
|
|
|
|
|
- # 3. Build ngram model
|
|
|
|
|
- # ngram is not supported now
|
|
|
|
|
- ngram = None
|
|
|
|
|
- scorers["ngram"] = ngram
|
|
|
|
|
-
|
|
|
|
|
- # 4. Build BeamSearch object
|
|
|
|
|
- # transducer is not supported now
|
|
|
|
|
- beam_search_transducer = None
|
|
|
|
|
-
|
|
|
|
|
- weights = dict(
|
|
|
|
|
- decoder=1.0 - ctc_weight,
|
|
|
|
|
- ctc=ctc_weight,
|
|
|
|
|
- lm=lm_weight,
|
|
|
|
|
- ngram=ngram_weight,
|
|
|
|
|
- length_bonus=penalty,
|
|
|
|
|
- )
|
|
|
|
|
- beam_search = BeamSearch(
|
|
|
|
|
- beam_size=beam_size,
|
|
|
|
|
- weights=weights,
|
|
|
|
|
- scorers=scorers,
|
|
|
|
|
- sos=asr_model.sos,
|
|
|
|
|
- eos=asr_model.eos,
|
|
|
|
|
- vocab_size=len(token_list),
|
|
|
|
|
- token_list=token_list,
|
|
|
|
|
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
|
|
|
|
|
- for scorer in scorers.values():
|
|
|
|
|
- if isinstance(scorer, torch.nn.Module):
|
|
|
|
|
- scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
|
|
|
|
|
-
|
|
|
|
|
- logging.info(f"Decoding device={device}, dtype={dtype}")
|
|
|
|
|
-
|
|
|
|
|
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
|
|
|
|
|
- if token_type is None:
|
|
|
|
|
- token_type = asr_train_args.token_type
|
|
|
|
|
- if bpemodel is None:
|
|
|
|
|
- bpemodel = asr_train_args.bpemodel
|
|
|
|
|
-
|
|
|
|
|
- if token_type is None:
|
|
|
|
|
- tokenizer = None
|
|
|
|
|
- elif token_type == "bpe":
|
|
|
|
|
- if bpemodel is not None:
|
|
|
|
|
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
|
|
|
|
|
- else:
|
|
|
|
|
- tokenizer = None
|
|
|
|
|
- else:
|
|
|
|
|
- tokenizer = build_tokenizer(token_type=token_type)
|
|
|
|
|
- converter = TokenIDConverter(token_list=token_list)
|
|
|
|
|
- logging.info(f"Text tokenizer: {tokenizer}")
|
|
|
|
|
-
|
|
|
|
|
- self.asr_model = asr_model
|
|
|
|
|
- self.asr_train_args = asr_train_args
|
|
|
|
|
- self.converter = converter
|
|
|
|
|
- self.tokenizer = tokenizer
|
|
|
|
|
-
|
|
|
|
|
- # 6. [Optional] Build hotword list from str, local file or url
|
|
|
|
|
- self.hotword_list = None
|
|
|
|
|
- self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
|
|
|
|
|
-
|
|
|
|
|
- is_use_lm = lm_weight != 0.0 and lm_file is not None
|
|
|
|
|
- if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
|
|
|
|
|
- beam_search = None
|
|
|
|
|
- self.beam_search = beam_search
|
|
|
|
|
- logging.info(f"Beam_search: {self.beam_search}")
|
|
|
|
|
- self.beam_search_transducer = beam_search_transducer
|
|
|
|
|
- self.maxlenratio = maxlenratio
|
|
|
|
|
- self.minlenratio = minlenratio
|
|
|
|
|
- self.device = device
|
|
|
|
|
- self.dtype = dtype
|
|
|
|
|
- self.nbest = nbest
|
|
|
|
|
- self.frontend = frontend
|
|
|
|
|
- self.encoder_downsampling_factor = 1
|
|
|
|
|
- if asr_train_args.encoder_conf["input_layer"] == "conv2d":
|
|
|
|
|
- self.encoder_downsampling_factor = 4
|
|
|
|
|
-
|
|
|
|
|
- @torch.no_grad()
|
|
|
|
|
- def __call__(
|
|
|
|
|
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
|
|
|
|
|
- begin_time: int = 0, end_time: int = None,
|
|
|
|
|
- ):
|
|
|
|
|
- """Inference
|
|
|
|
|
-
|
|
|
|
|
- Args:
|
|
|
|
|
- speech: Input speech data
|
|
|
|
|
- Returns:
|
|
|
|
|
- text, token, token_int, hyp
|
|
|
|
|
-
|
|
|
|
|
- """
|
|
|
|
|
- assert check_argument_types()
|
|
|
|
|
-
|
|
|
|
|
- # Input as audio signal
|
|
|
|
|
- if isinstance(speech, np.ndarray):
|
|
|
|
|
- speech = torch.tensor(speech)
|
|
|
|
|
-
|
|
|
|
|
- if self.frontend is not None:
|
|
|
|
|
- # feats, feats_len = self.frontend.forward(speech, speech_lengths)
|
|
|
|
|
- # fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
|
|
|
|
|
- feats, feats_len = self.frontend.forward_lfr_cmvn(speech, speech_lengths)
|
|
|
|
|
- feats = to_device(feats, device=self.device)
|
|
|
|
|
- feats_len = feats_len.int()
|
|
|
|
|
- self.asr_model.frontend = None
|
|
|
|
|
- else:
|
|
|
|
|
- feats = speech
|
|
|
|
|
- feats_len = speech_lengths
|
|
|
|
|
- lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
|
|
|
|
|
- batch = {"speech": feats, "speech_lengths": feats_len}
|
|
|
|
|
-
|
|
|
|
|
- # a. To device
|
|
|
|
|
- batch = to_device(batch, device=self.device)
|
|
|
|
|
-
|
|
|
|
|
- # b. Forward Encoder
|
|
|
|
|
- enc, enc_len = self.asr_model.encode(**batch)
|
|
|
|
|
- if isinstance(enc, tuple):
|
|
|
|
|
- enc = enc[0]
|
|
|
|
|
- # assert len(enc) == 1, len(enc)
|
|
|
|
|
- enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
|
|
|
|
|
-
|
|
|
|
|
- predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
|
|
|
|
|
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
|
|
|
|
|
- predictor_outs[2], predictor_outs[3]
|
|
|
|
|
- pre_token_length = pre_token_length.round().long()
|
|
|
|
|
- if torch.max(pre_token_length) < 1:
|
|
|
|
|
- return []
|
|
|
|
|
-
|
|
|
|
|
- if not isinstance(self.asr_model, ContextualParaformer):
|
|
|
|
|
- if self.hotword_list:
|
|
|
|
|
- logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
|
|
|
|
|
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
|
|
|
|
|
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
|
|
|
|
- else:
|
|
|
|
|
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
|
|
|
|
|
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
|
|
|
|
-
|
|
|
|
|
- if isinstance(self.asr_model, BiCifParaformer):
|
|
|
|
|
- _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
|
|
|
|
|
- pre_token_length) # test no bias cif2
|
|
|
|
|
-
|
|
|
|
|
- results = []
|
|
|
|
|
- b, n, d = decoder_out.size()
|
|
|
|
|
- for i in range(b):
|
|
|
|
|
- x = enc[i, :enc_len[i], :]
|
|
|
|
|
- am_scores = decoder_out[i, :pre_token_length[i], :]
|
|
|
|
|
- if self.beam_search is not None:
|
|
|
|
|
- nbest_hyps = self.beam_search(
|
|
|
|
|
- x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- nbest_hyps = nbest_hyps[: self.nbest]
|
|
|
|
|
- else:
|
|
|
|
|
- yseq = am_scores.argmax(dim=-1)
|
|
|
|
|
- score = am_scores.max(dim=-1)[0]
|
|
|
|
|
- score = torch.sum(score, dim=-1)
|
|
|
|
|
- # pad with mask tokens to ensure compatibility with sos/eos tokens
|
|
|
|
|
- yseq = torch.tensor(
|
|
|
|
|
- [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
|
|
|
|
|
- )
|
|
|
|
|
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
|
|
|
|
|
-
|
|
|
|
|
- for hyp in nbest_hyps:
|
|
|
|
|
- assert isinstance(hyp, (Hypothesis)), type(hyp)
|
|
|
|
|
-
|
|
|
|
|
- # remove sos/eos and get results
|
|
|
|
|
- last_pos = -1
|
|
|
|
|
- if isinstance(hyp.yseq, list):
|
|
|
|
|
- token_int = hyp.yseq[1:last_pos]
|
|
|
|
|
- else:
|
|
|
|
|
- token_int = hyp.yseq[1:last_pos].tolist()
|
|
|
|
|
-
|
|
|
|
|
- # remove blank symbol id, which is assumed to be 0
|
|
|
|
|
- token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
|
|
|
|
|
- if len(token_int) == 0:
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- # Change integer-ids to tokens
|
|
|
|
|
- token = self.converter.ids2tokens(token_int)
|
|
|
|
|
-
|
|
|
|
|
- if self.tokenizer is not None:
|
|
|
|
|
- text = self.tokenizer.tokens2text(token)
|
|
|
|
|
- else:
|
|
|
|
|
- text = None
|
|
|
|
|
-
|
|
|
|
|
- if isinstance(self.asr_model, BiCifParaformer):
|
|
|
|
|
- _, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
|
|
|
|
|
- us_peaks[i],
|
|
|
|
|
- copy.copy(token),
|
|
|
|
|
- vad_offset=begin_time)
|
|
|
|
|
- results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
|
|
|
|
|
- else:
|
|
|
|
|
- results.append((text, token, token_int, enc_len_batch_total, lfr_factor))
|
|
|
|
|
-
|
|
|
|
|
- # assert check_return_type(results)
|
|
|
|
|
- return results
|
|
|
|
|
-
|
|
|
|
|
- def generate_hotwords_list(self, hotword_list_or_file):
|
|
|
|
|
- # for None
|
|
|
|
|
- if hotword_list_or_file is None:
|
|
|
|
|
- hotword_list = None
|
|
|
|
|
- # for local txt inputs
|
|
|
|
|
- elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
|
|
|
|
|
- logging.info("Attempting to parse hotwords from local txt...")
|
|
|
|
|
- hotword_list = []
|
|
|
|
|
- hotword_str_list = []
|
|
|
|
|
- with codecs.open(hotword_list_or_file, 'r') as fin:
|
|
|
|
|
- for line in fin.readlines():
|
|
|
|
|
- hw = line.strip()
|
|
|
|
|
- hotword_str_list.append(hw)
|
|
|
|
|
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
|
|
|
|
- hotword_list.append([self.asr_model.sos])
|
|
|
|
|
- hotword_str_list.append('<s>')
|
|
|
|
|
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
|
|
|
|
- .format(hotword_list_or_file, hotword_str_list))
|
|
|
|
|
- # for url, download and generate txt
|
|
|
|
|
- elif hotword_list_or_file.startswith('http'):
|
|
|
|
|
- logging.info("Attempting to parse hotwords from url...")
|
|
|
|
|
- work_dir = tempfile.TemporaryDirectory().name
|
|
|
|
|
- if not os.path.exists(work_dir):
|
|
|
|
|
- os.makedirs(work_dir)
|
|
|
|
|
- text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
|
|
|
|
|
- local_file = requests.get(hotword_list_or_file)
|
|
|
|
|
- open(text_file_path, "wb").write(local_file.content)
|
|
|
|
|
- hotword_list_or_file = text_file_path
|
|
|
|
|
- hotword_list = []
|
|
|
|
|
- hotword_str_list = []
|
|
|
|
|
- with codecs.open(hotword_list_or_file, 'r') as fin:
|
|
|
|
|
- for line in fin.readlines():
|
|
|
|
|
- hw = line.strip()
|
|
|
|
|
- hotword_str_list.append(hw)
|
|
|
|
|
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
|
|
|
|
- hotword_list.append([self.asr_model.sos])
|
|
|
|
|
- hotword_str_list.append('<s>')
|
|
|
|
|
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
|
|
|
|
- .format(hotword_list_or_file, hotword_str_list))
|
|
|
|
|
- # for text str input
|
|
|
|
|
- elif not hotword_list_or_file.endswith('.txt'):
|
|
|
|
|
- logging.info("Attempting to parse hotwords as str...")
|
|
|
|
|
- hotword_list = []
|
|
|
|
|
- hotword_str_list = []
|
|
|
|
|
- for hw in hotword_list_or_file.strip().split():
|
|
|
|
|
- hotword_str_list.append(hw)
|
|
|
|
|
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
|
|
|
|
- hotword_list.append([self.asr_model.sos])
|
|
|
|
|
- hotword_str_list.append('<s>')
|
|
|
|
|
- logging.info("Hotword list: {}.".format(hotword_str_list))
|
|
|
|
|
- else:
|
|
|
|
|
- hotword_list = None
|
|
|
|
|
- return hotword_list
|
|
|
|
|
|
|
+from funasr.utils.vad_utils import slice_padding_fbank
|
|
|
|
|
+from funasr.bin.asr_inference_paraformer import Speech2Text
|
|
|
|
|
+# class Speech2Text:
|
|
|
|
|
+# """Speech2Text class
|
|
|
|
|
+#
|
|
|
|
|
+# Examples:
|
|
|
|
|
+# >>> import soundfile
|
|
|
|
|
+# >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
|
|
|
|
|
+# >>> audio, rate = soundfile.read("speech.wav")
|
|
|
|
|
+# >>> speech2text(audio)
|
|
|
|
|
+# [(text, token, token_int, hypothesis object), ...]
|
|
|
|
|
+#
|
|
|
|
|
+# """
|
|
|
|
|
+#
|
|
|
|
|
+# def __init__(
|
|
|
|
|
+# self,
|
|
|
|
|
+# asr_train_config: Union[Path, str] = None,
|
|
|
|
|
+# asr_model_file: Union[Path, str] = None,
|
|
|
|
|
+# cmvn_file: Union[Path, str] = None,
|
|
|
|
|
+# lm_train_config: Union[Path, str] = None,
|
|
|
|
|
+# lm_file: Union[Path, str] = None,
|
|
|
|
|
+# token_type: str = None,
|
|
|
|
|
+# bpemodel: str = None,
|
|
|
|
|
+# device: str = "cpu",
|
|
|
|
|
+# maxlenratio: float = 0.0,
|
|
|
|
|
+# minlenratio: float = 0.0,
|
|
|
|
|
+# dtype: str = "float32",
|
|
|
|
|
+# beam_size: int = 20,
|
|
|
|
|
+# ctc_weight: float = 0.5,
|
|
|
|
|
+# lm_weight: float = 1.0,
|
|
|
|
|
+# ngram_weight: float = 0.9,
|
|
|
|
|
+# penalty: float = 0.0,
|
|
|
|
|
+# nbest: int = 1,
|
|
|
|
|
+# frontend_conf: dict = None,
|
|
|
|
|
+# hotword_list_or_file: str = None,
|
|
|
|
|
+# **kwargs,
|
|
|
|
|
+# ):
|
|
|
|
|
+# assert check_argument_types()
|
|
|
|
|
+#
|
|
|
|
|
+# # 1. Build ASR model
|
|
|
|
|
+# scorers = {}
|
|
|
|
|
+# asr_model, asr_train_args = ASRTask.build_model_from_file(
|
|
|
|
|
+# asr_train_config, asr_model_file, cmvn_file=cmvn_file, device=device
|
|
|
|
|
+# )
|
|
|
|
|
+# frontend = None
|
|
|
|
|
+# if asr_model.frontend is not None and asr_train_args.frontend_conf is not None:
|
|
|
|
|
+# frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
|
|
|
|
+#
|
|
|
|
|
+# # logging.info("asr_model: {}".format(asr_model))
|
|
|
|
|
+# # logging.info("asr_train_args: {}".format(asr_train_args))
|
|
|
|
|
+# asr_model.to(dtype=getattr(torch, dtype)).eval()
|
|
|
|
|
+#
|
|
|
|
|
+# if asr_model.ctc != None:
|
|
|
|
|
+# ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
|
|
|
|
+# scorers.update(
|
|
|
|
|
+# ctc=ctc
|
|
|
|
|
+# )
|
|
|
|
|
+# token_list = asr_model.token_list
|
|
|
|
|
+# scorers.update(
|
|
|
|
|
+# length_bonus=LengthBonus(len(token_list)),
|
|
|
|
|
+# )
|
|
|
|
|
+#
|
|
|
|
|
+# # 2. Build Language model
|
|
|
|
|
+# if lm_train_config is not None:
|
|
|
|
|
+# lm, lm_train_args = LMTask.build_model_from_file(
|
|
|
|
|
+# lm_train_config, lm_file, device
|
|
|
|
|
+# )
|
|
|
|
|
+# scorers["lm"] = lm.lm
|
|
|
|
|
+#
|
|
|
|
|
+# # 3. Build ngram model
|
|
|
|
|
+# # ngram is not supported now
|
|
|
|
|
+# ngram = None
|
|
|
|
|
+# scorers["ngram"] = ngram
|
|
|
|
|
+#
|
|
|
|
|
+# # 4. Build BeamSearch object
|
|
|
|
|
+# # transducer is not supported now
|
|
|
|
|
+# beam_search_transducer = None
|
|
|
|
|
+#
|
|
|
|
|
+# weights = dict(
|
|
|
|
|
+# decoder=1.0 - ctc_weight,
|
|
|
|
|
+# ctc=ctc_weight,
|
|
|
|
|
+# lm=lm_weight,
|
|
|
|
|
+# ngram=ngram_weight,
|
|
|
|
|
+# length_bonus=penalty,
|
|
|
|
|
+# )
|
|
|
|
|
+# beam_search = BeamSearch(
|
|
|
|
|
+# beam_size=beam_size,
|
|
|
|
|
+# weights=weights,
|
|
|
|
|
+# scorers=scorers,
|
|
|
|
|
+# sos=asr_model.sos,
|
|
|
|
|
+# eos=asr_model.eos,
|
|
|
|
|
+# vocab_size=len(token_list),
|
|
|
|
|
+# token_list=token_list,
|
|
|
|
|
+# pre_beam_score_key=None if ctc_weight == 1.0 else "full",
|
|
|
|
|
+# )
|
|
|
|
|
+#
|
|
|
|
|
+# beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
|
|
|
|
|
+# for scorer in scorers.values():
|
|
|
|
|
+# if isinstance(scorer, torch.nn.Module):
|
|
|
|
|
+# scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
|
|
|
|
|
+#
|
|
|
|
|
+# logging.info(f"Decoding device={device}, dtype={dtype}")
|
|
|
|
|
+#
|
|
|
|
|
+# # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
|
|
|
|
|
+# if token_type is None:
|
|
|
|
|
+# token_type = asr_train_args.token_type
|
|
|
|
|
+# if bpemodel is None:
|
|
|
|
|
+# bpemodel = asr_train_args.bpemodel
|
|
|
|
|
+#
|
|
|
|
|
+# if token_type is None:
|
|
|
|
|
+# tokenizer = None
|
|
|
|
|
+# elif token_type == "bpe":
|
|
|
|
|
+# if bpemodel is not None:
|
|
|
|
|
+# tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
|
|
|
|
|
+# else:
|
|
|
|
|
+# tokenizer = None
|
|
|
|
|
+# else:
|
|
|
|
|
+# tokenizer = build_tokenizer(token_type=token_type)
|
|
|
|
|
+# converter = TokenIDConverter(token_list=token_list)
|
|
|
|
|
+# logging.info(f"Text tokenizer: {tokenizer}")
|
|
|
|
|
+#
|
|
|
|
|
+# self.asr_model = asr_model
|
|
|
|
|
+# self.asr_train_args = asr_train_args
|
|
|
|
|
+# self.converter = converter
|
|
|
|
|
+# self.tokenizer = tokenizer
|
|
|
|
|
+#
|
|
|
|
|
+# # 6. [Optional] Build hotword list from str, local file or url
|
|
|
|
|
+# self.hotword_list = None
|
|
|
|
|
+# self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
|
|
|
|
|
+#
|
|
|
|
|
+# is_use_lm = lm_weight != 0.0 and lm_file is not None
|
|
|
|
|
+# if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
|
|
|
|
|
+# beam_search = None
|
|
|
|
|
+# self.beam_search = beam_search
|
|
|
|
|
+# logging.info(f"Beam_search: {self.beam_search}")
|
|
|
|
|
+# self.beam_search_transducer = beam_search_transducer
|
|
|
|
|
+# self.maxlenratio = maxlenratio
|
|
|
|
|
+# self.minlenratio = minlenratio
|
|
|
|
|
+# self.device = device
|
|
|
|
|
+# self.dtype = dtype
|
|
|
|
|
+# self.nbest = nbest
|
|
|
|
|
+# self.frontend = frontend
|
|
|
|
|
+# self.encoder_downsampling_factor = 1
|
|
|
|
|
+# if asr_train_args.encoder_conf["input_layer"] == "conv2d":
|
|
|
|
|
+# self.encoder_downsampling_factor = 4
|
|
|
|
|
+#
|
|
|
|
|
+# @torch.no_grad()
|
|
|
|
|
+# def __call__(
|
|
|
|
|
+# self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
|
|
|
|
|
+# begin_time: int = 0, end_time: int = None,
|
|
|
|
|
+# ):
|
|
|
|
|
+# """Inference
|
|
|
|
|
+#
|
|
|
|
|
+# Args:
|
|
|
|
|
+# speech: Input speech data
|
|
|
|
|
+# Returns:
|
|
|
|
|
+# text, token, token_int, hyp
|
|
|
|
|
+#
|
|
|
|
|
+# """
|
|
|
|
|
+# assert check_argument_types()
|
|
|
|
|
+#
|
|
|
|
|
+# # Input as audio signal
|
|
|
|
|
+# if isinstance(speech, np.ndarray):
|
|
|
|
|
+# speech = torch.tensor(speech)
|
|
|
|
|
+#
|
|
|
|
|
+# if self.frontend is not None:
|
|
|
|
|
+# feats, feats_len = self.frontend.forward(speech, speech_lengths)
|
|
|
|
|
+# # fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
|
|
|
|
|
+# # feats, feats_len = self.frontend.forward_lfr_cmvn(speech, speech_lengths)
|
|
|
|
|
+# feats = to_device(feats, device=self.device)
|
|
|
|
|
+# feats_len = feats_len.int()
|
|
|
|
|
+# self.asr_model.frontend = None
|
|
|
|
|
+# else:
|
|
|
|
|
+# feats = speech
|
|
|
|
|
+# feats_len = speech_lengths
|
|
|
|
|
+# lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
|
|
|
|
|
+# batch = {"speech": feats, "speech_lengths": feats_len}
|
|
|
|
|
+#
|
|
|
|
|
+# # a. To device
|
|
|
|
|
+# batch = to_device(batch, device=self.device)
|
|
|
|
|
+#
|
|
|
|
|
+# # b. Forward Encoder
|
|
|
|
|
+# enc, enc_len = self.asr_model.encode(**batch)
|
|
|
|
|
+# if isinstance(enc, tuple):
|
|
|
|
|
+# enc = enc[0]
|
|
|
|
|
+# # assert len(enc) == 1, len(enc)
|
|
|
|
|
+# enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
|
|
|
|
|
+#
|
|
|
|
|
+# predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
|
|
|
|
|
+# pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
|
|
|
|
|
+# predictor_outs[2], predictor_outs[3]
|
|
|
|
|
+# pre_token_length = pre_token_length.round().long()
|
|
|
|
|
+# if torch.max(pre_token_length) < 1:
|
|
|
|
|
+# return []
|
|
|
|
|
+#
|
|
|
|
|
+# if not isinstance(self.asr_model, ContextualParaformer):
|
|
|
|
|
+# if self.hotword_list:
|
|
|
|
|
+# logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
|
|
|
|
|
+# decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
|
|
|
|
|
+# decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
|
|
|
|
+# else:
|
|
|
|
|
+# decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
|
|
|
|
|
+# decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
|
|
|
|
+#
|
|
|
|
|
+# if isinstance(self.asr_model, BiCifParaformer):
|
|
|
|
|
+# _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
|
|
|
|
|
+# pre_token_length) # test no bias cif2
|
|
|
|
|
+#
|
|
|
|
|
+# results = []
|
|
|
|
|
+# b, n, d = decoder_out.size()
|
|
|
|
|
+# for i in range(b):
|
|
|
|
|
+# x = enc[i, :enc_len[i], :]
|
|
|
|
|
+# am_scores = decoder_out[i, :pre_token_length[i], :]
|
|
|
|
|
+# if self.beam_search is not None:
|
|
|
|
|
+# nbest_hyps = self.beam_search(
|
|
|
|
|
+# x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
|
|
|
|
|
+# )
|
|
|
|
|
+#
|
|
|
|
|
+# nbest_hyps = nbest_hyps[: self.nbest]
|
|
|
|
|
+# else:
|
|
|
|
|
+# yseq = am_scores.argmax(dim=-1)
|
|
|
|
|
+# score = am_scores.max(dim=-1)[0]
|
|
|
|
|
+# score = torch.sum(score, dim=-1)
|
|
|
|
|
+# # pad with mask tokens to ensure compatibility with sos/eos tokens
|
|
|
|
|
+# yseq = torch.tensor(
|
|
|
|
|
+# [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
|
|
|
|
|
+# )
|
|
|
|
|
+# nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
|
|
|
|
|
+#
|
|
|
|
|
+# for hyp in nbest_hyps:
|
|
|
|
|
+# assert isinstance(hyp, (Hypothesis)), type(hyp)
|
|
|
|
|
+#
|
|
|
|
|
+# # remove sos/eos and get results
|
|
|
|
|
+# last_pos = -1
|
|
|
|
|
+# if isinstance(hyp.yseq, list):
|
|
|
|
|
+# token_int = hyp.yseq[1:last_pos]
|
|
|
|
|
+# else:
|
|
|
|
|
+# token_int = hyp.yseq[1:last_pos].tolist()
|
|
|
|
|
+#
|
|
|
|
|
+# # remove blank symbol id, which is assumed to be 0
|
|
|
|
|
+# token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
|
|
|
|
|
+# if len(token_int) == 0:
|
|
|
|
|
+# continue
|
|
|
|
|
+#
|
|
|
|
|
+# # Change integer-ids to tokens
|
|
|
|
|
+# token = self.converter.ids2tokens(token_int)
|
|
|
|
|
+#
|
|
|
|
|
+# if self.tokenizer is not None:
|
|
|
|
|
+# text = self.tokenizer.tokens2text(token)
|
|
|
|
|
+# else:
|
|
|
|
|
+# text = None
|
|
|
|
|
+#
|
|
|
|
|
+# if isinstance(self.asr_model, BiCifParaformer):
|
|
|
|
|
+# _, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
|
|
|
|
|
+# us_peaks[i],
|
|
|
|
|
+# copy.copy(token),
|
|
|
|
|
+# vad_offset=begin_time)
|
|
|
|
|
+# results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
|
|
|
|
|
+# else:
|
|
|
|
|
+# results.append((text, token, token_int, enc_len_batch_total, lfr_factor))
|
|
|
|
|
+#
|
|
|
|
|
+# # assert check_return_type(results)
|
|
|
|
|
+# return results
|
|
|
|
|
+#
|
|
|
|
|
+# def generate_hotwords_list(self, hotword_list_or_file):
|
|
|
|
|
+# # for None
|
|
|
|
|
+# if hotword_list_or_file is None:
|
|
|
|
|
+# hotword_list = None
|
|
|
|
|
+# # for local txt inputs
|
|
|
|
|
+# elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
|
|
|
|
|
+# logging.info("Attempting to parse hotwords from local txt...")
|
|
|
|
|
+# hotword_list = []
|
|
|
|
|
+# hotword_str_list = []
|
|
|
|
|
+# with codecs.open(hotword_list_or_file, 'r') as fin:
|
|
|
|
|
+# for line in fin.readlines():
|
|
|
|
|
+# hw = line.strip()
|
|
|
|
|
+# hotword_str_list.append(hw)
|
|
|
|
|
+# hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
|
|
|
|
+# hotword_list.append([self.asr_model.sos])
|
|
|
|
|
+# hotword_str_list.append('<s>')
|
|
|
|
|
+# logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
|
|
|
|
+# .format(hotword_list_or_file, hotword_str_list))
|
|
|
|
|
+# # for url, download and generate txt
|
|
|
|
|
+# elif hotword_list_or_file.startswith('http'):
|
|
|
|
|
+# logging.info("Attempting to parse hotwords from url...")
|
|
|
|
|
+# work_dir = tempfile.TemporaryDirectory().name
|
|
|
|
|
+# if not os.path.exists(work_dir):
|
|
|
|
|
+# os.makedirs(work_dir)
|
|
|
|
|
+# text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
|
|
|
|
|
+# local_file = requests.get(hotword_list_or_file)
|
|
|
|
|
+# open(text_file_path, "wb").write(local_file.content)
|
|
|
|
|
+# hotword_list_or_file = text_file_path
|
|
|
|
|
+# hotword_list = []
|
|
|
|
|
+# hotword_str_list = []
|
|
|
|
|
+# with codecs.open(hotword_list_or_file, 'r') as fin:
|
|
|
|
|
+# for line in fin.readlines():
|
|
|
|
|
+# hw = line.strip()
|
|
|
|
|
+# hotword_str_list.append(hw)
|
|
|
|
|
+# hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
|
|
|
|
+# hotword_list.append([self.asr_model.sos])
|
|
|
|
|
+# hotword_str_list.append('<s>')
|
|
|
|
|
+# logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
|
|
|
|
+# .format(hotword_list_or_file, hotword_str_list))
|
|
|
|
|
+# # for text str input
|
|
|
|
|
+# elif not hotword_list_or_file.endswith('.txt'):
|
|
|
|
|
+# logging.info("Attempting to parse hotwords as str...")
|
|
|
|
|
+# hotword_list = []
|
|
|
|
|
+# hotword_str_list = []
|
|
|
|
|
+# for hw in hotword_list_or_file.strip().split():
|
|
|
|
|
+# hotword_str_list.append(hw)
|
|
|
|
|
+# hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
|
|
|
|
+# hotword_list.append([self.asr_model.sos])
|
|
|
|
|
+# hotword_str_list.append('<s>')
|
|
|
|
|
+# logging.info("Hotword list: {}.".format(hotword_str_list))
|
|
|
|
|
+# else:
|
|
|
|
|
+# hotword_list = None
|
|
|
|
|
+# return hotword_list
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(
|
|
def inference(
|
|
@@ -611,15 +607,17 @@ def inference_modelscope(
|
|
|
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
|
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
|
|
|
|
|
|
|
vad_results = speech2vadsegment(**batch)
|
|
vad_results = speech2vadsegment(**batch)
|
|
|
- fbanks, vadsegments = vad_results[0], vad_results[1]
|
|
|
|
|
|
|
+ _, vadsegments = vad_results[0], vad_results[1]
|
|
|
|
|
+ speech, speech_lengths = batch["speech"], batch["speech_lengths"]
|
|
|
for i, segments in enumerate(vadsegments):
|
|
for i, segments in enumerate(vadsegments):
|
|
|
result_segments = [["", [], [], []]]
|
|
result_segments = [["", [], [], []]]
|
|
|
- for j, segment_idx in enumerate(segments):
|
|
|
|
|
- bed_idx, end_idx = int(segment_idx[0] / 10), int(segment_idx[1] / 10)
|
|
|
|
|
- segment = fbanks[:, bed_idx:end_idx, :].to(device)
|
|
|
|
|
- speech_lengths = torch.Tensor([end_idx - bed_idx]).int().to(device)
|
|
|
|
|
- batch = {"speech": segment, "speech_lengths": speech_lengths, "begin_time": vadsegments[i][j][0],
|
|
|
|
|
- "end_time": vadsegments[i][j][1]}
|
|
|
|
|
|
|
+ # for j, segment_idx in enumerate(segments):
|
|
|
|
|
+ for j, beg_idx in enumerate(range(0, len(segments), batch_size)):
|
|
|
|
|
+ end_idx = min(len(segments), beg_idx + batch_size)
|
|
|
|
|
+ speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, segments[beg_idx:end_idx])
|
|
|
|
|
+
|
|
|
|
|
+ batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
|
|
|
|
|
+ batch = to_device(batch, device=device)
|
|
|
results = speech2text(**batch)
|
|
results = speech2text(**batch)
|
|
|
if len(results) < 1:
|
|
if len(results) < 1:
|
|
|
continue
|
|
continue
|
|
@@ -633,8 +631,8 @@ def inference_modelscope(
|
|
|
|
|
|
|
|
key = keys[0]
|
|
key = keys[0]
|
|
|
result = result_segments[0]
|
|
result = result_segments[0]
|
|
|
- text, token, token_int = result[0], result[1], result[2]
|
|
|
|
|
- time_stamp = None if len(result) < 4 else result[3]
|
|
|
|
|
|
|
+ text, token, token_int, hyp = result[0], result[1], result[2], result[3]
|
|
|
|
|
+ time_stamp = None if len(result) < 5 else result[4]
|
|
|
|
|
|
|
|
|
|
|
|
|
if use_timestamp and time_stamp is not None:
|
|
if use_timestamp and time_stamp is not None:
|