九耳 3 лет назад
Родитель
Сommit
0fd9640ced

+ 7 - 18
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py

@@ -8,8 +8,7 @@ import numpy as np
 from .utils.utils import (ONNXRuntimeError,
 from .utils.utils import (ONNXRuntimeError,
                           OrtInferSession, get_logger,
                           OrtInferSession, get_logger,
                           read_yaml)
                           read_yaml)
-from .utils.preprocessor import CodeMixTokenizerCommonPreprocessor
-from .utils.utils import split_to_mini_sentence
+from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words)
 logging = get_logger()
 logging = get_logger()
 
 
 
 
@@ -30,6 +29,7 @@ class TargetDelayTransformer():
         config_file = os.path.join(model_dir, 'punc.yaml')
         config_file = os.path.join(model_dir, 'punc.yaml')
         config = read_yaml(config_file)
         config = read_yaml(config_file)
 
 
+        self.converter = TokenIDConverter(config['token_list'])
         self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
         self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
         self.batch_size = 1
         self.batch_size = 1
         self.punc_list = config['punc_list']
         self.punc_list = config['punc_list']
@@ -41,23 +41,12 @@ class TargetDelayTransformer():
                 self.punc_list[i] = "?"
                 self.punc_list[i] = "?"
             elif self.punc_list[i] == "。":
             elif self.punc_list[i] == "。":
                 self.period = i
                 self.period = i
-        self.preprocessor = CodeMixTokenizerCommonPreprocessor(
-            train=False,
-            token_type=config['token_type'],
-            token_list=config['token_list'],
-            bpemodel=config['bpemodel'],
-            text_cleaner=config['cleaner'],
-            g2p_type=config['g2p'],
-            text_name="text",
-            non_linguistic_symbols=config['non_linguistic_symbols'],
-        )
 
 
     def __call__(self, text: Union[list, str], split_size=20):
     def __call__(self, text: Union[list, str], split_size=20):
-        data = {"text": text}
-        result = self.preprocessor(data=data, uid="12938712838719")
-        split_text = self.preprocessor.pop_split_text_data(result)
+        split_text = code_mix_split_words(text)
+        split_text_id = self.converter.tokens2ids(split_text)
         mini_sentences = split_to_mini_sentence(split_text, split_size)
         mini_sentences = split_to_mini_sentence(split_text, split_size)
-        mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
+        mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
         assert len(mini_sentences) == len(mini_sentences_id)
         assert len(mini_sentences) == len(mini_sentences_id)
         cache_sent = []
         cache_sent = []
         cache_sent_id = []
         cache_sent_id = []
@@ -68,9 +57,9 @@ class TargetDelayTransformer():
             mini_sentence = mini_sentences[mini_sentence_i]
             mini_sentence = mini_sentences[mini_sentence_i]
             mini_sentence_id = mini_sentences_id[mini_sentence_i]
             mini_sentence_id = mini_sentences_id[mini_sentence_i]
             mini_sentence = cache_sent + mini_sentence
             mini_sentence = cache_sent + mini_sentence
-            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+            mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64')
             data = {
             data = {
-                "text": mini_sentence_id[None,:].astype(np.int64),
+                "text": mini_sentence_id[None,:],
                 "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
                 "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
             }
             }
             try:
             try:

+ 19 - 0
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py

@@ -228,6 +228,25 @@ def split_to_mini_sentence(words: list, word_limit: int = 20):
         sentences.append(words[sentence_len * word_limit:])
         sentences.append(words[sentence_len * word_limit:])
     return sentences
     return sentences
 
 
+def code_mix_split_words(text: str):
+    words = []
+    segs = text.split()
+    for seg in segs:
+        # There is no space in seg.
+        current_word = ""
+        for c in seg:
+            if len(c.encode()) == 1:
+                # This is an ASCII char.
+                current_word += c
+            else:
+                # This is a Chinese char.
+                if len(current_word) > 0:
+                    words.append(current_word)
+                    current_word = ""
+                words.append(c)
+        if len(current_word) > 0:
+            words.append(current_word)
+    return words
 
 
 def read_yaml(yaml_path: Union[str, Path]) -> Dict:
 def read_yaml(yaml_path: Union[str, Path]) -> Dict:
     if not Path(yaml_path).exists():
     if not Path(yaml_path).exists():