游雁 2 éve
szülő
commit
865ae89f0a

+ 123 - 0
fbank.py

@@ -0,0 +1,123 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Part of the implementation is borrowed from espnet/espnet.
+
+from typing import Tuple
+
+import numpy as np
+import torch
+import torchaudio.compliance.kaldi as kaldi
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from typeguard import check_argument_types
+from torch.nn.utils.rnn import pad_sequence
+import kaldi_native_fbank as knf
+
+class WavFrontend(AbsFrontend):
+	"""Conventional frontend structure for ASR.
+	"""
+
+	def __init__(
+		self,
+		cmvn_file: str = None,
+		fs: int = 16000,
+		window: str = 'hamming',
+		n_mels: int = 80,
+		frame_length: int = 25,
+		frame_shift: int = 10,
+		filter_length_min: int = -1,
+		filter_length_max: int = -1,
+		lfr_m: int = 1,
+		lfr_n: int = 1,
+		dither: float = 1.0,
+		snip_edges: bool = True,
+		upsacle_samples: bool = True,
+	):
+		assert check_argument_types()
+		super().__init__()
+		self.fs = fs
+		self.window = window
+		self.n_mels = n_mels
+		self.frame_length = frame_length
+		self.frame_shift = frame_shift
+		self.filter_length_min = filter_length_min
+		self.filter_length_max = filter_length_max
+		self.lfr_m = lfr_m
+		self.lfr_n = lfr_n
+		self.cmvn_file = cmvn_file
+		self.dither = dither
+		self.snip_edges = snip_edges
+		self.upsacle_samples = upsacle_samples
+
+	def output_size(self) -> int:
+		return self.n_mels * self.lfr_m
+
+	def forward(
+		self,
+		input: torch.Tensor,
+		input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+		batch_size = input.size(0)
+		feats = []
+		feats_lens = []
+		for i in range(batch_size):
+			waveform_length = input_lengths[i]
+			waveform = input[i][:waveform_length]
+			waveform = waveform * (1 << 15)
+			waveform = waveform.unsqueeze(0)
+			mat = kaldi.fbank(waveform,
+			                  num_mel_bins=self.n_mels,
+			                  frame_length=self.frame_length,
+			                  frame_shift=self.frame_shift,
+			                  dither=self.dither,
+			                  energy_floor=0.0,
+			                  window_type=self.window,
+			                  sample_frequency=self.fs)
+
+			feat_length = mat.size(0)
+			feats.append(mat)
+			feats_lens.append(feat_length)
+
+		feats_lens = torch.as_tensor(feats_lens)
+		feats_pad = pad_sequence(feats,
+		                         batch_first=True,
+		                         padding_value=0.0)
+		return feats_pad, feats_lens
+
+import kaldi_native_fbank as knf
+
+def fbank_knf(waveform):
+	# sampling_rate = 16000
+	# samples = torch.randn(16000 * 10)
+
+	opts = knf.FbankOptions()
+	opts.frame_opts.samp_freq = 16000
+	opts.frame_opts.dither = 0.0
+	opts.frame_opts.window_type = "hamming"
+	opts.frame_opts.frame_shift_ms = 10.0
+	opts.frame_opts.frame_length_ms = 25.0
+	opts.mel_opts.num_bins = 80
+	opts.energy_floor = 1
+	opts.frame_opts.snip_edges = True
+	opts.mel_opts.debug_mel = False
+	
+	fbank = knf.OnlineFbank(opts)
+	waveform = waveform * (1 << 15)
+	fbank.accept_waveform(opts.frame_opts.samp_freq, waveform.tolist())
+	frames = fbank.num_frames_ready
+	mat = np.empty([frames, opts.mel_opts.num_bins])
+	for i in range(frames):
+		mat[i, :] = fbank.get_frame(i)
+	return mat
+
+if __name__ == '__main__':
+	import librosa
+	
+	path = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
+	waveform, fs = librosa.load(path, sr=None)
+	fbank = fbank_knf(waveform)
+	frontend = WavFrontend(dither=0.0)
+	waveform_tensor = torch.from_numpy(waveform)[None, :]
+	fbank_torch, _ = frontend.forward(waveform_tensor, [waveform_tensor.size(1)])
+	fbank_torch = fbank_torch.cpu().numpy()[0, :, :]
+	diff = fbank - fbank_torch
+	diff_max = diff.max()
+	diff_sum = diff.abs().sum()
+	pass

+ 1 - 4
funasr/models/frontend/wav_frontend.py

@@ -171,10 +171,7 @@ class WavFrontend(AbsFrontend):
                               window_type=self.window,
                               sample_frequency=self.fs)
 
