|
|
@@ -31,7 +31,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
|
|
from funasr.utils import config_argparse
|
|
|
from funasr.utils.types import str2bool, str2triple_str, str_or_none
|
|
|
from funasr.utils.cli_utils import get_commandline_args
|
|
|
-
|
|
|
+from funasr.models.frontend.wav_frontend import WavFrontend
|
|
|
|
|
|
class Speech2Text:
|
|
|
"""Speech2Text class for Transducer models.
|
|
|
@@ -62,6 +62,7 @@ class Speech2Text:
|
|
|
self,
|
|
|
asr_train_config: Union[Path, str] = None,
|
|
|
asr_model_file: Union[Path, str] = None,
|
|
|
+ cmvn_file: Union[Path, str] = None,
|
|
|
beam_search_config: Dict[str, Any] = None,
|
|
|
lm_train_config: Union[Path, str] = None,
|
|
|
lm_file: Union[Path, str] = None,
|
|
|
@@ -86,11 +87,14 @@ class Speech2Text:
|
|
|
super().__init__()
|
|
|
|
|
|
assert check_argument_types()
|
|
|
-
|
|
|
asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
|
|
|
- asr_train_config, asr_model_file, device
|
|
|
+ asr_train_config, asr_model_file, cmvn_file, device
|
|
|
)
|
|
|
|
|
|
+ frontend = None
|
|
|
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
|
|
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
|
|
+
|
|
|
if quantize_asr_model:
|
|
|
if quantize_modules is not None:
|
|
|
if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
|
|
|
@@ -156,7 +160,7 @@ class Speech2Text:
|
|
|
tokenizer = build_tokenizer(token_type=token_type)
|
|
|
converter = TokenIDConverter(token_list=token_list)
|
|
|
logging.info(f"Text tokenizer: {tokenizer}")
|
|
|
-
|
|
|
+
|
|
|
self.asr_model = asr_model
|
|
|
self.asr_train_args = asr_train_args
|
|
|
self.device = device
|
|
|
@@ -181,23 +185,13 @@ class Speech2Text:
|
|
|
self.simu_streaming = False
|
|
|
self.asr_model.encoder.dynamic_chunk_training = False
|
|
|
|
|
|
- self.n_fft = asr_train_args.frontend_conf.get("n_fft", 512)
|
|
|
- self.hop_length = asr_train_args.frontend_conf.get("hop_length", 128)
|
|
|
-
|
|
|
- if asr_train_args.frontend_conf.get("win_length", None) is not None:
|
|
|
- self.frontend_window_size = asr_train_args.frontend_conf["win_length"]
|
|
|
- else:
|
|
|
- self.frontend_window_size = self.n_fft
|
|
|
-
|
|
|
+ self.frontend = frontend
|
|
|
self.window_size = self.chunk_size + self.right_context
|
|
|
- self._raw_ctx = self.asr_model.encoder.get_encoder_input_raw_size(
|
|
|
- self.window_size, self.hop_length
|
|
|
- )
|
|
|
+
|
|
|
self._ctx = self.asr_model.encoder.get_encoder_input_size(
|
|
|
self.window_size
|
|
|
)
|
|
|
|
|
|
-
|
|
|
#self.last_chunk_length = (
|
|
|
# self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
|
|
|
#) * self.hop_length
|
|
|
@@ -218,112 +212,6 @@ class Speech2Text:
|
|
|
|
|
|
self.num_processed_frames = torch.tensor([[0]], device=self.device)
|
|
|
|
|
|
- def apply_frontend(
|
|
|
- self, speech: torch.Tensor, is_final: bool = False
|
|
|
- ) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
- """Forward frontend.
|
|
|
- Args:
|
|
|
- speech: Speech data. (S)
|
|
|
- is_final: Whether speech corresponds to the final (or only) chunk of data.
|
|
|
- Returns:
|
|
|
- feats: Features sequence. (1, T_in, F)
|
|
|
- feats_lengths: Features sequence length. (1, T_in, F)
|
|
|
- """
|
|
|
- if self.frontend_cache is not None:
|
|
|
- speech = torch.cat([self.frontend_cache["waveform_buffer"], speech], dim=0)
|
|
|
-
|
|
|
- if is_final:
|
|
|
- if self.streaming and speech.size(0) < self.last_chunk_length:
|
|
|
- pad = torch.zeros(
|
|
|
- self.last_chunk_length - speech.size(0), dtype=speech.dtype
|
|
|
- )
|
|
|
- speech = torch.cat([speech, pad], dim=0)
|
|
|
-
|
|
|
- speech_to_process = speech
|
|
|
- waveform_buffer = None
|
|
|
- else:
|
|
|
- n_frames = (
|
|
|
- speech.size(0) - (self.frontend_window_size - self.hop_length)
|
|
|
- ) // self.hop_length
|
|
|
-
|
|
|
- n_residual = (
|
|
|
- speech.size(0) - (self.frontend_window_size - self.hop_length)
|
|
|
- ) % self.hop_length
|
|
|
-
|
|
|
- speech_to_process = speech.narrow(
|
|
|
- 0,
|
|
|
- 0,
|
|
|
- (self.frontend_window_size - self.hop_length)
|
|
|
- + n_frames * self.hop_length,
|
|
|
- )
|
|
|
-
|
|
|
- waveform_buffer = speech.narrow(
|
|
|
- 0,
|
|
|
- speech.size(0)
|
|
|
- - (self.frontend_window_size - self.hop_length)
|
|
|
- - n_residual,
|
|
|
- (self.frontend_window_size - self.hop_length) + n_residual,
|
|
|
- ).clone()
|
|
|
-
|
|
|
- speech_to_process = speech_to_process.unsqueeze(0).to(
|
|
|
- getattr(torch, self.dtype)
|
|
|
- )
|
|
|
- lengths = speech_to_process.new_full(
|
|
|
- [1], dtype=torch.long, fill_value=speech_to_process.size(1)
|
|
|
- )
|
|
|
- batch = {"speech": speech_to_process, "speech_lengths": lengths}
|
|
|
- batch = to_device(batch, device=self.device)
|
|
|
-
|
|
|
- feats, feats_lengths = self.asr_model._extract_feats(**batch)
|
|
|
- if self.asr_model.normalize is not None:
|
|
|
- feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
|
|
|
-
|
|
|
- if is_final:
|
|
|
- if self.frontend_cache is None:
|
|
|
- pass
|
|
|
- else:
|
|
|
- feats = feats.narrow(
|
|
|
- 1,
|
|
|
- math.ceil(
|
|
|
- math.ceil(self.frontend_window_size / self.hop_length) / 2
|
|
|
- ),
|
|
|
- feats.size(1)
|
|
|
- - math.ceil(
|
|
|
- math.ceil(self.frontend_window_size / self.hop_length) / 2
|
|
|
- ),
|
|
|
- )
|
|
|
- else:
|
|
|
- if self.frontend_cache is None:
|
|
|
- feats = feats.narrow(
|
|
|
- 1,
|
|
|
- 0,
|
|
|
- feats.size(1)
|
|
|
- - math.ceil(
|
|
|
- math.ceil(self.frontend_window_size / self.hop_length) / 2
|
|
|
- ),
|
|
|
- )
|
|
|
- else:
|
|
|
- feats = feats.narrow(
|
|
|
- 1,
|
|
|
- math.ceil(
|
|
|
- math.ceil(self.frontend_window_size / self.hop_length) / 2
|
|
|
- ),
|
|
|
- feats.size(1)
|
|
|
- - 2
|
|
|
- * math.ceil(
|
|
|
- math.ceil(self.frontend_window_size / self.hop_length) / 2
|
|
|
- ),
|
|
|
- )
|
|
|
-
|
|
|
- feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
|
|
|
-
|
|
|
- if is_final:
|
|
|
- self.frontend_cache = None
|
|
|
- else:
|
|
|
- self.frontend_cache = {"waveform_buffer": waveform_buffer}
|
|
|
-
|
|
|
- return feats, feats_lengths
|
|
|
-
|
|
|
@torch.no_grad()
|
|
|
def streaming_decode(
|
|
|
self,
|
|
|
@@ -410,14 +298,9 @@ class Speech2Text:
|
|
|
if isinstance(speech, np.ndarray):
|
|
|
speech = torch.tensor(speech)
|
|
|
|
|
|
- # lengths: (1,)
|
|
|
- # feats, feats_length = self.apply_frontend(speech)
|
|
|
feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
|
|
|
- # lengths: (1,)
|
|
|
feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
|
|
|
|
|
|
- # print(feats.shape)
|
|
|
- # print(feats_lengths)
|
|
|
if self.asr_model.normalize is not None:
|
|
|
feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
|
|
|
|
|
|
@@ -495,6 +378,7 @@ def inference(
|
|
|
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
|
|
|
asr_train_config: Optional[str],
|
|
|
asr_model_file: Optional[str],
|
|
|
+ cmvn_file: Optional[str],
|
|
|
beam_search_config: Optional[dict],
|
|
|
lm_train_config: Optional[str],
|
|
|
lm_file: Optional[str],
|
|
|
@@ -562,7 +446,6 @@ def inference(
|
|
|
device = "cuda"
|
|
|
else:
|
|
|
device = "cpu"
|
|
|
-
|
|
|
# 1. Set random-seed
|
|
|
set_all_random_seed(seed)
|
|
|
|
|
|
@@ -570,6 +453,7 @@ def inference(
|
|
|
speech2text_kwargs = dict(
|
|
|
asr_train_config=asr_train_config,
|
|
|
asr_model_file=asr_model_file,
|
|
|
+ cmvn_file=cmvn_file,
|
|
|
beam_search_config=beam_search_config,
|
|
|
lm_train_config=lm_train_config,
|
|
|
lm_file=lm_file,
|
|
|
@@ -719,6 +603,11 @@ def get_parser():
|
|
|
type=str,
|
|
|
help="ASR model parameter file",
|
|
|
)
|
|
|
+ group.add_argument(
|
|
|
+ "--cmvn_file",
|
|
|
+ type=str,
|
|
|
+ help="Global cmvn file",
|
|
|
+ )
|
|
|
group.add_argument(
|
|
|
"--lm_train_config",
|
|
|
type=str,
|