|
@@ -280,6 +280,7 @@ class Speech2TextParaformer:
|
|
|
nbest: int = 1,
|
|
nbest: int = 1,
|
|
|
frontend_conf: dict = None,
|
|
frontend_conf: dict = None,
|
|
|
hotword_list_or_file: str = None,
|
|
hotword_list_or_file: str = None,
|
|
|
|
|
+ clas_scale: float = 1.0,
|
|
|
decoding_ind: int = 0,
|
|
decoding_ind: int = 0,
|
|
|
**kwargs,
|
|
**kwargs,
|
|
|
):
|
|
):
|
|
@@ -376,6 +377,7 @@ class Speech2TextParaformer:
|
|
|
# 6. [Optional] Build hotword list from str, local file or url
|
|
# 6. [Optional] Build hotword list from str, local file or url
|
|
|
self.hotword_list = None
|
|
self.hotword_list = None
|
|
|
self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
|
|
self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
|
|
|
|
|
+ self.clas_scale = clas_scale
|
|
|
|
|
|
|
|
is_use_lm = lm_weight != 0.0 and lm_file is not None
|
|
is_use_lm = lm_weight != 0.0 and lm_file is not None
|
|
|
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
|
|
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
|
|
@@ -439,16 +441,20 @@ class Speech2TextParaformer:
|
|
|
pre_token_length = pre_token_length.round().long()
|
|
pre_token_length = pre_token_length.round().long()
|
|
|
if torch.max(pre_token_length) < 1:
|
|
if torch.max(pre_token_length) < 1:
|
|
|
return []
|
|
return []
|
|
|
- if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
|
|
|
|
|
- NeatContextualParaformer):
|
|
|
|
|
|
|
+ if not isinstance(self.asr_model, ContextualParaformer) and \
|
|
|
|
|
+ not isinstance(self.asr_model, NeatContextualParaformer):
|
|
|
if self.hotword_list:
|
|
if self.hotword_list:
|
|
|
logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
|
|
logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
|
|
|
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
|
|
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
|
|
|
pre_token_length)
|
|
pre_token_length)
|
|
|
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
|
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
|
|
else:
|
|
else:
|
|
|
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
|
|
|
|
|
- pre_token_length, hw_list=self.hotword_list)
|
|
|
|
|
|
|
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc,
|
|
|
|
|
+ enc_len,
|
|
|
|
|
+ pre_acoustic_embeds,
|
|
|
|
|
+ pre_token_length,
|
|
|
|
|
+ hw_list=self.hotword_list,
|
|
|
|
|
+ clas_scale=self.clas_scale)
|
|
|
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
|
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
|
|
|
|
|
|
|
if isinstance(self.asr_model, BiCifParaformer):
|
|
if isinstance(self.asr_model, BiCifParaformer):
|