-            # if self.lfr_m != 1 or self.lfr_n != 1:
-            #     mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
-            # if self.cmvn_file is not None:
-            #     mat = apply_cmvn(mat, self.cmvn_file)
+
             feat_length = mat.size(0)
             feats.append(mat)
             feats_lens.append(feat_length)

+ 0 - 0
funasr/runtime/__init__.py


+ 0 - 0
funasr/runtime/python/__init__.py


+ 0 - 0
funasr/runtime/python/onnxruntime/__init__.py


+ 0 - 0
funasr/runtime/python/onnxruntime/.gitignore → funasr/runtime/python/onnxruntime/paraformer/.gitignore


+ 4 - 10
funasr/runtime/python/onnxruntime/README.md → funasr/runtime/python/onnxruntime/paraformer/README.md

@@ -29,12 +29,6 @@
         │   └── utils.py
         ├── README.md
         ├── requirements.txt
-        ├── resources
-        │   ├── config.yaml
-        │   └── models
-        │       ├── am.mvn
-        │       ├── model.onnx  # Put it here.
-        │       └── token_list.pkl
         ├── test_onnx.py
         ├── tests
         │   ├── __pycache__
@@ -48,15 +42,15 @@
    - Output: `List[str]`: recognition result.
    - Example:
         ```python
-        from rapid_paraformer import RapidParaformer
+        from paraformer_onnx import Paraformer
 
 
         config_path = 'resources/config.yaml'
-        paraformer = RapidParaformer(config_path)
+        model = Paraformer(config_path)
 
-        wav_path = ['test_wavs/0478_00017.wav']
+        wav_path = ['example/asr_example.wav']
 
-        result = paraformer(wav_path)
+        result = model(wav_path)
         print(result)
         ```
 

+ 0 - 0
funasr/runtime/python/onnxruntime/paraformer/__init__.py


+ 0 - 0
funasr/runtime/python/onnxruntime/rapid_paraformer/__init__.py → funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/__init__.py


+ 0 - 0
funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/LICENSE → funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/LICENSE


+ 0 - 0
funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/__init__.py → funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/__init__.py


+ 0 - 0
funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/feature.py → funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/feature.py


+ 0 - 0
funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/ivector.py → funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/ivector.py


+ 30 - 18
funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py → funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py

@@ -1,6 +1,7 @@
 # -*- encoding: utf-8 -*-
 # @Author: SWHL
 # @Contact: liekkaskono@163.com
+import os.path
 import traceback
 from pathlib import Path
 from typing import List, Union, Tuple
@@ -11,25 +12,33 @@ import numpy as np
 from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                     OrtInferSession, TokenIDConverter, WavFrontend, get_logger,
                     read_yaml)
+from .postprocess_utils import sentence_postprocess
 
 logging = get_logger()
 
 
-class RapidParaformer():
-    def __init__(self, config_path: Union[str, Path]) -> None:
-        if not Path(config_path).exists():
-            raise FileNotFoundError(f'{config_path} does not exist.')
+class Paraformer():
+    def __init__(self, model_dir: Union[str, Path]=None,
+                 batch_size: int = 1,
+                 device_id: Union[str, int]="-1",
+                 ):
+        
+        if not Path(model_dir).exists():
+            raise FileNotFoundError(f'{model_dir} does not exist.')
 
-        config = read_yaml(config_path)
+        model_file = os.path.join(model_dir, 'model.onnx')
+        config_file = os.path.join(model_dir, 'config.yaml')
+        cmvn_file = os.path.join(model_dir, 'am.mvn')
+        config = read_yaml(config_file)
 
-        self.converter = TokenIDConverter(**config['TokenIDConverter'])
-        self.tokenizer = CharTokenizer(**config['CharTokenizer'])
+        self.converter = TokenIDConverter(config['token_list'])
+        self.tokenizer = CharTokenizer()
         self.frontend = WavFrontend(
-            cmvn_file=config['WavFrontend']['cmvn_file'],
-            **config['WavFrontend']['frontend_conf']
+            cmvn_file=cmvn_file,
+            **config['frontend_conf']
         )
