游雁 před 2 roky
rodič
revize
f0fdc051fb

+ 2 - 2
funasr/export/models/CT_Transformer.py

@@ -10,7 +10,7 @@ from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadE
 
 class CT_Transformer(nn.Module):
     """
-    Author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
     https://arxiv.org/pdf/2003.01309.pdf
     """
@@ -81,7 +81,7 @@ class CT_Transformer(nn.Module):
 
 class CT_Transformer_VadRealtime(nn.Module):
     """
-    Author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
     https://arxiv.org/pdf/2003.01309.pdf
     """

+ 2 - 2
funasr/export/models/e2e_asr_paraformer.py

@@ -19,7 +19,7 @@ from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSA
 
 class Paraformer(nn.Module):
     """
-    Author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
     https://arxiv.org/abs/2206.08317
     """
@@ -112,7 +112,7 @@ class Paraformer(nn.Module):
 
 class BiCifParaformer(nn.Module):
     """
-    Author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
     https://arxiv.org/abs/2206.08317
     """

+ 1 - 1
funasr/models/decoder/contextual_decoder.py

@@ -102,7 +102,7 @@ class ContextualBiasDecoder(nn.Module):
 
 class ContextualParaformerDecoder(ParaformerSANMDecoder):
     """
-    author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
     https://arxiv.org/abs/2006.01713
     """

+ 2 - 2
funasr/models/decoder/sanm_decoder.py

@@ -151,7 +151,7 @@ class DecoderLayerSANM(nn.Module):
 
 class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
     """
-    author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
     https://arxiv.org/abs/2006.01713
 
@@ -812,7 +812,7 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
 
 class ParaformerSANMDecoder(BaseTransformerDecoder):
     """
-    author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
     https://arxiv.org/abs/2006.01713
     """

+ 1 - 1
funasr/models/decoder/transformer_decoder.py

@@ -405,7 +405,7 @@ class TransformerDecoder(BaseTransformerDecoder):
 
 class ParaformerDecoderSAN(BaseTransformerDecoder):
     """
-    author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
     https://arxiv.org/abs/2006.01713
     """

+ 2 - 2
funasr/models/e2e_asr_paraformer.py

@@ -44,7 +44,7 @@ else:
 
 class Paraformer(AbsESPnetModel):
     """
-    Author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
     https://arxiv.org/abs/2206.08317
     """
@@ -612,7 +612,7 @@ class Paraformer(AbsESPnetModel):
 
 class ParaformerBert(Paraformer):
     """
-    Author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
     """
 

+ 1 - 1
funasr/models/e2e_tp.py

@@ -32,7 +32,7 @@ else:
 
 class TimestampPredictor(AbsESPnetModel):
     """
-    Author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     """
 
     def __init__(

+ 1 - 1
funasr/models/e2e_uni_asr.py

@@ -40,7 +40,7 @@ else:
 
 class UniASR(AbsESPnetModel):
     """
-    Author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     """
 
     def __init__(

+ 25 - 0
funasr/models/e2e_vad.py

@@ -35,6 +35,11 @@ class VadDetectMode(Enum):
 
 
 class VADXOptions:
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
     def __init__(
             self,
             sample_rate: int = 16000,
@@ -99,6 +104,11 @@ class VADXOptions:
 
 
 class E2EVadSpeechBufWithDoa(object):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
     def __init__(self):
         self.start_ms = 0
         self.end_ms = 0
@@ -117,6 +127,11 @@ class E2EVadSpeechBufWithDoa(object):
 
 
 class E2EVadFrameProb(object):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
     def __init__(self):
         self.noise_prob = 0.0
         self.speech_prob = 0.0
@@ -126,6 +141,11 @@ class E2EVadFrameProb(object):
 
 
 class WindowDetector(object):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
     def __init__(self, window_size_ms: int, sil_to_speech_time: int,
                  speech_to_sil_time: int, frame_size_ms: int):
         self.window_size_ms = window_size_ms
@@ -192,6 +212,11 @@ class WindowDetector(object):
 
 
 class E2EVadModel(nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
     def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
         super(E2EVadModel, self).__init__()
         self.vad_opts = VADXOptions(**vad_post_args)

+ 1 - 1
funasr/models/encoder/opennmt_encoders/conv_encoder.py

@@ -67,7 +67,7 @@ class EncoderLayer(nn.Module):
 
 class ConvEncoder(AbsEncoder):
     """
-    author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     Convolution encoder in OpenNMT framework
     """
 

+ 1 - 1
funasr/models/encoder/opennmt_encoders/self_attention_encoder.py

@@ -117,7 +117,7 @@ class EncoderLayer(nn.Module):
 
 class SelfAttentionEncoder(AbsEncoder):
     """
