|
|
@@ -1,5 +1,5 @@
|
|
|
-# -*- encoding: utf-8 -*-
|
|
|
#!/usr/bin/env python3
|
|
|
+# -*- encoding: utf-8 -*-
|
|
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
|
|
# MIT License (https://opensource.org/licenses/MIT)
|
|
|
|
|
|
@@ -7,55 +7,36 @@ import argparse
|
|
|
import logging
|
|
|
import os
|
|
|
import sys
|
|
|
-from typing import Union, Dict, Any
|
|
|
-
|
|
|
-from funasr.utils import config_argparse
|
|
|
-from funasr.utils.cli_utils import get_commandline_args
|
|
|
-from funasr.utils.types import str2bool
|
|
|
-from funasr.utils.types import str2triple_str
|
|
|
-from funasr.utils.types import str_or_none
|
|
|
-from funasr.utils.types import float_or_none
|
|
|
-
|
|
|
-import argparse
|
|
|
-import logging
|
|
|
from pathlib import Path
|
|
|
-import sys
|
|
|
-from typing import Optional
|
|
|
-from typing import Sequence
|
|
|
-from typing import Tuple
|
|
|
-from typing import Union
|
|
|
from typing import Any
|
|
|
from typing import List
|
|
|
+from typing import Optional
|
|
|
+from typing import Union
|
|
|
|
|
|
-import numpy as np
|
|
|
import torch
|
|
|
from typeguard import check_argument_types
|
|
|
|
|
|
-from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
|
|
|
-from funasr.utils.cli_utils import get_commandline_args
|
|
|
-from funasr.tasks.punctuation import PunctuationTask
|
|
|
-from funasr.torch_utils.device_funcs import to_device
|
|
|
-from funasr.torch_utils.forward_adaptor import ForwardAdaptor
|
|
|
+from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
|
|
|
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
|
|
from funasr.utils import config_argparse
|
|
|
+from funasr.utils.cli_utils import get_commandline_args
|
|
|
from funasr.utils.types import str2triple_str
|
|
|
from funasr.utils.types import str_or_none
|
|
|
-from funasr.datasets.preprocessor import split_to_mini_sentence
|
|
|
-from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
|
|
|
+
|
|
|
|
|
|
def inference_punc(
|
|
|
- batch_size: int,
|
|
|
- dtype: str,
|
|
|
- ngpu: int,
|
|
|
- seed: int,
|
|
|
- num_workers: int,
|
|
|
- log_level: Union[int, str],
|
|
|
- key_file: Optional[str],
|
|
|
- train_config: Optional[str],
|
|
|
- model_file: Optional[str],
|
|
|
- output_dir: Optional[str] = None,
|
|
|
- param_dict: dict = None,
|
|
|
- **kwargs,
|
|
|
+ batch_size: int,
|
|
|
+ dtype: str,
|
|
|
+ ngpu: int,
|
|
|
+ seed: int,
|
|
|
+ num_workers: int,
|
|
|
+ log_level: Union[int, str],
|
|
|
+ key_file: Optional[str],
|
|
|
+ train_config: Optional[str],
|
|
|
+ model_file: Optional[str],
|
|
|
+ output_dir: Optional[str] = None,
|
|
|
+ param_dict: dict = None,
|
|
|
+ **kwargs,
|
|
|
):
|
|
|
assert check_argument_types()
|
|
|
logging.basicConfig(
|
|
|
@@ -73,11 +54,11 @@ def inference_punc(
|
|
|
text2punc = Text2Punc(train_config, model_file, device)
|
|
|
|
|
|
def _forward(
|
|
|
- data_path_and_name_and_type,
|
|
|
- raw_inputs: Union[List[Any], bytes, str] = None,
|
|
|
- output_dir_v2: Optional[str] = None,
|
|
|
- cache: List[Any] = None,
|
|
|
- param_dict: dict = None,
|
|
|
+ data_path_and_name_and_type,
|
|
|
+ raw_inputs: Union[List[Any], bytes, str] = None,
|
|
|
+ output_dir_v2: Optional[str] = None,
|
|
|
+ cache: List[Any] = None,
|
|
|
+ param_dict: dict = None,
|
|
|
):
|
|
|
results = []
|
|
|
split_size = 20
|
|
|
@@ -121,20 +102,21 @@ def inference_punc(
|
|
|
|
|
|
return _forward
|
|
|
|
|
|
+
|
|
|
def inference_punc_vad_realtime(
|
|
|
- batch_size: int,
|
|
|
- dtype: str,
|
|
|
- ngpu: int,
|
|
|
- seed: int,
|
|
|
- num_workers: int,
|
|
|
- log_level: Union[int, str],
|
|
|
- #cache: list,
|
|
|
- key_file: Optional[str],
|
|
|
- train_config: Optional[str],
|
|
|
- model_file: Optional[str],
|
|
|
- output_dir: Optional[str] = None,
|
|
|
- param_dict: dict = None,
|
|
|
- **kwargs,
|
|
|
+ batch_size: int,
|
|
|
+ dtype: str,
|
|
|
+ ngpu: int,
|
|
|
+ seed: int,
|
|
|
+ num_workers: int,
|
|
|
+ log_level: Union[int, str],
|
|
|
+ # cache: list,
|
|
|
+ key_file: Optional[str],
|
|
|
+ train_config: Optional[str],
|
|
|
+ model_file: Optional[str],
|
|
|
+ output_dir: Optional[str] = None,
|
|
|
+ param_dict: dict = None,
|
|
|
+ **kwargs,
|
|
|
):
|
|
|
assert check_argument_types()
|
|
|
ncpu = kwargs.get("ncpu", 1)
|
|
|
@@ -150,11 +132,11 @@ def inference_punc_vad_realtime(
|
|
|
text2punc = Text2PuncVADRealtime(train_config, model_file, device)
|
|
|
|
|
|
def _forward(
|
|
|
- data_path_and_name_and_type,
|
|
|
- raw_inputs: Union[List[Any], bytes, str] = None,
|
|
|
- output_dir_v2: Optional[str] = None,
|
|
|
- cache: List[Any] = None,
|
|
|
- param_dict: dict = None,
|
|
|
+ data_path_and_name_and_type,
|
|
|
+ raw_inputs: Union[List[Any], bytes, str] = None,
|
|
|
+ output_dir_v2: Optional[str] = None,
|
|
|
+ cache: List[Any] = None,
|
|
|
+ param_dict: dict = None,
|
|
|
):
|
|
|
results = []
|
|
|
split_size = 10
|
|
|
@@ -177,7 +159,6 @@ def inference_punc_vad_realtime(
|
|
|
return _forward
|
|
|
|
|
|
|
|
|
-
|
|
|
def inference_launch(mode, **kwargs):
|
|
|
if mode == "punc":
|
|
|
return inference_punc(**kwargs)
|
|
|
@@ -187,6 +168,7 @@ def inference_launch(mode, **kwargs):
|
|
|
logging.info("Unknown decoding mode: {}".format(mode))
|
|
|
return None
|
|
|
|
|
|
+
|
|
|
def get_parser():
|
|
|
parser = config_argparse.ArgumentParser(
|
|
|
description="Punctuation inference",
|
|
|
@@ -269,6 +251,5 @@ def main(cmd=None):
|
|
|
return inference_pipeline(kwargs["data_path_and_name_and_type"])
|
|
|
|
|
|
|
|
|
-
|
|
|
if __name__ == "__main__":
|
|
|
main()
|