-        self.ort_infer = OrtInferSession(config['Model'])
-        self.batch_size = config['Model']['batch_size']
+        self.ort_infer = OrtInferSession(model_file, device_id)
+        self.batch_size = batch_size
 
     def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List:
         waveform_list = self.load_data(wav_content)
@@ -124,16 +133,19 @@ class RapidParaformer():
 
         # Change integer-ids to tokens
         token = self.converter.ids2tokens(token_int)
-        text = self.tokenizer.tokens2text(token)
+        token = token[:valid_token_num-1]
+        texts = sentence_postprocess(token)
+        text = texts[0]
+        # text = self.tokenizer.tokens2text(token)
         return text[:valid_token_num-1]
 
 
 if __name__ == '__main__':
     project_dir = Path(__file__).resolve().parent.parent
-    cfg_path = project_dir / 'resources' / 'config.yaml'
-    paraformer = RapidParaformer(cfg_path)
+    model_dir = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+    model = Paraformer(model_dir)
+
+    wav_file = os.path.join(model_dir, 'example/asr_example.wav')
+    result = model(wav_file)
+    print(result)
 
-    wav_file = '0478_00017.wav'
-    for i in range(1000):
-        result = paraformer(wav_file)
-        print(result)

+ 240 - 0
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/postprocess_utils.py

