Просмотр исходного кода

Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add

游雁 3 лет назад
Родитель
Сommit
de264be093

+ 19 - 0
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md

@@ -0,0 +1,19 @@
+# ModelScope Model
+
+## How to infer using a pretrained Paraformer-large Model
+
+### Inference
+
+You can use the pretrain model for inference directly.
+
+- Setting parameters in `infer.py`
+    - <strong>audio_in:</strong> # Support wav, url, bytes, and parsed audio format.
+    - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
+    - <strong>batch_size:</strong> # Set batch size in inference.
+    - <strong>param_dict:</strong> # Set the hotword list in inference.
+
+- Then you can run the pipeline to infer with:
+```python
+    python infer.py
+```
+

+ 21 - 0
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.py

@@ -0,0 +1,21 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+
+if __name__ == '__main__':
+    param_dict = dict()
+    param_dict['hotword'] = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/hotword.txt"
+
+    audio_in = "//isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_hotword.wav"
+    output_dir = None
+    batch_size = 1
+
+    inference_pipeline = pipeline(
+        task=Tasks.auto_speech_recognition,
+        model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
+        output_dir=output_dir,
+        batch_size=batch_size,
+        param_dict=param_dict)
+
+    rec_result = inference_pipeline(audio_in=audio_in)
+    print(rec_result)

+ 36 - 9
funasr/bin/asr_inference_paraformer.py

@@ -6,6 +6,8 @@ import time
 import copy
 import os
 import codecs
+import tempfile
+import requests
 from pathlib import Path
 from typing import Optional
 from typing import Sequence
@@ -175,10 +177,24 @@ class Speech2Text:
         self.converter = converter
         self.tokenizer = tokenizer
 
-        # 6. [Optional] Build hotword list from file or str
+        # 6. [Optional] Build hotword list from str, local file or url
+        # for None 
         if hotword_list_or_file is None:
             self.hotword_list = None
+        # for text str input
+        elif not os.path.exists(hotword_list_or_file) and not hotword_list_or_file.startswith('http'):
+            logging.info("Attempting to parse hotwords as str...")
+            self.hotword_list = []
+            hotword_str_list = []
+            for hw in hotword_list_or_file.strip().split():
+                hotword_str_list.append(hw)
+                self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+            self.hotword_list.append([self.asr_model.sos])
+            hotword_str_list.append('<s>')
+            logging.info("Hotword list: {}.".format(hotword_str_list))
+        # for local txt inputs
         elif os.path.exists(hotword_list_or_file):
+            logging.info("Attempting to parse hotwords from local txt...")
             self.hotword_list = []
             hotword_str_list = []
             with codecs.open(hotword_list_or_file, 'r') as fin:
@@ -186,20 +202,31 @@ class Speech2Text:
                     hw = line.strip()
                     hotword_str_list.append(hw)
                     self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
-                self.hotword_list.append([1])
+                self.hotword_list.append([self.asr_model.sos])
                 hotword_str_list.append('<s>')
             logging.info("Initialized hotword list from file: {}, hotword list: {}."
                 .format(hotword_list_or_file, hotword_str_list))
+        # for url, download and generate txt
         else:
-            logging.info("Attempting to parse hotwords as str...")
+            logging.info("Attempting to parse hotwords from url...")
+            work_dir = tempfile.TemporaryDirectory().name
+            if not os.path.exists(work_dir):
+                os.makedirs(work_dir)
+            text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
+            local_file = requests.get(hotword_list_or_file)
+            open(text_file_path, "wb").write(local_file.content)
+            hotword_list_or_file = text_file_path
             self.hotword_list = []
             hotword_str_list = []
-            for hw in hotword_list_or_file.strip().split():
-                hotword_str_list.append(hw)
-                self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
-            self.hotword_list.append([1])
-            hotword_str_list.append('<s>')
-            logging.info("Hotword list: {}.".format(hotword_str_list))
+            with codecs.open(hotword_list_or_file, 'r') as fin:
+                for line in fin.readlines():
+                    hw = line.strip()
+                    hotword_str_list.append(hw)
+                    self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                self.hotword_list.append([self.asr_model.sos])
+                hotword_str_list.append('<s>')
+            logging.info("Initialized hotword list from file: {}, hotword list: {}."
+                .format(hotword_list_or_file, hotword_str_list))
 
 
         is_use_lm = lm_weight != 0.0 and lm_file is not None