|
|
@@ -32,6 +32,7 @@ from funasr.modules.beam_search.beam_search import BeamSearch
|
|
|
from funasr.modules.beam_search.beam_search import Hypothesis
|
|
|
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
|
|
|
from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
|
|
|
+from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
|
|
|
from funasr.modules.scorers.ctc import CTCPrefixScorer
|
|
|
from funasr.modules.scorers.length_bonus import LengthBonus
|
|
|
from funasr.modules.subsampling import TooShortUttError
|
|
|
@@ -58,7 +59,7 @@ from funasr.bin.punc_infer import Text2Punc
|
|
|
from funasr.utils.vad_utils import slice_padding_fbank
|
|
|
from funasr.tasks.vad import VADTask
|
|
|
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
|
|
|
-
|
|
|
+from funasr.tasks.asr import frontend_choices
|
|
|
|
|
|
class Speech2Text:
|
|
|
"""Speech2Text class
|
|
|
@@ -1599,3 +1600,251 @@ class Speech2TextTransducer:
|
|
|
|
|
|
return Speech2Text(**kwargs)
|
|
|
|
|
|
+
|
|
|
+class Speech2TextSAASR:
|
|
|
+ """Speech2Text class
|
|
|
+
|
|
|
+ Examples:
|
|
|
+ >>> import soundfile
|
|
|
+ >>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb")
|
|
|
+ >>> audio, rate = soundfile.read("speech.wav")
|
|
|
+ >>> speech2text(audio)
|
|
|
+ [(text, token, token_int, hypothesis object), ...]
|
|
|
+
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ asr_train_config: Union[Path, str] = None,
|
|
|
+ asr_model_file: Union[Path, str] = None,
|
|
|
+ cmvn_file: Union[Path, str] = None,
|
|
|
+ lm_train_config: Union[Path, str] = None,
|
|
|
+ lm_file: Union[Path, str] = None,
|
|
|
+ token_type: str = None,
|
|
|
+ bpemodel: str = None,
|
|
|
+ device: str = "cpu",
|
|
|
+ maxlenratio: float = 0.0,
|
|
|
+ minlenratio: float = 0.0,
|
|
|
+ batch_size: int = 1,
|
|
|
+ dtype: str = "float32",
|
|
|
+ beam_size: int = 20,
|
|
|
+ ctc_weight: float = 0.5,
|
|
|
+ lm_weight: float = 1.0,
|
|
|
+ ngram_weight: float = 0.9,
|
|
|
+ penalty: float = 0.0,
|
|
|
+ nbest: int = 1,
|
|
|
+ streaming: bool = False,
|
|
|
+ frontend_conf: dict = None,
|
|
|
+ **kwargs,
|
|
|
+ ):
|
|
|
+ assert check_argument_types()
|
|
|
+
|
|
|
+ # 1. Build ASR model
|
|
|
+ from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
|
|
|
+ scorers = {}
|
|
|
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
|
|
|
+ asr_train_config, asr_model_file, cmvn_file, device
|
|
|
+ )
|
|
|
+ frontend = None
|
|
|
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
|
|
+ if asr_train_args.frontend == 'wav_frontend':
|
|
|
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
|
|
+ else:
|
|
|
+ frontend_class = frontend_choices.get_class(asr_train_args.frontend)
|
|
|
+ frontend = frontend_class(**asr_train_args.frontend_conf).eval()
|
|
|
+
|
|
|
+ logging.info("asr_model: {}".format(asr_model))
|
|
|
+ logging.info("asr_train_args: {}".format(asr_train_args))
|
|
|
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
|
|
|
+
|
|
|
+ decoder = asr_model.decoder
|
|
|
+
|
|
|
+ ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
|
|
+ token_list = asr_model.token_list
|
|
|
+ scorers.update(
|
|
|
+ decoder=decoder,
|
|
|
+ ctc=ctc,
|
|
|
+ length_bonus=LengthBonus(len(token_list)),
|
|
|
+ )
|
|
|
+
|
|
|
+ # 2. Build Language model
|
|
|
+ if lm_train_config is not None:
|
|
|
+ lm, lm_train_args = LMTask.build_model_from_file(
|
|
|
+ lm_train_config, lm_file, None, device
|
|
|
+ )
|
|
|
+ scorers["lm"] = lm.lm
|
|
|
+
|
|
|
+ # 3. Build ngram model
|
|
|
+ # ngram is not supported now
|
|
|
+ ngram = None
|
|
|
+ scorers["ngram"] = ngram
|
|
|
+
|
|
|
+ # 4. Build BeamSearch object
|
|
|
+ # transducer is not supported now
|
|
|
+ beam_search_transducer = None
|
|
|
+
|
|
|
+ weights = dict(
|
|
|
+ decoder=1.0 - ctc_weight,
|
|
|
+ ctc=ctc_weight,
|
|
|
+ lm=lm_weight,
|
|
|
+ ngram=ngram_weight,
|
|
|
+ length_bonus=penalty,
|
|
|
+ )
|
|
|
+ beam_search = BeamSearch(
|
|
|
+ beam_size=beam_size,
|
|
|
+ weights=weights,
|
|
|
+ scorers=scorers,
|
|
|
+ sos=asr_model.sos,
|
|
|
+ eos=asr_model.eos,
|
|
|
+ vocab_size=len(token_list),
|
|
|
+ token_list=token_list,
|
|
|
+ pre_beam_score_key=None if ctc_weight == 1.0 else "full",
|
|
|
+ )
|
|
|
+
|
|
|
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
|
|
|
+ if token_type is None:
|
|
|
+ token_type = asr_train_args.token_type
|
|
|
+ if bpemodel is None:
|
|
|
+ bpemodel = asr_train_args.bpemodel
|
|
|
+
|
|
|
+ if token_type is None:
|
|
|
+ tokenizer = None
|
|
|
+ elif token_type == "bpe":
|
|
|
+ if bpemodel is not None:
|
|
|
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
|
|
|
+ else:
|
|
|
+ tokenizer = None
|
|
|
+ else:
|
|
|
+ tokenizer = build_tokenizer(token_type=token_type)
|
|
|
+ converter = TokenIDConverter(token_list=token_list)
|
|
|
+ logging.info(f"Text tokenizer: {tokenizer}")
|
|
|
+
|
|
|
+ self.asr_model = asr_model
|
|
|
+ self.asr_train_args = asr_train_args
|
|
|
+ self.converter = converter
|
|
|
+ self.tokenizer = tokenizer
|
|
|
+ self.beam_search = beam_search
|
|
|
+ self.beam_search_transducer = beam_search_transducer
|
|
|
+ self.maxlenratio = maxlenratio
|
|
|
+ self.minlenratio = minlenratio
|
|
|
+ self.device = device
|
|
|
+ self.dtype = dtype
|
|
|
+ self.nbest = nbest
|
|
|
+ self.frontend = frontend
|
|
|
+
|
|
|
+ @torch.no_grad()
|
|
|
+ def __call__(
|
|
|
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
|
|
|
+ profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
|
|
|
+ ) -> List[
|
|
|
+ Tuple[
|
|
|
+ Optional[str],
|
|
|
+ Optional[str],
|
|
|
+ List[str],
|
|
|
+ List[int],
|
|
|
+ Union[HypothesisSAASR],
|
|
|
+ ]
|
|
|
+ ]:
|
|
|
+ """Inference
|
|
|
+
|
|
|
+ Args:
|
|
|
+ speech: Input speech data
|
|
|
+ Returns:
|
|
|
+ text, text_id, token, token_int, hyp
|
|
|
+
|
|
|
+ """
|
|
|
+ assert check_argument_types()
|
|
|
+
|
|
|
+ # Input as audio signal
|
|
|
+ if isinstance(speech, np.ndarray):
|
|
|
+ speech = torch.tensor(speech)
|
|
|
+
|
|
|
+ if isinstance(profile, np.ndarray):
|
|
|
+ profile = torch.tensor(profile)
|
|
|
+
|
|
|
+ if self.frontend is not None:
|
|
|
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
|
|
|
+ feats = to_device(feats, device=self.device)
|
|
|
+ feats_len = feats_len.int()
|
|
|
+ self.asr_model.frontend = None
|
|
|
+ else:
|
|
|
+ feats = speech
|
|
|
+ feats_len = speech_lengths
|
|
|
+ lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
|
|
|
+ batch = {"speech": feats, "speech_lengths": feats_len}
|
|
|
+
|
|
|
+ # a. To device
|
|
|
+ batch = to_device(batch, device=self.device)
|
|
|
+
|
|
|
+ # b. Forward Encoder
|
|
|
+ asr_enc, _, spk_enc = self.asr_model.encode(**batch)
|
|
|
+ if isinstance(asr_enc, tuple):
|
|
|
+ asr_enc = asr_enc[0]
|
|
|
+ if isinstance(spk_enc, tuple):
|
|
|
+ spk_enc = spk_enc[0]
|
|
|
+ assert len(asr_enc) == 1, len(asr_enc)
|
|
|
+ assert len(spk_enc) == 1, len(spk_enc)
|
|
|
+
|
|
|
+ # c. Passed the encoder result and the beam search
|
|
|
+ nbest_hyps = self.beam_search(
|
|
|
+ asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
|
|
|
+ )
|
|
|
+
|
|
|
+ nbest_hyps = nbest_hyps[: self.nbest]
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for hyp in nbest_hyps:
|
|
|
+ assert isinstance(hyp, (HypothesisSAASR)), type(hyp)
|
|
|
+
|
|
|
+ # remove sos/eos and get results
|
|
|
+ last_pos = -1
|
|
|
+ if isinstance(hyp.yseq, list):
|
|
|
+ token_int = hyp.yseq[1: last_pos]
|
|
|
+ else:
|
|
|
+ token_int = hyp.yseq[1: last_pos].tolist()
|
|
|
+
|
|
|
+ spk_weigths = torch.stack(hyp.spk_weigths, dim=0)
|
|
|
+
|
|
|
+ token_ori = self.converter.ids2tokens(token_int)
|
|
|
+ text_ori = self.tokenizer.tokens2text(token_ori)
|
|
|
+
|
|
|
+ text_ori_spklist = text_ori.split('$')
|
|
|
+ cur_index = 0
|
|
|
+ spk_choose = []
|
|
|
+ for i in range(len(text_ori_spklist)):
|
|
|
+ text_ori_split = text_ori_spklist[i]
|
|
|
+ n = len(text_ori_split)
|
|
|
+ spk_weights_local = spk_weigths[cur_index: cur_index + n]
|
|
|
+ cur_index = cur_index + n + 1
|
|
|
+ spk_weights_local = spk_weights_local.mean(dim=0)
|
|
|
+ spk_choose_local = spk_weights_local.argmax(-1)
|
|
|
+ spk_choose.append(spk_choose_local.item() + 1)
|
|
|
+
|
|
|
+ # remove blank symbol id, which is assumed to be 0
|
|
|
+ token_int = list(filter(lambda x: x != 0, token_int))
|
|
|
+
|
|
|
+ # Change integer-ids to tokens
|
|
|
+ token = self.converter.ids2tokens(token_int)
|
|
|
+
|
|
|
+ if self.tokenizer is not None:
|
|
|
+ text = self.tokenizer.tokens2text(token)
|
|
|
+ else:
|
|
|
+ text = None
|
|
|
+
|
|
|
+ text_spklist = text.split('$')
|
|
|
+ assert len(spk_choose) == len(text_spklist)
|
|
|
+
|
|
|
+ spk_list = []
|
|
|
+ for i in range(len(text_spklist)):
|
|
|
+ text_split = text_spklist[i]
|
|
|
+ n = len(text_split)
|
|
|
+ spk_list.append(str(spk_choose[i]) * n)
|
|
|
+
|
|
|
+ text_id = '$'.join(spk_list)
|
|
|
+
|
|
|
+ assert len(text) == len(text_id)
|
|
|
+
|
|
|
+ results.append((text, text_id, token, token_int, hyp))
|
|
|
+
|
|
|
+ assert check_return_type(results)
|
|
|
+ return results
|