|
|
@@ -48,20 +48,29 @@ class Paraformer():
|
|
|
|
|
|
asr_res = []
|
|
|
for beg_idx in range(0, waveform_nums, self.batch_size):
|
|
|
+ res = {}
|
|
|
end_idx = min(waveform_nums, beg_idx + self.batch_size)
|
|
|
-
|
|
|
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
|
|
|
|
|
|
try:
|
|
|
- am_scores, valid_token_lens = self.infer(feats, feats_len)
|
|
|
+ outputs = self.infer(feats, feats_len)
|
|
|
+ am_scores, valid_token_lens = outputs[0], outputs[1]
|
|
|
+ if len(outputs) == 4:
|
|
|
+ # for BiCifParaformer Inference
|
|
|
+ us_alphas, us_cif_peak = outputs[2], outputs[3]
|
|
|
+ else:
|
|
|
+ us_alphas, us_cif_peak = None, None
|
|
|
except ONNXRuntimeError:
|
|
|
#logging.warning(traceback.format_exc())
|
|
|
logging.warning("input wav is silence or noise")
|
|
|
preds = ['']
|
|
|
else:
|
|
|
- preds = self.decode(am_scores, valid_token_lens)
|
|
|
-
|
|
|
- asr_res.extend(preds)
|
|
|
+ preds, raw_token = self.decode(am_scores, valid_token_lens)[0]
|
|
|
+ res['preds'] = preds
|
|
|
+ if us_cif_peak is not None:
|
|
|
+ timestamp = time_stamp_lfr6_pl(us_alphas, us_cif_peak, copy.copy(raw_token), log=False)
|
|
|
+ res['timestamp'] = timestamp
|
|
|
+ asr_res.append(res)
|
|
|
return asr_res
|
|
|
|
|
|
def load_data(self,
|
|
|
@@ -108,8 +117,8 @@ class Paraformer():
|
|
|
|
|
|
def infer(self, feats: np.ndarray,
|
|
|
feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
|
|
- am_scores, token_nums = self.ort_infer([feats, feats_len])
|
|
|
- return am_scores, token_nums
|
|
|
+ outputs = self.ort_infer([feats, feats_len])
|
|
|
+ return outputs
|
|
|
|
|
|
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
|
|
|
return [self.decode_one(am_score, token_num)
|
|
|
@@ -140,63 +149,5 @@ class Paraformer():
|
|
|
texts = sentence_postprocess(token)
|
|
|
text = texts[0]
|
|
|
# text = self.tokenizer.tokens2text(token)
|
|
|
- return text
|
|
|
-
|
|
|
-
|
|
|
-class BiCifParaformer(Paraformer):
|
|
|
- def infer(self, feats: np.ndarray,
|
|
|
- feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
|
|
- am_scores, token_nums, us_alphas, us_cif_peak = self.ort_infer([feats, feats_len])
|
|
|
- return am_scores, token_nums, us_alphas, us_cif_peak
|
|
|
- def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
|
|
|
- waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
|
|
|
- waveform_nums = len(waveform_list)
|
|
|
-
|
|
|
- asr_res = []
|
|
|
- for beg_idx in range(0, waveform_nums, self.batch_size):
|
|
|
- res = {}
|
|
|
- end_idx = min(waveform_nums, beg_idx + self.batch_size)
|
|
|
- feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
|
|
|
- am_scores, valid_token_lens, us_alphas, us_cif_peak = self.infer(feats, feats_len)
|
|
|
-
|
|
|
- try:
|
|
|
- am_scores, valid_token_lens, us_alphas, us_cif_peak = self.infer(feats, feats_len)
|
|
|
- except ONNXRuntimeError:
|
|
|
- #logging.warning(traceback.format_exc())
|
|
|
- logging.warning("input wav is silence or noise")
|
|
|
- preds = ['']
|
|
|
- else:
|
|
|
- token = self.decode(am_scores, valid_token_lens)
|
|
|
- timestamp = time_stamp_lfr6_pl(us_alphas, us_cif_peak, copy.copy(token[0]), log=False)
|
|
|
- texts = sentence_postprocess(token[0], timestamp)
|
|
|
- # texts = sentence_postprocess(token[0])
|
|
|
- text = texts[0]
|
|
|
- res['text'] = text
|
|
|
- res['timestamp'] = timestamp
|
|
|
- asr_res.append(res)
|
|
|
+ return text, token
|
|
|
|
|
|
- return asr_res
|
|
|
-
|
|
|
- def decode_one(self,
|
|
|
- am_score: np.ndarray,
|
|
|
- valid_token_num: int) -> List[str]:
|
|
|
- yseq = am_score.argmax(axis=-1)
|
|
|
- score = am_score.max(axis=-1)
|
|
|
- score = np.sum(score, axis=-1)
|
|
|
-
|
|
|
- # pad with mask tokens to ensure compatibility with sos/eos tokens
|
|
|
- # asr_model.sos:1 asr_model.eos:2
|
|
|
- yseq = np.array([1] + yseq.tolist() + [2])
|
|
|
- hyp = Hypothesis(yseq=yseq, score=score)
|
|
|
-
|
|
|
- # remove sos/eos and get results
|
|
|
- last_pos = -1
|
|
|
- token_int = hyp.yseq[1:last_pos].tolist()
|
|
|
-
|
|
|
- # remove blank symbol id, which is assumed to be 0
|
|
|
- token_int = list(filter(lambda x: x not in (0, 2), token_int))
|
|
|
-
|
|
|
- # Change integer-ids to tokens
|
|
|
- token = self.converter.ids2tokens(token_int)
|
|
|
- # token = token[:valid_token_num-1]
|
|
|
- return token
|