|
|
@@ -1,5 +1,6 @@
|
|
|
import argparse
|
|
|
import logging
|
|
|
+from optparse import Option
|
|
|
import sys
|
|
|
import json
|
|
|
from pathlib import Path
|
|
|
@@ -11,15 +12,12 @@ from typing import Tuple
|
|
|
from typing import Union
|
|
|
from typing import Dict
|
|
|
|
|
|
-import math
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from typeguard import check_argument_types
|
|
|
-from typeguard import check_return_type
|
|
|
|
|
|
from funasr.fileio.datadir_writer import DatadirWriter
|
|
|
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
|
|
|
-from funasr.modules.subsampling import TooShortUttError
|
|
|
+from funasr.datasets.preprocessor import LMPreprocessor
|
|
|
from funasr.tasks.asr import ASRTaskAligner as ASRTask
|
|
|
from funasr.torch_utils.device_funcs import to_device
|
|
|
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
|
|
@@ -28,7 +26,6 @@ from funasr.utils.cli_utils import get_commandline_args
|
|
|
from funasr.utils.types import str2bool
|
|
|
from funasr.utils.types import str2triple_str
|
|
|
from funasr.utils.types import str_or_none
|
|
|
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
|
|
from funasr.models.frontend.wav_frontend import WavFrontend
|
|
|
from funasr.text.token_id_converter import TokenIDConverter
|
|
|
|
|
|
@@ -191,6 +188,8 @@ def inference(
|
|
|
dtype: str = "float32",
|
|
|
seed: int = 0,
|
|
|
num_workers: int = 1,
|
|
|
+ split_with_space: bool = True,
|
|
|
+ seg_dict_file: Optional[str] = None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
inference_pipeline = inference_modelscope(
|
|
|
@@ -206,6 +205,8 @@ def inference(
|
|
|
dtype=dtype,
|
|
|
seed=seed,
|
|
|
num_workers=num_workers,
|
|
|
+ split_with_space=split_with_space,
|
|
|
+ seg_dict_file=seg_dict_file,
|
|
|
**kwargs,
|
|
|
)
|
|
|
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
|
|
|
@@ -226,6 +227,8 @@ def inference_modelscope(
|
|
|
dtype: str = "float32",
|
|
|
seed: int = 0,
|
|
|
num_workers: int = 1,
|
|
|
+ split_with_space: bool = True,
|
|
|
+ seg_dict_file: Optional[str] = None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
assert check_argument_types()
|
|
|
@@ -256,6 +259,19 @@ def inference_modelscope(
|
|
|
)
|
|
|
logging.info("speechtext2timestamp_kwargs: {}".format(speechtext2timestamp_kwargs))
|
|
|
speechtext2timestamp = SpeechText2Timestamp(**speechtext2timestamp_kwargs)
|
|
|
+
|
|
|
+ preprocessor = LMPreprocessor(
|
|
|
+ train=False,
|
|
|
+ token_type=speechtext2timestamp.tp_train_args.token_type,
|
|
|
+ token_list=speechtext2timestamp.tp_train_args,
|
|
|
+ bpemodel=None,
|
|
|
+ text_cleaner=None,
|
|
|
+ g2p_type=None,
|
|
|
+ text_name="text",
|
|
|
+ non_linguistic_symbols=speechtext2timestamp.tp_train_args.non_linguistic_symbols,
|
|
|
+ split_with_space=split_with_space,
|
|
|
+ seg_dict_file=seg_dict_file,
|
|
|
+ )
|
|
|
|
|
|
def _forward(
|
|
|
data_path_and_name_and_type,
|
|
|
@@ -277,7 +293,7 @@ def inference_modelscope(
|
|
|
batch_size=batch_size,
|
|
|
key_file=key_file,
|
|
|
num_workers=num_workers,
|
|
|
- preprocess_fn=ASRTask.build_preprocess_fn(speechtext2timestamp.tp_train_args, False),
|
|
|
+ preprocess_fn=LMPreprocessor,
|
|
|
collate_fn=ASRTask.build_collate_fn(speechtext2timestamp.tp_train_args, False),
|
|
|
allow_variable_data_keys=allow_variable_data_keys,
|
|
|
inference=True,
|