Browse Source

Funasr1.0 (#1277)

* funasr1.0 funetine

* funasr1.0 pbar

* update with main (#1260)

* Update websocket_protocol_zh.md

* update

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

* update with main (#1264)

* Funasr1.0 (#1261)

* funasr1.0 funetine

* funasr1.0 pbar

* update with main (#1260)

* Update websocket_protocol_zh.md

* update

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

* bug fix

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

* funasr1.0 sanm scama

* funasr1.0 infer_after_finetune

* funasr1.0 fsmn-vad bug fix

* funasr1.0 fsmn-vad bug fix

* funasr1.0 fsmn-vad bug fix

* funasr1.0 finetune

* funasr1.0 finetune

* funasr1.0 finetune

* funasr1.0 finetune

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
zhifu gao 2 năm trước cách đây
mục cha
commit
37d7764ecf

+ 2 - 1
funasr/auto/auto_model.py

@@ -132,7 +132,8 @@ class AutoModel:
         self.punc_kwargs = punc_kwargs
         self.spk_model = spk_model
         self.spk_kwargs = spk_kwargs
-        self.model_path = kwargs.get("model_path", "./")
+        self.model_path = kwargs.get("model_path")
+
   
         
     def build_model(self, **kwargs):

+ 1 - 1
funasr/datasets/audio_datasets/datasets.py

@@ -58,7 +58,7 @@ class AudioDataset(torch.utils.data.Dataset):
         data_src = load_audio_text_image_video(source, fs=self.fs)
         if self.preprocessor_speech:
             data_src = self.preprocessor_speech(data_src)
-        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
+        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
 
         target = item["target"]
         if self.preprocessor_text:

+ 2 - 1
funasr/frontends/wav_frontend.py

@@ -399,9 +399,10 @@ class WavFrontendOnline(nn.Module):
         return feats_pad, feats_lens, lfr_splice_frame_idxs
 
     def forward(
-        self, input: torch.Tensor, input_lengths: torch.Tensor, cache: dict = {}, **kwargs
+        self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs
     ):
         is_final = kwargs.get("is_final", False)
+        cache = kwargs.get("cache", {})
         if len(cache) == 0:
             self.init_cache(cache)
         

+ 46 - 30
funasr/models/fsmn_vad_streaming/model.py

@@ -15,7 +15,7 @@ from funasr.register import tables
 from typing import List, Tuple, Dict, Any, Optional
 
 from funasr.utils.datadir_writer import DatadirWriter
-from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 
 
 class VadStateMachine(Enum):
@@ -23,11 +23,13 @@ class VadStateMachine(Enum):
 	kVadInStateInSpeechSegment = 2
 	kVadInStateEndPointDetected = 3
 
+
 class FrameState(Enum):
 	kFrameStateInvalid = -1
 	kFrameStateSpeech = 1
 	kFrameStateSil = 0
 
+
 # final voice/unvoice state per frame
 class AudioChangeState(Enum):
 	kChangeStateSpeech2Speech = 0
@@ -37,16 +39,19 @@ class AudioChangeState(Enum):
 	kChangeStateNoBegin = 4
 	kChangeStateInvalid = 5
 
+
 class VadDetectMode(Enum):
 	kVadSingleUtteranceDetectMode = 0
 	kVadMutipleUtteranceDetectMode = 1
 
+
 class VADXOptions:
 	"""
 	Author: Speech Lab of DAMO Academy, Alibaba Group
 	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 	https://arxiv.org/abs/1803.05030
 	"""
+	
 	def __init__(
 		self,
 		sample_rate: int = 16000,
@@ -117,6 +122,7 @@ class E2EVadSpeechBufWithDoa(object):
 	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 	https://arxiv.org/abs/1803.05030
 	"""
+	
 	def __init__(self):
 		self.start_ms = 0
 		self.end_ms = 0
@@ -140,6 +146,7 @@ class E2EVadFrameProb(object):
 	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 	https://arxiv.org/abs/1803.05030
 	"""
+	
 	def __init__(self):
 		self.noise_prob = 0.0
 		self.speech_prob = 0.0
@@ -154,6 +161,7 @@ class WindowDetector(object):
 	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 	https://arxiv.org/abs/1803.05030
 	"""
+	
 	def __init__(self, window_size_ms: int,
 	             sil_to_speech_time: int,
 	             speech_to_sil_time: int,
@@ -190,7 +198,7 @@ class WindowDetector(object):
 	def GetWinSize(self) -> int:
 		return int(self.win_size_frame)
 	
-	def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
+	def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict = {}) -> AudioChangeState:
 		cur_frame_state = FrameState.kFrameStateSil
 		if frameState == FrameState.kFrameStateSpeech:
 			cur_frame_state = 1
@@ -220,13 +228,13 @@ class WindowDetector(object):
 	def FrameSizeMs(self) -> int:
 		return int(self.frame_size_ms)
 
+
 class Stats(object):
 	def __init__(self,
 	             sil_pdf_ids,
 	             max_end_sil_frame_cnt_thresh,
 	             speech_noise_thres,
 	             ):
-		
 		self.data_buf_start_frame = 0
 		self.frm_cnt = 0
 		self.latest_confirmed_speech_frame = 0
@@ -255,6 +263,7 @@ class Stats(object):
 		self.waveform = None
 		self.last_drop_frames = 0
 
+
 @tables.register("model_classes", "FsmnVADStreaming")
 class FsmnVADStreaming(nn.Module):
 	"""
@@ -262,6 +271,7 @@ class FsmnVADStreaming(nn.Module):
 	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 	https://arxiv.org/abs/1803.05030
 	"""
+	
 	def __init__(self,
 	             encoder: str = None,
 	             encoder_conf: Optional[Dict] = None,
@@ -275,7 +285,6 @@ class FsmnVADStreaming(nn.Module):
 		encoder = encoder_class(**encoder_conf)
 		self.encoder = encoder
 	
-	
 	def ResetDetection(self, cache: dict = {}):
 		cache["stats"].continous_silence_frame_count = 0
 		cache["stats"].latest_confirmed_speech_frame = 0
@@ -292,7 +301,8 @@ class FsmnVADStreaming(nn.Module):
 			drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
 			real_drop_frames = drop_frames - cache["stats"].last_drop_frames
 			cache["stats"].last_drop_frames = drop_frames
-			cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+			cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(
+				self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
 			cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
 			cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
 	
@@ -300,7 +310,8 @@ class FsmnVADStreaming(nn.Module):
 		frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
 		frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
 		if cache["stats"].data_buf_all is None:
-			cache["stats"].data_buf_all = cache["stats"].waveform[0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
+			cache["stats"].data_buf_all = cache["stats"].waveform[
+				0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
 			cache["stats"].data_buf = cache["stats"].data_buf_all
 		else:
 			cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
@@ -319,15 +330,16 @@ class FsmnVADStreaming(nn.Module):
 		else:
 			cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
 	
-	def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None:  # need check again
+	def PopDataBufTillFrame(self, frame_idx: int, cache: dict = {}) -> None:  # need check again
 		while cache["stats"].data_buf_start_frame < frame_idx:
 			if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
 				cache["stats"].data_buf_start_frame += 1
-				cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
-					self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+				cache["stats"].data_buf = cache["stats"].data_buf_all[
+				                          (cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
+					                          self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
 	
 	def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
-	                       last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
+	                       last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict = {}) -> None:
 		self.PopDataBufTillFrame(start_frm, cache=cache)
 		expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
 		if last_frm_is_end_point:
@@ -379,14 +391,15 @@ class FsmnVADStreaming(nn.Module):
 		cache["stats"].lastest_confirmed_silence_frame = valid_frame
 		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
 			self.PopDataBufTillFrame(valid_frame, cache=cache)
-		# silence_detected_callback_
-		# pass
 	
-	def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
+	# silence_detected_callback_
+	# pass
+	
+	def OnVoiceDetected(self, valid_frame: int, cache: dict = {}) -> None:
 		cache["stats"].latest_confirmed_speech_frame = valid_frame
 		self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
 	
-	def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
+	def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = {}) -> None:
 		if self.vad_opts.do_start_point_detection:
 			pass
 		if cache["stats"].confirmed_start_frame != -1:
@@ -397,7 +410,7 @@ class FsmnVADStreaming(nn.Module):
 		if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
 			self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
 	
-	def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
+	def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = {}) -> None:
 		for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
 			self.OnVoiceDetected(t, cache=cache)
 		if self.vad_opts.do_end_point_detection:
@@ -487,7 +500,8 @@ class FsmnVADStreaming(nn.Module):
 			segment_batch = []
 			if len(cache["stats"].output_data_buf) > 0:
 				for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
-					if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
+					if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not
+					cache["stats"].output_data_buf[
 						i].contain_seg_end_point):
 						continue
 					segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
@@ -499,9 +513,9 @@ class FsmnVADStreaming(nn.Module):
 		#     # reset class variables and clear the dict for the next query
 		#     self.AllResetDetection()
 		return segments
-
+	
 	def init_cache(self, cache: dict = {}, **kwargs):
-    
+		
 		cache["frontend"] = {}
 		cache["prev_samples"] = torch.empty(0)
 		cache["encoder"] = {}
@@ -528,12 +542,12 @@ class FsmnVADStreaming(nn.Module):
 	              cache: dict = {},
 	              **kwargs,
 	              ):
-
+		
 		if len(cache) == 0:
 			self.init_cache(cache, **kwargs)
 		
 		meta_data = {}
-		chunk_size = kwargs.get("chunk_size", 60000) # 50ms
+		chunk_size = kwargs.get("chunk_size", 60000)  # 50ms
 		chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
 		
 		time1 = time.perf_counter()
@@ -580,7 +594,6 @@ class FsmnVADStreaming(nn.Module):
 			if len(segments_i) > 0:
 				segments.extend(*segments_i)
 		
-		
 		cache["prev_samples"] = audio_sample[:-m]
 		if _is_final:
 			self.init_cache(cache)
@@ -600,16 +613,15 @@ class FsmnVADStreaming(nn.Module):
 		if ibest_writer is not None:
 			ibest_writer["text"][key[0]] = segments
 		
-		
 		return results, meta_data
 	
-	
 	def DetectCommonFrames(self, cache: dict = {}) -> int:
 		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
 			return 0
 		for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
 			frame_state = FrameState.kFrameStateInvalid
-			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
+			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
+			                                 cache=cache)
 			self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
 		
 		return 0
@@ -619,7 +631,8 @@ class FsmnVADStreaming(nn.Module):
 			return 0
 		for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
 			frame_state = FrameState.kFrameStateInvalid
-			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
+			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
+			                                 cache=cache)
 			if i != 0:
 				self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
 			else:
@@ -627,7 +640,8 @@ class FsmnVADStreaming(nn.Module):
 		
 		return 0
 	
-	def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
+	def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool,
+	                   cache: dict = {}) -> None:
 		tmp_cur_frm_state = FrameState.kFrameStateInvalid
 		if cur_frm_state == FrameState.kFrameStateSpeech:
 			if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
@@ -644,7 +658,8 @@ class FsmnVADStreaming(nn.Module):
 			cache["stats"].pre_end_silence_detected = False
 			start_frame = 0
 			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-				start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
+				start_frame = max(cache["stats"].data_buf_start_frame,
+				                  cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
 				self.OnVoiceStart(start_frame, cache=cache)
 				cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
 				for t in range(start_frame + 1, cur_frm_idx + 1):
@@ -696,7 +711,8 @@ class FsmnVADStreaming(nn.Module):
 			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
 				# silence timeout, return zero length decision
 				if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
-					cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
+					cache[
+						"stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
 					or (is_final_frame and cache["stats"].number_end_time_detected == 0):
 					for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
 						self.OnSilenceDetected(t, cache=cache)
@@ -707,7 +723,8 @@ class FsmnVADStreaming(nn.Module):
 					if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
 						self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
 			elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
-				if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
+				if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache[
+					"stats"].max_end_sil_frame_cnt_thresh:
 					lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
 					if self.vad_opts.do_extend:
 						lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
@@ -733,4 +750,3 @@ class FsmnVADStreaming(nn.Module):
 			self.ResetDetection(cache=cache)
 
 
-

+ 466 - 1
funasr/models/scama/beam_search.py

@@ -11,7 +11,7 @@ from typing import Union
 
 import torch
 
-from funasr.metrics import end_detect
+from funasr.metrics.common import end_detect
 from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
 from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
 
@@ -494,3 +494,468 @@ class BeamSearchScama(torch.nn.Module):
             else:
                 remained_hyps.append(hyp)
         return remained_hyps
+
+class BeamSearchScamaStreaming(torch.nn.Module):
+    """Beam search implementation."""
+
+    def __init__(
+        self,
+        scorers: Dict[str, ScorerInterface],
+        weights: Dict[str, float],
+        beam_size: int,
+        vocab_size: int,
+        sos: int,
+        eos: int,
+        token_list: List[str] = None,
+        pre_beam_ratio: float = 1.5,
+        pre_beam_score_key: str = None,
+    ):
+        """Initialize beam search.
+
+        Args:
+            scorers (dict[str, ScorerInterface]): Dict of decoder modules
+                e.g., Decoder, CTCPrefixScorer, LM
+                The scorer will be ignored if it is `None`
+            weights (dict[str, float]): Dict of weights for each scorers
+                The scorer will be ignored if its weight is 0
+            beam_size (int): The number of hypotheses kept during search
+            vocab_size (int): The number of vocabulary
+            sos (int): Start of sequence id
+            eos (int): End of sequence id
+            token_list (list[str]): List of tokens for debug log
+            pre_beam_score_key (str): key of scores to perform pre-beam search
+            pre_beam_ratio (float): beam size in the pre-beam search
+                will be `int(pre_beam_ratio * beam_size)`
+
+        """
+        super().__init__()
+        # set scorers
+        self.weights = weights
+        self.scorers = dict()
+        self.full_scorers = dict()
+        self.part_scorers = dict()
+        # this module dict is required for recursive cast
+        # `self.to(device, dtype)` in `recog.py`
+        self.nn_dict = torch.nn.ModuleDict()
+        for k, v in scorers.items():
+            w = weights.get(k, 0)
+            if w == 0 or v is None:
+                continue
+            assert isinstance(
+                v, ScorerInterface
+            ), f"{k} ({type(v)}) does not implement ScorerInterface"
+            self.scorers[k] = v
+            if isinstance(v, PartialScorerInterface):
+                self.part_scorers[k] = v
+            else:
+                self.full_scorers[k] = v
+            if isinstance(v, torch.nn.Module):
+                self.nn_dict[k] = v
+
+        # set configurations
+        self.sos = sos
+        self.eos = eos
+        self.token_list = token_list
+        self.pre_beam_size = int(pre_beam_ratio * beam_size)
+        self.beam_size = beam_size
+        self.n_vocab = vocab_size
+        if (
+            pre_beam_score_key is not None
+            and pre_beam_score_key != "full"
+            and pre_beam_score_key not in self.full_scorers
+        ):
+            raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
+        self.pre_beam_score_key = pre_beam_score_key
+        self.do_pre_beam = (
+            self.pre_beam_score_key is not None
+            and self.pre_beam_size < self.n_vocab
+            and len(self.part_scorers) > 0
+        )
+
+    def init_hyp(self, x) -> List[Hypothesis]:
+        """Get an initial hypothesis data.
+
+        Args:
+            x (torch.Tensor): The encoder output feature
+
+        Returns:
+            Hypothesis: The initial hypothesis.
+
+        """
+        init_states = dict()
+        init_scores = dict()
+        for k, d in self.scorers.items():
+            init_states[k] = d.init_state(x)
+            init_scores[k] = 0.0
+        return [
+            Hypothesis(
+                score=0.0,
+                scores=init_scores,
+                states=init_states,
+                yseq=torch.tensor([self.sos], device=x.device),
+            )
+        ]
+
+    @staticmethod
+    def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
+        """Append new token to prefix tokens.
+
+        Args:
+            xs (torch.Tensor): The prefix token
+            x (int): The new token to append
+
+        Returns:
+            torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
+
+        """
+        x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
+        return torch.cat((xs, x))
+
+    def score_full(
+        self, hyp: Hypothesis,
+        x: torch.Tensor,
+        x_mask: torch.Tensor = None,
+        pre_acoustic_embeds: torch.Tensor = None,
+        cache: dict={},
+    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+        """Score new hypothesis by `self.full_scorers`.
+
+        Args:
+            hyp (Hypothesis): Hypothesis with prefix tokens to score
+            x (torch.Tensor): Corresponding input feature
+
+        Returns:
+            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+                score dict of `hyp` that has string keys of `self.full_scorers`
+                and tensor score values of shape: `(self.n_vocab,)`,
+                and state dict that has string keys
+                and state values of `self.full_scorers`
+
+        """
+        scores = dict()
+        states = dict()
+        for k, d in self.full_scorers.items():
+            scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache)
+        return scores, states
+
+    def score_partial(
+        self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
+    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+        """Score new hypothesis by `self.part_scorers`.
+
+        Args:
+            hyp (Hypothesis): Hypothesis with prefix tokens to score
+            ids (torch.Tensor): 1D tensor of new partial tokens to score
+            x (torch.Tensor): Corresponding input feature
+
+        Returns:
+            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+                score dict of `hyp` that has string keys of `self.part_scorers`
+                and tensor score values of shape: `(len(ids),)`,
+                and state dict that has string keys
+                and state values of `self.part_scorers`
+
+        """
+        scores = dict()
+        states = dict()
+        for k, d in self.part_scorers.items():
+            scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
+        return scores, states
+
+    def beam(
+        self, weighted_scores: torch.Tensor, ids: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute topk full token ids and partial token ids.
+
+        Args:
+            weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
+            Its shape is `(self.n_vocab,)`.
+            ids (torch.Tensor): The partial token ids to compute topk
+
+        Returns:
+            Tuple[torch.Tensor, torch.Tensor]:
+                The topk full token ids and partial token ids.
+                Their shapes are `(self.beam_size,)`
+
+        """
+        # no pre beam performed
+        if weighted_scores.size(0) == ids.size(0):
+            top_ids = weighted_scores.topk(self.beam_size)[1]
+            return top_ids, top_ids
+
+        # mask pruned in pre-beam not to select in topk
+        tmp = weighted_scores[ids]
+        weighted_scores[:] = -float("inf")
+        weighted_scores[ids] = tmp
+        top_ids = weighted_scores.topk(self.beam_size)[1]
+        local_ids = weighted_scores[ids].topk(self.beam_size)[1]
+        return top_ids, local_ids
+
+    @staticmethod
+    def merge_scores(
+        prev_scores: Dict[str, float],
+        next_full_scores: Dict[str, torch.Tensor],
+        full_idx: int,
+        next_part_scores: Dict[str, torch.Tensor],
+        part_idx: int,
+    ) -> Dict[str, torch.Tensor]:
+        """Merge scores for new hypothesis.
+
+        Args:
+            prev_scores (Dict[str, float]):
+                The previous hypothesis scores by `self.scorers`
+            next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
+            full_idx (int): The next token id for `next_full_scores`
+            next_part_scores (Dict[str, torch.Tensor]):
+                scores of partial tokens by `self.part_scorers`
+            part_idx (int): The new token id for `next_part_scores`
+
+        Returns:
+            Dict[str, torch.Tensor]: The new score dict.
+                Its keys are names of `self.full_scorers` and `self.part_scorers`.
+                Its values are scalar tensors by the scorers.
+
+        """
+        new_scores = dict()
+        for k, v in next_full_scores.items():
+            new_scores[k] = prev_scores[k] + v[full_idx]
+        for k, v in next_part_scores.items():
+            new_scores[k] = prev_scores[k] + v[part_idx]
+        return new_scores
+
+    def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
+        """Merge states for new hypothesis.
+
+        Args:
+            states: states of `self.full_scorers`
+            part_states: states of `self.part_scorers`
+            part_idx (int): The new token id for `part_scores`
+
+        Returns:
+            Dict[str, torch.Tensor]: The new score dict.
+                Its keys are names of `self.full_scorers` and `self.part_scorers`.
+                Its values are states of the scorers.
+
+        """
+        new_states = dict()
+        for k, v in states.items():
+            new_states[k] = v
+        for k, d in self.part_scorers.items():
+            new_states[k] = d.select_state(part_states[k], part_idx)
+        return new_states
+
+    def search(
+        self, running_hyps: List[Hypothesis],
+        x: torch.Tensor,
+        x_mask: torch.Tensor = None,
+        pre_acoustic_embeds: torch.Tensor = None,
+        cache: dict={},
+    ) -> List[Hypothesis]:
+        """Search new tokens for running hypotheses and encoded speech x.
+
+        Args:
+            running_hyps (List[Hypothesis]): Running hypotheses on beam
+            x (torch.Tensor): Encoded speech feature (T, D)
+
+        Returns:
+            List[Hypotheses]: Best sorted hypotheses
+
+        """
+        best_hyps = []
+        part_ids = torch.arange(self.n_vocab, device=x.device)  # no pre-beam
+        for hyp in running_hyps:
+            # scoring
+            weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
+            scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache)
+            for k in self.full_scorers:
+                weighted_scores += self.weights[k] * scores[k]
+            # partial scoring
+            if self.do_pre_beam:
+                pre_beam_scores = (
+                    weighted_scores
+                    if self.pre_beam_score_key == "full"
+                    else scores[self.pre_beam_score_key]
+                )
+                part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
+            part_scores, part_states = self.score_partial(hyp, part_ids, x)
+            for k in self.part_scorers:
+                weighted_scores[part_ids] += self.weights[k] * part_scores[k]
+            # add previous hyp score
+            weighted_scores += hyp.score
+
+            # update hyps
+            for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
+                # will be (2 x beam at most)
+                best_hyps.append(
+                    Hypothesis(
+                        score=weighted_scores[j],
+                        yseq=self.append_token(hyp.yseq, j),
+                        scores=self.merge_scores(
+                            hyp.scores, scores, j, part_scores, part_j
+                        ),
+                        states=self.merge_states(states, part_states, part_j),
+                    )
+                )
+
+            # sort and prune 2 x beam -> beam
+            best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
+                : min(len(best_hyps), self.beam_size)
+            ]
+        return best_hyps
+
+    def forward(
+        self, x: torch.Tensor,
+        scama_mask: torch.Tensor = None,
+        pre_acoustic_embeds: torch.Tensor = None,
+        maxlenratio: float = 0.0,
+        minlenratio: float = 0.0,
+        maxlen: int = None,
+        minlen: int = 0,
+        cache:dict={},
+    ) -> List[Hypothesis]:
+        """Perform beam search.
+
+        Args:
+            x (torch.Tensor): Encoded speech feature (T, D)
+            maxlenratio (float): Input length ratio to obtain max output length.
+                If maxlenratio=0.0 (default), it uses a end-detect function
+                to automatically find maximum hypothesis lengths
+                If maxlenratio<0.0, its absolute value is interpreted
+                as a constant max output length.
+            minlenratio (float): Input length ratio to obtain min output length.
+
+        Returns:
+            list[Hypothesis]: N-best decoding results
+
+        """
+        if maxlen is None:
+            # set length bounds
+            if maxlenratio == 0:
+                maxlen = x.shape[0]
+            elif maxlenratio < 0:
+                maxlen = -1 * int(maxlenratio)
+            else:
+                maxlen = max(1, int(maxlenratio * x.size(0)))
+            minlen = int(minlenratio * x.size(0))
+
+        logging.info("decoder input length: " + str(x.shape[0]))
+        logging.info("max output length: " + str(maxlen))
+        logging.info("min output length: " + str(minlen))
+
+        # main loop of prefix search
+        # running_hyps = self.init_hyp(x)
+        running_hyps = cache["running_hyps"]
+        ended_hyps = []
+        for i in range(maxlen):
+            logging.debug("position " + str(i))
+            mask_enc = None
+            # if scama_mask is not None:
+            #     token_num_predictor = scama_mask.size(1)
+            #     token_id_slice = min(i, token_num_predictor-1)
+            #     mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :]
+            #     # if mask_enc.size(1) == 0:
+            #     #     mask_enc = scama_mask[:, -2:-1, :]
+            #     #     # mask_enc = torch.zeros_like(mask_enc)
+            pre_acoustic_embeds_cur = None
+            if pre_acoustic_embeds is not None:
+                b, t, d = pre_acoustic_embeds.size()
+                pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device)
+                pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1)
+                token_id_slice = min(i, t)
+                pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :]
+
+            best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur, cache=cache["decoder"])
+            # post process of one iteration
+            running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
+            # end detection
+            if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
+                logging.info(f"end detected at {i}")
+                break
+            if len(running_hyps) == 0:
+                logging.info("no hypothesis. Finish decoding.")
+                break
+            else:
+                logging.debug(f"remained hypotheses: {len(running_hyps)}")
+
+        nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
+        # check the number of hypotheses reaching to eos
+        if len(nbest_hyps) == 0:
+            logging.warning(
+                "there is no N-best results, perform recognition "
+                "again with smaller minlenratio."
+            )
+            return (
+                []
+                if minlenratio < 0.1
+                else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
+            )
+
+        # report the best result
+        for x in nbest_hyps:
+            yseq = "".join([self.token_list[x] for x in x.yseq])
+            logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score))
+        best = nbest_hyps[0]
+        for k, v in best.scores.items():
+            logging.info(
+                f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
+            )
+        logging.info(f"total log probability: {best.score:.2f}")
+        logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
+        logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
+        if self.token_list is not None:
+            logging.info(
+                "best hypo: "
+                + "".join([self.token_list[x] for x in best.yseq[1:-1]])
+                + "\n"
+            )
+        return nbest_hyps
+
+    def post_process(
+        self,
+        i: int,
+        maxlen: int,
+        maxlenratio: float,
+        running_hyps: List[Hypothesis],
+        ended_hyps: List[Hypothesis],
+    ) -> List[Hypothesis]:
+        """Perform post-processing of beam search iterations.
+
+        Args:
+            i (int): The length of hypothesis tokens.
+            maxlen (int): The maximum length of tokens in beam search.
+            maxlenratio (int): The maximum length ratio in beam search.
+            running_hyps (List[Hypothesis]): The running hypotheses in beam search.
+            ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
+
+        Returns:
+            List[Hypothesis]: The new running hypotheses.
+
+        """
+        logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
+        if self.token_list is not None:
+            logging.debug(
+                "best hypo: "
+                + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
+            )
+        # add eos in the final loop to avoid that there are no ended hyps
+        if i == maxlen - 1:
+            logging.info("adding <eos> in the last position in the loop")
+            running_hyps = [
+                h._replace(yseq=self.append_token(h.yseq, self.eos))
+                for h in running_hyps
+            ]
+
+        # add ended hypotheses to a final list, and removed them from current hypotheses
+        # (this will be a problem, number of hyps < beam)
+        remained_hyps = []
+        for hyp in running_hyps:
+            if hyp.yseq[-1] == self.eos:
+                # e.g., Word LM needs to add final <eos> score
+                for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
+                    s = d.final_score(hyp.states[k])
+                    hyp.scores[k] += s
+                    hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
+                ended_hyps.append(hyp)
+            else:
+                remained_hyps.append(hyp)
+        return remained_hyps