-    author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     Self attention encoder in OpenNMT framework
     """
 

+ 3 - 3
funasr/models/encoder/sanm_encoder.py

@@ -117,7 +117,7 @@ class EncoderLayerSANM(nn.Module):
 
 class SANMEncoder(AbsEncoder):
     """
-    author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     San-m: Memory equipped self-attention for end-to-end speech recognition
     https://arxiv.org/abs/2006.01713
 
@@ -549,7 +549,7 @@ class SANMEncoder(AbsEncoder):
 
 class SANMEncoderChunkOpt(AbsEncoder):
     """
-    author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
     https://arxiv.org/abs/2006.01713
 
@@ -962,7 +962,7 @@ class SANMEncoderChunkOpt(AbsEncoder):
 
 class SANMVadEncoder(AbsEncoder):
     """
-    author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
 
     """
 

+ 1 - 1
funasr/models/target_delay_transformer.py

@@ -14,7 +14,7 @@ from funasr.train.abs_model import AbsPunctuation
 
 class TargetDelayTransformer(AbsPunctuation):
     """
-    Author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
     https://arxiv.org/pdf/2003.01309.pdf
     """

+ 1 - 1
funasr/models/vad_realtime_transformer.py

@@ -12,7 +12,7 @@ from funasr.train.abs_model import AbsPunctuation
 
 class VadRealtimeTransformer(AbsPunctuation):
     """
-    Author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
     https://arxiv.org/pdf/2003.01309.pdf
     """

+ 1 - 1
funasr/modules/streaming_utils/chunk_utilis.py

@@ -11,7 +11,7 @@ from funasr.modules.streaming_utils.utils import sequence_mask
 
 class overlap_chunk():
 	"""
-	author: Speech Lab, Alibaba Group, China
+	Author: Speech Lab of DAMO Academy, Alibaba Group
 	San-m: Memory equipped self-attention for end-to-end speech recognition
 	https://arxiv.org/abs/2006.01713
 

+ 1 - 1
funasr/runtime/python/onnxruntime/demo_vad_offline.py

@@ -1,5 +1,5 @@
 import soundfile
-from funasr_onnx.vad_bin import Fsmn_vad
+from funasr_onnx import Fsmn_vad
 
 
 model_dir = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"

+ 2 - 2
funasr/runtime/python/onnxruntime/demo_vad_online.py

@@ -1,10 +1,10 @@
 import soundfile
-from funasr_onnx.vad_online_bin import Fsmn_vad
+from funasr_onnx import Fsmn_vad_online
 
 
 model_dir = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
 wav_path = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/vad_example_16k.wav"
-model = Fsmn_vad(model_dir)
+model = Fsmn_vad_online(model_dir)
 
 
 ##online vad

+ 1 - 0
funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py

@@ -1,5 +1,6 @@
 # -*- encoding: utf-8 -*-
 from .paraformer_bin import Paraformer
 from .vad_bin import Fsmn_vad
+from .vad_bin import Fsmn_vad_online
 from .punc_bin import CT_Transformer
 from .punc_bin import CT_Transformer_VadRealtime

+ 2 - 2
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py

@@ -14,7 +14,7 @@ logging = get_logger()
 
 class CT_Transformer():
     """
-    Author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
     https://arxiv.org/pdf/2003.01309.pdf
     """
@@ -125,7 +125,7 @@ class CT_Transformer():
 
 class CT_Transformer_VadRealtime(CT_Transformer):
     """
-    Author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
     https://arxiv.org/pdf/2003.01309.pdf
     """

+ 127 - 1
funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py

@@ -11,13 +11,18 @@ import numpy as np
 from .utils.utils import (ONNXRuntimeError,
                           OrtInferSession, get_logger,
                           read_yaml)
