|
|
@@ -32,8 +32,7 @@ class TargetDelayTransformer():
|
|
|
|
|
|
self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
|
|
|
self.batch_size = 1
|
|
|
- self.encoder_conf = config["encoder_conf"]
|
|
|
- self.punc_list = config.punc_list
|
|
|
+ self.punc_list = config['punc_list']
|
|
|
self.period = 0
|
|
|
for i in range(len(self.punc_list)):
|
|
|
if self.punc_list[i] == ",":
|
|
|
@@ -44,13 +43,13 @@ class TargetDelayTransformer():
|
|
|
self.period = i
|
|
|
self.preprocessor = CodeMixTokenizerCommonPreprocessor(
|
|
|
train=False,
|
|
|
- token_type=config.token_type,
|
|
|
- token_list=config.token_list,
|
|
|
- bpemodel=config.bpemodel,
|
|
|
- text_cleaner=config.cleaner,
|
|
|
- g2p_type=config.g2p,
|
|
|
+ token_type=config['token_type'],
|
|
|
+ token_list=config['token_list'],
|
|
|
+ bpemodel=config['bpemodel'],
|
|
|
+ text_cleaner=config['cleaner'],
|
|
|
+ g2p_type=config['g2p'],
|
|
|
text_name="text",
|
|
|
- non_linguistic_symbols=config.non_linguistic_symbols,
|
|
|
+ non_linguistic_symbols=config['non_linguistic_symbols'],
|
|
|
)
|
|
|
|
|
|
def __call__(self, text: Union[list, str], split_size=20):
|