+ 48 - 49
funasr/models/scama/model.py

@@ -436,7 +436,10 @@ class SCAMA(nn.Module):
     def init_beam_search(self,
                          **kwargs,
                          ):
-        from funasr.models.scama.beam_search import BeamSearchScama
+
+        from funasr.models.scama.beam_search import BeamSearchScamaStreaming
+
+
         from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
         from funasr.models.transformer.scorers.length_bonus import LengthBonus
     
@@ -460,13 +463,14 @@ class SCAMA(nn.Module):
         scorers["ngram"] = ngram
     
         weights = dict(
-            decoder=1.0 - kwargs.get("decoding_ctc_weight"),
+            decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
             ctc=kwargs.get("decoding_ctc_weight", 0.0),
             lm=kwargs.get("lm_weight", 0.0),
             ngram=kwargs.get("ngram_weight", 0.0),
             length_bonus=kwargs.get("penalty", 0.0),
         )
-        beam_search = BeamSearchScama(
+        
+        beam_search = BeamSearchScamaStreaming(
             beam_size=kwargs.get("beam_size", 2),
             weights=weights,
             scorers=scorers,
@@ -499,7 +503,11 @@ class SCAMA(nn.Module):
                                                           is_final=kwargs.get("is_final", False))
         if isinstance(encoder_out, tuple):
             encoder_out = encoder_out[0]
-
+        if "running_hyps" not in cache:
+            running_hyps = self.beam_search.init_hyp(encoder_out)
+            cache["running_hyps"] = running_hyps
+       
+       
         # predictor
         predictor_outs = self.calc_predictor_chunk(encoder_out,
                                                    encoder_out_lens,
@@ -513,47 +521,30 @@ class SCAMA(nn.Module):
 
         if torch.max(pre_token_length) < 1:
             return []
-        decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
-                                                             encoder_out_lens,
-                                                             pre_acoustic_embeds,
-                                                             pre_token_length,
-                                                             cache=cache
-                                                             )
-        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-    
+        maxlen = minlen = pre_token_length
+        if kwargs.get("is_final", False):
+            maxlen += kwargs.get("token_num_relax", 5)
+            minlen = max(0, minlen - kwargs.get("token_num_relax", 5))
+        # c. Passed the encoder result and the beam search
+        nbest_hyps = self.beam_search(
+            x=encoder_out[0], scama_mask=None, pre_acoustic_embeds=pre_acoustic_embeds, maxlen=int(maxlen), minlen=int(minlen), cache=cache,
+        )
+
+        cache["running_hyps"] = nbest_hyps
+        nbest_hyps = nbest_hyps[: self.nbest]
+
         results = []
-        b, n, d = decoder_out.size()
-        if isinstance(key[0], (list, tuple)):
-            key = key[0]
-        for i in range(b):
-            x = encoder_out[i, :encoder_out_lens[i], :]
-            am_scores = decoder_out[i, :pre_token_length[i], :]
-            if self.beam_search is not None:
-                nbest_hyps = self.beam_search(
-                    x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
-                    minlenratio=kwargs.get("minlenratio", 0.0)
-                )
-            
-                nbest_hyps = nbest_hyps[: self.nbest]
+        for hyp in nbest_hyps:
+            # assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+            # remove sos/eos and get results
+            last_pos = -1
+            if isinstance(hyp.yseq, list):
+                token_int = hyp.yseq[1:last_pos]
             else:
-            
-                yseq = am_scores.argmax(dim=-1)
-                score = am_scores.max(dim=-1)[0]
-                score = torch.sum(score, dim=-1)
-                # pad with mask tokens to ensure compatibility with sos/eos tokens
-                yseq = torch.tensor(
-                    [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
-                )
-                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-            for nbest_idx, hyp in enumerate(nbest_hyps):
-            
-                # remove sos/eos and get results
-                last_pos = -1
-                if isinstance(hyp.yseq, list):
-                    token_int = hyp.yseq[1:last_pos]
-                else:
-                    token_int = hyp.yseq[1:last_pos].tolist()
-            
+                token_int = hyp.yseq[1:last_pos].tolist()
+
+
                 # remove blank symbol id, which is assumed to be 0
                 token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
             
@@ -568,6 +559,8 @@ class SCAMA(nn.Module):
         return results
 
     def init_cache(self, cache: dict = {}, **kwargs):
+        device = kwargs.get("device", "cuda")
+        
         chunk_size = kwargs.get("chunk_size", [0, 10, 5])
         encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
         decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
@@ -575,10 +568,11 @@ class SCAMA(nn.Module):
     
         enc_output_size = kwargs["encoder_conf"]["output_size"]
         feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
-        cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
-                         "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
+
+        cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)).to(device=device),
+                         "cif_alphas": torch.zeros((batch_size, 1)).to(device=device), "chunk_size": chunk_size,
                          "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
