shixian.shi 2 роки тому
батько
коміт
571dc8b55a

+ 8 - 5
examples/industrial_data_pretraining/paraformer/demo.py

@@ -5,11 +5,14 @@
 
 
 from funasr import AutoModel
 from funasr import AutoModel
 
 
-model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4",
-                  # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
-                  # vad_model_revision="v2.0.4",
-                  # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
-                  # punc_model_revision="v2.0.4",
+model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", 
+                  model_revision="v2.0.4",
+                  vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+                  vad_model_revision="v2.0.4",
+                  punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
+                  punc_model_revision="v2.0.4",
+                  # spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
+                  # spk_model_revision="v2.0.2",
                   )
                   )
 
 
 res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
 res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")

+ 18 - 7
funasr/auto/auto_model.py

@@ -1,14 +1,13 @@
 import json
 import json
 import time
 import time
+import copy
 import torch
 import torch
-import hydra
 import random
 import random
 import string
 import string
 import logging
 import logging
 import os.path
 import os.path
 import numpy as np
 import numpy as np
 from tqdm import tqdm
 from tqdm import tqdm
-from omegaconf import DictConfig, OmegaConf, ListConfig
 
 
 from funasr.register import tables
 from funasr.register import tables
 from funasr.utils.load_utils import load_bytes
 from funasr.utils.load_utils import load_bytes
@@ -17,7 +16,7 @@ from funasr.download.download_from_hub import download_model
 from funasr.utils.vad_utils import slice_padding_audio_samples
 from funasr.utils.vad_utils import slice_padding_audio_samples
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.utils.load_utils import load_audio_text_image_video
 from funasr.utils.timestamp_tools import timestamp_sentence
 from funasr.utils.timestamp_tools import timestamp_sentence
 from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
 from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
 try:
 try:
@@ -385,11 +384,15 @@ class AutoModel:
             if self.punc_model is not None:
             if self.punc_model is not None:
                 self.punc_kwargs.update(cfg)
                 self.punc_kwargs.update(cfg)
                 punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, disable_pbar=True, **cfg)
                 punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, disable_pbar=True, **cfg)
-                import copy; raw_text = copy.copy(result["text"])
+                raw_text = copy.copy(result["text"])
                 result["text"] = punc_res[0]["text"]
                 result["text"] = punc_res[0]["text"]
+            else:
+                raw_text = None
                 
                 
             # speaker embedding cluster after resorted
             # speaker embedding cluster after resorted
             if self.spk_model is not None and kwargs.get('return_spk_res', True):
             if self.spk_model is not None and kwargs.get('return_spk_res', True):
+                if raw_text is None:
+                    logging.error("Missing punc_model, which is required by spk_model.")
                 all_segments = sorted(all_segments, key=lambda x: x[0])
                 all_segments = sorted(all_segments, key=lambda x: x[0])
                 spk_embedding = result['spk_embedding']
                 spk_embedding = result['spk_embedding']
                 labels = self.cb_model(spk_embedding.cpu(), oracle_num=kwargs.get('preset_spk_num', None))
                 labels = self.cb_model(spk_embedding.cpu(), oracle_num=kwargs.get('preset_spk_num', None))
@@ -398,20 +401,28 @@ class AutoModel:
                 if self.spk_mode == 'vad_segment':  # recover sentence_list
                 if self.spk_mode == 'vad_segment':  # recover sentence_list
                     sentence_list = []
                     sentence_list = []
                     for res, vadsegment in zip(restored_data, vadsegments):
                     for res, vadsegment in zip(restored_data, vadsegments):
+                        if 'timestamp' not in res:
+                            logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
+                                and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
+                                can predict timestamp, and speaker diarization relies on timestamps.")
                         sentence_list.append({"start": vadsegment[0],\
                         sentence_list.append({"start": vadsegment[0],\
                                                 "end": vadsegment[1],
                                                 "end": vadsegment[1],
-                                                "sentence": res['raw_text'],
+                                                "sentence": res['text'],
                                                 "timestamp": res['timestamp']})
                                                 "timestamp": res['timestamp']})
                 elif self.spk_mode == 'punc_segment':
                 elif self.spk_mode == 'punc_segment':
+                    if 'timestamp' not in result:
+                        logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
+                            and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
+                            can predict timestamp, and speaker diarization relies on timestamps.")
                     sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
                     sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
                                                         result['timestamp'], \
                                                         result['timestamp'], \
-                                                        result['raw_text'])
+                                                        raw_text)
                 distribute_spk(sentence_list, sv_output)
                 distribute_spk(sentence_list, sv_output)
                 result['sentence_info'] = sentence_list
                 result['sentence_info'] = sentence_list
             elif kwargs.get("sentence_timestamp", False):
             elif kwargs.get("sentence_timestamp", False):
                 sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
                 sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
                                                         result['timestamp'], \
                                                         result['timestamp'], \
-                                                        result['raw_text'])
+                                                        raw_text)
                 result['sentence_info'] = sentence_list
                 result['sentence_info'] = sentence_list
             if "spk_embedding" in result: del result['spk_embedding']
             if "spk_embedding" in result: del result['spk_embedding']
                     
                     

+ 1 - 2
funasr/models/seaco_paraformer/model.py

@@ -415,12 +415,11 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
                         token, timestamp)
                         token, timestamp)
 
 
                     result_i = {"key": key[i], "text": text_postprocessed,
                     result_i = {"key": key[i], "text": text_postprocessed,
-                                "timestamp": time_stamp_postprocessed, "raw_text": copy.copy(text_postprocessed)
+                                "timestamp": time_stamp_postprocessed
                                 }
                                 }
                     
                     
                     if ibest_writer is not None:
                     if ibest_writer is not None:
                         ibest_writer["token"][key[i]] = " ".join(token)
                         ibest_writer["token"][key[i]] = " ".join(token)
-                        # ibest_writer["raw_text"][key[i]] = text
                         ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
                         ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
                         ibest_writer["text"][key[i]] = text_postprocessed
                         ibest_writer["text"][key[i]] = text_postprocessed
                 else:
                 else: