|
|
@@ -8,8 +8,7 @@ import numpy as np
|
|
|
from .utils.utils import (ONNXRuntimeError,
|
|
|
OrtInferSession, get_logger,
|
|
|
read_yaml)
|
|
|
-from .utils.preprocessor import CodeMixTokenizerCommonPreprocessor
|
|
|
-from .utils.utils import split_to_mini_sentence
|
|
|
+from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words)
|
|
|
logging = get_logger()
|
|
|
|
|
|
|
|
|
@@ -30,6 +29,7 @@ class TargetDelayTransformer():
|
|
|
config_file = os.path.join(model_dir, 'punc.yaml')
|
|
|
config = read_yaml(config_file)
|
|
|
|
|
|
+ self.converter = TokenIDConverter(config['token_list'])
|
|
|
self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
|
|
|
self.batch_size = 1
|
|
|
self.punc_list = config['punc_list']
|
|
|
@@ -41,23 +41,12 @@ class TargetDelayTransformer():
|
|
|
self.punc_list[i] = "?"
|
|
|
elif self.punc_list[i] == "。":
|
|
|
self.period = i
|
|
|
- self.preprocessor = CodeMixTokenizerCommonPreprocessor(
|
|
|
- train=False,
|
|
|
- token_type=config['token_type'],
|
|
|
- token_list=config['token_list'],
|
|
|
- bpemodel=config['bpemodel'],
|
|
|
- text_cleaner=config['cleaner'],
|
|
|
- g2p_type=config['g2p'],
|
|
|
- text_name="text",
|
|
|
- non_linguistic_symbols=config['non_linguistic_symbols'],
|
|
|
- )
|
|
|
|
|
|
def __call__(self, text: Union[list, str], split_size=20):
|
|
|
- data = {"text": text}
|
|
|
- result = self.preprocessor(data=data, uid="12938712838719")
|
|
|
- split_text = self.preprocessor.pop_split_text_data(result)
|
|
|
+ split_text = code_mix_split_words(text)
|
|
|
+ split_text_id = self.converter.tokens2ids(split_text)
|
|
|
mini_sentences = split_to_mini_sentence(split_text, split_size)
|
|
|
- mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
|
|
|
+ mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
|
|
|
assert len(mini_sentences) == len(mini_sentences_id)
|
|
|
cache_sent = []
|
|
|
cache_sent_id = []
|
|
|
@@ -68,9 +57,9 @@ class TargetDelayTransformer():
|
|
|
mini_sentence = mini_sentences[mini_sentence_i]
|
|
|
mini_sentence_id = mini_sentences_id[mini_sentence_i]
|
|
|
mini_sentence = cache_sent + mini_sentence
|
|
|
- mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
|
|
|
+ mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64')
|
|
|
data = {
|
|
|
- "text": mini_sentence_id[None,:].astype(np.int64),
|
|
|
+ "text": mini_sentence_id[None,:],
|
|
|
"text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
|
|
|
}
|
|
|
try:
|