-                         "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
+                         "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)).to(device=device),
                          "tail_chunk": False}
         cache["encoder"] = cache_encoder
     
@@ -586,8 +580,10 @@ class SCAMA(nn.Module):
                          "chunk_size": chunk_size}
         cache["decoder"] = cache_decoder
         cache["frontend"] = {}
-        cache["prev_samples"] = torch.empty(0)
-    
+
+
+        cache["prev_samples"] = torch.empty(0).to(device=device)
+
         return cache
 
     def inference(self,
@@ -603,7 +599,10 @@ class SCAMA(nn.Module):
         # init beamsearch
         is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
         is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
-        if self.beam_search is None and (is_use_lm or is_use_ctc):
+
+        if self.beam_search is None:
+
+
             logging.info("enable beam_search")
             self.init_beam_search(**kwargs)
             self.nbest = kwargs.get("nbest", 1)

+ 5 - 3
funasr/train_utils/trainer.py

@@ -148,6 +148,7 @@ class Trainer:
             
             self._train_epoch(epoch)
 
+
             
             if self.use_ddp or self.use_fsdp:
                 dist.barrier()
@@ -156,8 +157,8 @@ class Trainer:
 
             if self.use_ddp or self.use_fsdp:
                 dist.barrier()
-                
-
+           
+           
             if self.rank == 0:
                 self._save_checkpoint(epoch)
             
@@ -172,7 +173,8 @@ class Trainer:
             
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-      
+
+
         if self.writer:
             self.writer.close()