Просмотр исходного кода

Update asr_inference_paraformer_streaming.py

hnluo 2 лет назад
Родитель
Сommit
c5992ca03e
1 измененных файлов с 4 добавлено и 4 удалено
  1. 4 4
      funasr/bin/asr_inference_paraformer_streaming.py

+ 4 - 4
funasr/bin/asr_inference_paraformer_streaming.py

@@ -203,7 +203,7 @@ class Speech2Text:
         results = []
         cache_en = cache["encoder"]
         if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
-            cache_en["last_chunk"] = True
+            cache_en["tail_chunk"] = True
             feats = cache_en["feats"]
             feats_len = torch.tensor([feats.shape[1]])
         else:
@@ -232,7 +232,7 @@ class Speech2Text:
                         feats_len = torch.tensor([feats_chunk2.shape[1]])
                         results_chunk2 = self.infer(feats_chunk2, feats_len, cache)
 
-                        return results_chunk1 + results_chunk2
+                        return ["".join(results_chunk1 + results_chunk2)]
 
                 results = self.infer(feats, feats_len, cache)
 
@@ -466,7 +466,7 @@ def inference_modelscope(
 
         cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
                     "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
-                    "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))}
+                    "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False}
         cache["encoder"] = cache_en
 
         cache_de = {"decode_fsmn": None}
@@ -478,7 +478,7 @@ def inference_modelscope(
         if len(cache) > 0:
             cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
                         "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
-                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))}
+                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False}
             cache["encoder"] = cache_en
 
             cache_de = {"decode_fsmn": None}