Răsfoiți Sursa

rnnt support wav input

aky15 2 ani în urmă
părinte
comite
71f1059af9
2 a modificat fișierele cu 18 adăugiri și 13 ștergeri
  1. 16 6
      funasr/bin/asr_infer.py
  2. 2 7
      funasr/tasks/asr.py

+ 16 - 6
funasr/bin/asr_infer.py

@@ -1510,8 +1510,13 @@ class Speech2TextTransducer:
         if isinstance(speech, np.ndarray):
             speech = torch.tensor(speech)
         
-        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
-        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+        if self.frontend is not None:
+            speech = torch.unsqueeze(speech, axis=0)
+            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+            feats, feats_lengths = self.frontend(speech, speech_lengths)
+        else:                
+            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
         
         if self.asr_model.normalize is not None:
             feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
@@ -1536,14 +1541,19 @@ class Speech2TextTransducer:
         
         if isinstance(speech, np.ndarray):
             speech = torch.tensor(speech)
-        
-        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
-        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+        if self.frontend is not None:
+            speech = torch.unsqueeze(speech, axis=0)
+            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+            feats, feats_lengths = self.frontend(speech, speech_lengths)
+        else:                
+            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
         
         feats = to_device(feats, device=self.device)
         feats_lengths = to_device(feats_lengths, device=self.device)
         
-        enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
+        enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
         
         nbest_hyps = self.beam_search(enc_out[0])
         

+ 2 - 7
funasr/tasks/asr.py

@@ -363,12 +363,6 @@ class ASRTask(AbsTask):
             default=get_default_kwargs(CTC),
             help="The keyword arguments for CTC class.",
         )
-        group.add_argument(
-            "--joint_network_conf",
-            action=NestedDictAction,
-            default=None,
-            help="The keyword arguments for joint network class.",
-        )
 
         group = parser.add_argument_group(description="Preprocess related")
         group.add_argument(
@@ -1379,6 +1373,7 @@ class ASRTransducerTask(ASRTask):
     num_optimizers: int = 1
 
     class_choices_list = [
+        model_choices,
         frontend_choices,
         specaug_choices,
         normalize_choices,
@@ -1476,7 +1471,7 @@ class ASRTransducerTask(ASRTask):
         try:
             model_class = model_choices.get_class(args.model)
         except AttributeError:
-            model_class = model_choices.get_class("asr")
+            model_class = model_choices.get_class("rnnt_unified")
 
         model = model_class(
             vocab_size=vocab_size,