@@ -0,0 +1,240 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import string
+import logging
+from typing import Any, List, Union
+
+
+def isChinese(ch: str):
+    if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039':
+        return True
+    return False
+
+
+def isAllChinese(word: Union[List[Any], str]):
+    word_lists = []
+    for i in word:
+        cur = i.replace(' ', '')
+        cur = cur.replace('</s>', '')
+        cur = cur.replace('<s>', '')
+        word_lists.append(cur)
+
+    if len(word_lists) == 0:
+        return False
+
+    for ch in word_lists:
+        if isChinese(ch) is False:
+            return False
+    return True
+
+
+def isAllAlpha(word: Union[List[Any], str]):
+    word_lists = []
+    for i in word:
+        cur = i.replace(' ', '')
+        cur = cur.replace('</s>', '')
+        cur = cur.replace('<s>', '')
+        word_lists.append(cur)
+
+    if len(word_lists) == 0:
+        return False
+
+    for ch in word_lists:
+        if ch.isalpha() is False and ch != "'":
+            return False
+        elif ch.isalpha() is True and isChinese(ch) is True:
+            return False
+
+    return True
+
+
+# def abbr_dispose(words: List[Any]) -> List[Any]:
+def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
+    words_size = len(words)
+    word_lists = []
+    abbr_begin = []
+    abbr_end = []
+    last_num = -1
+    ts_lists = []
+    ts_nums = []
+    ts_index = 0
+    for num in range(words_size):
+        if num <= last_num:
+            continue
+
+        if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
+            if num + 1 < words_size and words[
+                    num + 1] == ' ' and num + 2 < words_size and len(
+                        words[num +
+                              2]) == 1 and words[num +
+                                                 2].encode('utf-8').isalpha():
+                # found the begin of abbr
+                abbr_begin.append(num)
+                num += 2
+                abbr_end.append(num)
+                # to find the end of abbr
+                while True:
+                    num += 1
+                    if num < words_size and words[num] == ' ':
+                        num += 1
+                        if num < words_size and len(
+                                words[num]) == 1 and words[num].encode(
+                                    'utf-8').isalpha():
+                            abbr_end.pop()
+                            abbr_end.append(num)
+                            last_num = num
+                        else:
+                            break
+                    else:
+                        break
+
+    for num in range(words_size):
+        if words[num] == ' ':
+            ts_nums.append(ts_index)
+        else:
+            ts_nums.append(ts_index)
+            ts_index += 1 
+    last_num = -1
+    for num in range(words_size):
+        if num <= last_num:
+            continue
+
+        if num in abbr_begin:
+            if time_stamp is not None:
+                begin = time_stamp[ts_nums[num]][0]
+            word_lists.append(words[num].upper())
+            num += 1
+            while num < words_size:
+                if num in abbr_end:
+                    word_lists.append(words[num].upper())
+                    last_num = num
+                    break
+                else:
+                    if words[num].encode('utf-8').isalpha():
+                        word_lists.append(words[num].upper())
+                num += 1
+            if time_stamp is not None:
+                end = time_stamp[ts_nums[num]][1]
+                ts_lists.append([begin, end])
+        else:
+            word_lists.append(words[num])
+            if time_stamp is not None and words[num] != ' ':
+                begin = time_stamp[ts_nums[num]][0]
+                end = time_stamp[ts_nums[num]][1]
+                ts_lists.append([begin, end])
+                begin = end
+
+    if time_stamp is not None:
+        return word_lists, ts_lists
+    else:
+        return word_lists
+
+
+def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
+    middle_lists = []
+    word_lists = []
+    word_item = ''
+    ts_lists = []
+
+    # wash words lists
+    for i in words:
+        word = ''
+        if isinstance(i, str):
+            word = i
+        else:
+            word = i.decode('utf-8')
+
+        if word in ['<s>', '</s>', '<unk>']:
+            continue
+        else:
+            middle_lists.append(word)
+
+    # all chinese characters
+    if isAllChinese(middle_lists):
+        for i, ch in enumerate(middle_lists):
+            word_lists.append(ch.replace(' ', ''))
+        if time_stamp is not None:
+            ts_lists = time_stamp
+
+    # all alpha characters
+    elif isAllAlpha(middle_lists):
+        ts_flag = True
+        for i, ch in enumerate(middle_lists):
+            if ts_flag and time_stamp is not None:
+                begin = time_stamp[i][0]
+                end = time_stamp[i][1]
+            word = ''
+            if '@@' in ch:
+                word = ch.replace('@@', '')
+                word_item += word
+                if time_stamp is not None:
+                    ts_flag = False
+                    end = time_stamp[i][1]
+            else:
+                word_item += ch
+                word_lists.append(word_item)
+                word_lists.append(' ')
+                word_item = ''
+                if time_stamp is not None:
+                    ts_flag = True
+                    end = time_stamp[i][1]
+                    ts_lists.append([begin, end])
+                    begin = end
+
+    # mix characters
+    else:
+        alpha_blank = False
+        ts_flag = True
+        begin = -1
+        end = -1
+        for i, ch in enumerate(middle_lists):
+            if ts_flag and time_stamp is not None:
+                begin = time_stamp[i][0]
+                end = time_stamp[i][1]
+            word = ''
+            if isAllChinese(ch):
+                if alpha_blank is True:
+                    word_lists.pop()
+                word_lists.append(ch)
+                alpha_blank = False
+                if time_stamp is not None:
+                    ts_flag = True
+                    ts_lists.append([begin, end])
+                    begin = end
+            elif '@@' in ch:
+                word = ch.replace('@@', '')
+                word_item += word
+                alpha_blank = False
+                if time_stamp is not None:
+                    ts_flag = False
+                    end = time_stamp[i][1]
+            elif isAllAlpha(ch):
+                word_item += ch
+                word_lists.append(word_item)
+                word_lists.append(' ')
+                word_item = ''
+                alpha_blank = True
+                if time_stamp is not None:
+                    ts_flag = True
+                    end = time_stamp[i][1] 
+                    ts_lists.append([begin, end])
+                    begin = end
+            else:
+                raise ValueError('invalid character: {}'.format(ch))
+
+    if time_stamp is not None: 
+        word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
+        real_word_lists = []
+        for ch in word_lists:
+            if ch != ' ':
+                real_word_lists.append(ch)
+        sentence = ' '.join(real_word_lists).strip()
+        return sentence, ts_lists, real_word_lists
+    else:
+        word_lists = abbr_dispose(word_lists)
+        real_word_lists = []
+        for ch in word_lists:
+            if ch != ' ':
+                real_word_lists.append(ch)
+        sentence = ''.join(word_lists).strip()
+        return sentence, real_word_lists

+ 29 - 22
funasr/runtime/python/onnxruntime/rapid_paraformer/utils.py → funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils.py