-from .utils.frontend import WavFrontend
+from .utils.frontend import WavFrontend, WavFrontendOnline
 from .utils.e2e_vad import E2EVadModel
 
 logging = get_logger()
 
 
 class Fsmn_vad():
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+	https://arxiv.org/abs/1803.05030
+	"""
 	def __init__(self, model_dir: Union[str, Path] = None,
 	             batch_size: int = 1,
 	             device_id: Union[str, int] = "-1",
@@ -151,4 +156,125 @@ class Fsmn_vad():
 		outputs = self.ort_infer(feats)
 		scores, out_caches = outputs[0], outputs[1:]
 		return scores, out_caches
+
+
+class Fsmn_vad_online():
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+	https://arxiv.org/abs/1803.05030
+	"""
+	def __init__(self, model_dir: Union[str, Path] = None,
+	             batch_size: int = 1,
+	             device_id: Union[str, int] = "-1",
+	             quantize: bool = False,
+	             intra_op_num_threads: int = 4,
+	             max_end_sil: int = None,
+	             ):
+		
+		if not Path(model_dir).exists():
+			raise FileNotFoundError(f'{model_dir} does not exist.')
+		
+		model_file = os.path.join(model_dir, 'model.onnx')
+		if quantize:
+			model_file = os.path.join(model_dir, 'model_quant.onnx')
+		config_file = os.path.join(model_dir, 'vad.yaml')
+		cmvn_file = os.path.join(model_dir, 'vad.mvn')
+		config = read_yaml(config_file)
+		
+		self.frontend = WavFrontendOnline(
+			cmvn_file=cmvn_file,
+			**config['frontend_conf']
+		)
+		self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
+		self.batch_size = batch_size
+		self.vad_scorer = E2EVadModel(config["vad_post_conf"])
+		self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
+		self.encoder_conf = config["encoder_conf"]
+	
+	def prepare_cache(self, in_cache: list = []):
+		if len(in_cache) > 0:
+			return in_cache
+		fsmn_layers = self.encoder_conf["fsmn_layers"]
+		proj_dim = self.encoder_conf["proj_dim"]
+		lorder = self.encoder_conf["lorder"]
+		for i in range(fsmn_layers):
+			cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32)
+			in_cache.append(cache)
+		return in_cache
+	
+	def __call__(self, audio_in: np.ndarray, **kwargs) -> List:
+		waveforms = np.expand_dims(audio_in, axis=0)
+		
+		param_dict = kwargs.get('param_dict', dict())
+		is_final = param_dict.get('is_final', False)
+		feats, feats_len = self.extract_feat(waveforms, is_final)
+		segments = []
+		if feats.size != 0:
+			in_cache = param_dict.get('in_cache', list())
+			in_cache = self.prepare_cache(in_cache)
+			try:
+				inputs = [feats]
+				inputs.extend(in_cache)
+				scores, out_caches = self.infer(inputs)
+				param_dict['in_cache'] = out_caches
+				waveforms = self.frontend.get_waveforms()
+				segments = self.vad_scorer(scores, waveforms, is_final=is_final, max_end_sil=self.max_end_sil,
+				                           online=True)
+			
+			
+			except ONNXRuntimeError:
+				# logging.warning(traceback.format_exc())
+				logging.warning("input wav is silence or noise")
+				segments = []
+		return segments
+	
+	def load_data(self,
+	              wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
+		def load_wav(path: str) -> np.ndarray:
+			waveform, _ = librosa.load(path, sr=fs)
+			return waveform
+		
+		if isinstance(wav_content, np.ndarray):
+			return [wav_content]
+		
+		if isinstance(wav_content, str):
+			return [load_wav(wav_content)]
+		
+		if isinstance(wav_content, list):
+			return [load_wav(path) for path in wav_content]
+		
+		raise TypeError(
+			f'The type of {wav_content} is not in [str, np.ndarray, list]')
+	
+	def extract_feat(self,
+	                 waveforms: np.ndarray, is_final: bool = False
+	                 ) -> Tuple[np.ndarray, np.ndarray]:
+		waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
+		for idx, waveform in enumerate(waveforms):
+			waveforms_lens[idx] = waveform.shape[-1]
+		
+		feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
+		# feats.append(feat)
+		# feats_len.append(feat_len)
+		
+		# feats = self.pad_feats(feats, np.max(feats_len))
+		# feats_len = np.array(feats_len).astype(np.int32)
+		return feats.astype(np.float32), feats_len.astype(np.int32)
+	
+	@staticmethod
+	def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
+		def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
+			pad_width = ((0, max_feat_len - cur_len), (0, 0))
+			return np.pad(feat, pad_width, 'constant', constant_values=0)
+		
+		feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
+		feats = np.array(feat_res).astype(np.float32)
+		return feats
 	
+	def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
+		
+		outputs = self.ort_infer(feats)
+		scores, out_caches = outputs[0], outputs[1:]
+		return scores, out_caches
+

+ 0 - 134
funasr/runtime/python/onnxruntime/funasr_onnx/vad_online_bin.py

@@ -1,134 +0,0 @@
-# -*- encoding: utf-8 -*-
-
-import os.path
-from pathlib import Path
-from typing import List, Union, Tuple
-
-import copy
-import librosa
-import numpy as np
-
-from .utils.utils import (ONNXRuntimeError,
-                          OrtInferSession, get_logger,
-                          read_yaml)
-from .utils.frontend import WavFrontendOnline
-from .utils.e2e_vad import E2EVadModel
-
-logging = get_logger()
-
-
-class Fsmn_vad():
-	def __init__(self, model_dir: Union[str, Path] = None,
-	             batch_size: int = 1,
-	             device_id: Union[str, int] = "-1",
-	             quantize: bool = False,
-	             intra_op_num_threads: int = 4,
-	             max_end_sil: int = None,
-	             ):
-		
-		if not Path(model_dir).exists():
-			raise FileNotFoundError(f'{model_dir} does not exist.')
-		
-		model_file = os.path.join(model_dir, 'model.onnx')
-		if quantize:
-			model_file = os.path.join(model_dir, 'model_quant.onnx')
-		config_file = os.path.join(model_dir, 'vad.yaml')
-		cmvn_file = os.path.join(model_dir, 'vad.mvn')
-		config = read_yaml(config_file)
-		
-		self.frontend = WavFrontendOnline(
-			cmvn_file=cmvn_file,
-			**config['frontend_conf']
-		)
-		self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
-		self.batch_size = batch_size
-		self.vad_scorer = E2EVadModel(config["vad_post_conf"])
-		self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
-		self.encoder_conf = config["encoder_conf"]
-	
-	def prepare_cache(self, in_cache: list = []):
-		if len(in_cache) > 0:
-			return in_cache
-		fsmn_layers = self.encoder_conf["fsmn_layers"]
-		proj_dim = self.encoder_conf["proj_dim"]
-		lorder = self.encoder_conf["lorder"]
-		for i in range(fsmn_layers):
-			cache = np.zeros((1, proj_dim, lorder-1, 1)).astype(np.float32)
-			in_cache.append(cache)
-		return in_cache
-		
-	
-	def __call__(self, audio_in: np.ndarray, **kwargs) -> List:
-		waveforms = np.expand_dims(audio_in, axis=0)
-		
-		param_dict = kwargs.get('param_dict', dict())
-		is_final = param_dict.get('is_final', False)
-		feats, feats_len = self.extract_feat(waveforms, is_final)
-		segments = []
-		if feats.size != 0:
-			in_cache = param_dict.get('in_cache', list())
-			in_cache = self.prepare_cache(in_cache)
-			try:
-				inputs = [feats]
-				inputs.extend(in_cache)
-				scores, out_caches = self.infer(inputs)
-				param_dict['in_cache'] = out_caches
-				waveforms = self.frontend.get_waveforms()
-				segments = self.vad_scorer(scores, waveforms, is_final=is_final, max_end_sil=self.max_end_sil, online=True)
-
-
-			except ONNXRuntimeError:
-				logging.warning(traceback.format_exc())
-				logging.warning("input wav is silence or noise")
-				segments = []
-		return segments
-
-	def load_data(self,
-	              wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
-		def load_wav(path: str) -> np.ndarray:
-			waveform, _ = librosa.load(path, sr=fs)
-			return waveform
-		
-		if isinstance(wav_content, np.ndarray):
-			return [wav_content]
-		
-		if isinstance(wav_content, str):
-			return [load_wav(wav_content)]
-		
-		if isinstance(wav_content, list):
-			return [load_wav(path) for path in wav_content]
-		
-		raise TypeError(
-			f'The type of {wav_content} is not in [str, np.ndarray, list]')
-	
-	def extract_feat(self,
-	                 waveforms: np.ndarray, is_final: bool = False
-	                 ) -> Tuple[np.ndarray, np.ndarray]:
-		waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
-		for idx, waveform in enumerate(waveforms):
-			waveforms_lens[idx] = waveform.shape[-1]
-
-		feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
-		# feats.append(feat)
-		# feats_len.append(feat_len)
-
-		# feats = self.pad_feats(feats, np.max(feats_len))
-		# feats_len = np.array(feats_len).astype(np.int32)
-		return feats.astype(np.float32), feats_len.astype(np.int32)
-
-	@staticmethod
-	def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
-		def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
-			pad_width = ((0, max_feat_len - cur_len), (0, 0))
-			return np.pad(feat, pad_width, 'constant', constant_values=0)
-		
-		feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
-		feats = np.array(feat_res).astype(np.float32)
-		return feats
-	
-	def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
-		
-		outputs = self.ort_infer(feats)
-		scores, out_caches = outputs[0], outputs[1:]
-		return scores, out_caches
-	

+ 1 - 1
funasr/runtime/python/onnxruntime/setup.py

@@ -13,7 +13,7 @@ def get_readme():
 
 
 MODULE_NAME = 'funasr_onnx'
-VERSION_NUM = '0.0.3'
+VERSION_NUM = '0.0.4'
 
 setuptools.setup(
     name=MODULE_NAME,