@@ -14,6 +14,7 @@ from onnxruntime import (GraphOptimizationLevel, InferenceSession,
 from typeguard import check_argument_types
 
 from .kaldifeat import compute_fbank_feats
+import warnings
 
 root_dir = Path(__file__).resolve().parent
 
@@ -21,24 +22,25 @@ logger_initialized = {}
 
 
 class TokenIDConverter():
-    def __init__(self, token_path: Union[Path, str],
+    def __init__(self, token_list: Union[Path, str],
                  unk_symbol: str = "<unk>",):
         check_argument_types()
 
-        self.token_list = self.load_token(token_path)
-        self.unk_symbol = unk_symbol
-
-    @staticmethod
-    def load_token(file_path: Union[Path, str]) -> List:
-        if not Path(file_path).exists():
-            raise TokenIDConverterError(f'The {file_path} does not exist.')
-
-        with open(str(file_path), 'rb') as f:
-            token_list = pickle.load(f)
-
-        if len(token_list) != len(set(token_list)):
-            raise TokenIDConverterError('The Token exists duplicated symbol.')
-        return token_list
+        # self.token_list = self.load_token(token_path)
+        self.token_list = token_list
+        self.unk_symbol = token_list[-1]
+
+    # @staticmethod
+    # def load_token(file_path: Union[Path, str]) -> List:
+    #     if not Path(file_path).exists():
+    #         raise TokenIDConverterError(f'The {file_path} does not exist.')
+    #
+    #     with open(str(file_path), 'rb') as f:
+    #         token_list = pickle.load(f)
+    #
+    #     if len(token_list) != len(set(token_list)):
+    #         raise TokenIDConverterError('The Token exists duplicated symbol.')
+    #     return token_list
 
     def get_num_vocabulary_size(self) -> int:
         return len(self.token_list)
@@ -268,31 +270,36 @@ class ONNXRuntimeError(Exception):
 
 
 class OrtInferSession():
-    def __init__(self, config):
+    def __init__(self, model_file, device_id=-1):
         sess_opt = SessionOptions()
         sess_opt.log_severity_level = 4
         sess_opt.enable_cpu_mem_arena = False
         sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
 
         cuda_ep = 'CUDAExecutionProvider'
+        cuda_provider_options = {
+            "device_id": device_id,
+            "arena_extend_strategy": "kNextPowerOfTwo",
+            "cudnn_conv_algo_search": "EXHAUSTIVE",
+            "do_copy_in_default_stream": "true",
+        }
         cpu_ep = 'CPUExecutionProvider'
         cpu_provider_options = {
             "arena_extend_strategy": "kSameAsRequested",
         }
 
         EP_list = []
-        if config['use_cuda'] and get_device() == 'GPU' \
+        if device_id != -1 and get_device() == 'GPU' \
                 and cuda_ep in get_available_providers():
-            EP_list = [(cuda_ep, config[cuda_ep])]
+            EP_list = [(cuda_ep, cuda_provider_options)]
         EP_list.append((cpu_ep, cpu_provider_options))
 
-        config['model_path'] = config['model_path']
-        self._verify_model(config['model_path'])
-        self.session = InferenceSession(config['model_path'],
+        self._verify_model(model_file)
+        self.session = InferenceSession(model_file,
                                         sess_options=sess_opt,
                                         providers=EP_list)
 
-        if config['use_cuda'] and cuda_ep not in self.session.get_providers():
+        if device_id != -1 and cuda_ep not in self.session.get_providers():
             warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
                           'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
                           'you can check their relations from the offical web site: '

+ 0 - 0
funasr/runtime/python/onnxruntime/requirements.txt → funasr/runtime/python/onnxruntime/paraformer/requirements.txt


+ 1 - 0
funasr/runtime/python/onnxruntime/resources/config.yaml → funasr/runtime/python/onnxruntime/paraformer/resources/config.yaml

@@ -18,6 +18,7 @@ WavFrontend:
     lfr_m: 7
     lfr_n: 6
     filter_length_max: -.inf
+    dither: 0.0
 
 Model:
   model_path: resources/models/model.onnx

+ 0 - 0
funasr/runtime/python/onnxruntime/resources/models/am.mvn → funasr/runtime/python/onnxruntime/paraformer/resources/models/am.mvn


+ 0 - 0
funasr/runtime/python/onnxruntime/resources/models/token_list.pkl → funasr/runtime/python/onnxruntime/paraformer/resources/models/token_list.pkl