Просмотр исходного кода

Merge pull request #57 from alibaba-damo-academy/dev_cmz

update punc and asr_inference_paraformer_vad_punc
zhifu gao 3 лет назад
Родитель
Сommit
cc7020e078

+ 7 - 100
funasr/bin/asr_inference_paraformer_vad_punc.py

@@ -1,9 +1,10 @@
 #!/usr/bin/env python3
 #!/usr/bin/env python3
+
+import json
 import argparse
 import argparse
 import logging
 import logging
 import sys
 import sys
 import time
 import time
-import json
 from pathlib import Path
 from pathlib import Path
 from typing import Optional
 from typing import Optional
 from typing import Sequence
 from typing import Sequence
@@ -38,10 +39,10 @@ from funasr.utils import asr_utils, wav_utils, postprocess_utils
 from funasr.models.frontend.wav_frontend import WavFrontend
 from funasr.models.frontend.wav_frontend import WavFrontend
 from funasr.tasks.vad import VADTask
 from funasr.tasks.vad import VADTask
 from funasr.utils.timestamp_tools import time_stamp_lfr6
 from funasr.utils.timestamp_tools import time_stamp_lfr6
-from funasr.tasks.punctuation import PunctuationTask
+from funasr.bin.punctuation_infer import Text2Punc
 from funasr.torch_utils.forward_adaptor import ForwardAdaptor
 from funasr.torch_utils.forward_adaptor import ForwardAdaptor
 from funasr.datasets.preprocessor import CommonPreprocessor
 from funasr.datasets.preprocessor import CommonPreprocessor
-from funasr.punctuation.text_preprocessor import split_words, split_to_mini_sentence
+from funasr.punctuation.text_preprocessor import split_to_mini_sentence
 
 
 header_colors = '\033[95m'
 header_colors = '\033[95m'
 end_colors = '\033[0m'
 end_colors = '\033[0m'
@@ -235,9 +236,9 @@ class Speech2Text:
 
 
         predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
         predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
         pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], predictor_outs[2], predictor_outs[3]
         pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], predictor_outs[2], predictor_outs[3]
+        pre_token_length = pre_token_length.round().long()
         if torch.max(pre_token_length) < 1:
         if torch.max(pre_token_length) < 1:
             return []
             return []
-        pre_token_length = pre_token_length.round().long()
         decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
         decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
         decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
         decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
 
 
@@ -481,6 +482,7 @@ def inference_modelscope(
     punc_infer_config: Optional[str] = None,
     punc_infer_config: Optional[str] = None,
     punc_model_file: Optional[str] = None,
     punc_model_file: Optional[str] = None,
     outputs_dict: Optional[bool] = True,
     outputs_dict: Optional[bool] = True,
+    param_dict: dict = None,
     **kwargs,
     **kwargs,
 ):
 ):
     assert check_argument_types()
     assert check_argument_types()
@@ -546,6 +548,7 @@ def inference_modelscope(
     def _forward(data_path_and_name_and_type,
     def _forward(data_path_and_name_and_type,
                  raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                  raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                  output_dir_v2: Optional[str] = None,
                  output_dir_v2: Optional[str] = None,
+                 param_dict: dict = None,
                  ):
                  ):
         # 3. Build data-iterator
         # 3. Build data-iterator
         if data_path_and_name_and_type is None and raw_inputs is not None:
         if data_path_and_name_and_type is None and raw_inputs is not None:
@@ -680,102 +683,6 @@ def inference_modelscope(
         return asr_result_list
         return asr_result_list
     return _forward
     return _forward
 
 
-def Text2Punc(
-    train_config: Optional[str],
-    model_file: Optional[str],
-    device: str = "cpu",
-    dtype: str = "float32",
-):
-   
-    # 2. Build Model
-    model, train_args = PunctuationTask.build_model_from_file(
-        train_config, model_file, device)
-    # Wrape model to make model.nll() data-parallel
-    wrapped_model = ForwardAdaptor(model, "inference")
-    wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
-    # logging.info(f"Model:\n{model}")
-    punc_list = train_args.punc_list
-    period = 0
-    for i in range(len(punc_list)):
-        if punc_list[i] == ",":
-            punc_list[i] = ","
-        elif punc_list[i] == "?":
-            punc_list[i] = "?"
-        elif punc_list[i] == "。":
-            period = i
-    preprocessor = CommonPreprocessor(
-        train=False,
-        token_type="word",
-        token_list=train_args.token_list,
-        bpemodel=train_args.bpemodel,
-        text_cleaner=train_args.cleaner,
-        g2p_type=train_args.g2p,
-        text_name="text",
-        non_linguistic_symbols=train_args.non_linguistic_symbols,
-    )
-
-    print("start decoding!!!")
-    
-    def _forward(words, split_size = 20):
-        cache_sent = []
-        mini_sentences = split_to_mini_sentence(words, split_size)
-        new_mini_sentence = ""
-        new_mini_sentence_punc = []
-        cache_pop_trigger_limit = 200
-        for mini_sentence_i in range(len(mini_sentences)):
-            mini_sentence = mini_sentences[mini_sentence_i]
-            mini_sentence = cache_sent + mini_sentence
-            data = {"text": " ".join(mini_sentence)}
-            batch = preprocessor(data=data, uid="12938712838719")
-            batch["text_lengths"] = torch.from_numpy(np.array([len(batch["text"])], dtype='int32'))
-            batch["text"] = torch.from_numpy(batch["text"])
-            # Extend one dimension to fake a batch dim.
-            batch["text"] = torch.unsqueeze(batch["text"], 0)
-            batch = to_device(batch, device)
-            y, _ = wrapped_model(**batch)
-            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
-            punctuations = indices
-            if indices.size()[0] != 1:
-                punctuations = torch.squeeze(indices)
-            assert punctuations.size()[0] == len(mini_sentence)
-
-            # Search for the last Period/QuestionMark as cache
-            if mini_sentence_i < len(mini_sentences) - 1:
-                sentenceEnd = -1
-                last_comma_index = -1
-                for i in range(len(punctuations) - 2, 1, -1):
-                    if punc_list[punctuations[i]] == "。" or punc_list[punctuations[i]] == "?":
-                        sentenceEnd = i
-                        break
-                    if last_comma_index < 0 and punc_list[punctuations[i]] == ",":
-                        last_comma_index = i
-
-                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
-                    # The sentence it too long, cut off at a comma.
-                    sentenceEnd = last_comma_index
-                    punctuations[sentenceEnd] = period
-                cache_sent = mini_sentence[sentenceEnd + 1:]
-                mini_sentence = mini_sentence[0:sentenceEnd + 1]
-                punctuations = punctuations[0:sentenceEnd + 1]
-
-            # if len(punctuations) == 0:
-            #    continue
-
-            punctuations_np = punctuations.cpu().numpy()
-            new_mini_sentence_punc += [int(x) for x in punctuations_np]
-            words_with_punc = []
-            for i in range(len(mini_sentence)):
-                if i > 0:
-                    if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
-                        mini_sentence[i] = " " + mini_sentence[i]
-                words_with_punc.append(mini_sentence[i])
-                if punc_list[punctuations[i]] != "_":
-                    words_with_punc.append(punc_list[punctuations[i]])
-            new_mini_sentence += "".join(words_with_punc)
-
-        return new_mini_sentence, new_mini_sentence_punc
-    return _forward
-
 def get_parser():
 def get_parser():
     parser = config_argparse.ArgumentParser(
     parser = config_argparse.ArgumentParser(
         description="ASR Decoding",
         description="ASR Decoding",

+ 5 - 13
funasr/bin/punc_inference_launch.py

@@ -59,26 +59,18 @@ def get_parser():
     )
     )
 
 
     group = parser.add_argument_group("Input data related")
     group = parser.add_argument_group("Input data related")
-    group.add_argument(
-        "--data_path_and_name_and_type",
-        type=str2triple_str,
-        action="append",
-        required=False
-    )
-    group.add_argument(
-        "--raw_inputs",
-        type=str,
-        required=False
-    )
+    group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
+    group.add_argument("--raw_inputs", type=str, required=False)
     group.add_argument("--key_file", type=str_or_none)
     group.add_argument("--key_file", type=str_or_none)
-
-
+    group.add_argument("--cache", type=list, required=False)
+    group.add_argument("--param_dict", type=dict, required=False)
     group = parser.add_argument_group("The model configuration related")
     group = parser.add_argument_group("The model configuration related")
     group.add_argument("--train_config", type=str)
     group.add_argument("--train_config", type=str)
     group.add_argument("--model_file", type=str)
     group.add_argument("--model_file", type=str)
     group.add_argument("--mode", type=str, default="punc")
     group.add_argument("--mode", type=str, default="punc")
     return parser
     return parser
 
 
+
 def inference_launch(mode, **kwargs):
 def inference_launch(mode, **kwargs):
     if mode == "punc":
     if mode == "punc":
         from funasr.bin.punctuation_infer import inference_modelscope
         from funasr.bin.punctuation_infer import inference_modelscope

+ 138 - 190
funasr/bin/punctuation_infer.py

@@ -3,33 +3,141 @@ import argparse
 import logging
 import logging
 from pathlib import Path
 from pathlib import Path
 import sys
 import sys
-import os
 from typing import Optional
 from typing import Optional
 from typing import Sequence
 from typing import Sequence
 from typing import Tuple
 from typing import Tuple
 from typing import Union
 from typing import Union
-from typing import Dict
 from typing import Any
 from typing import Any
 from typing import List
 from typing import List
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
-from torch.nn.parallel import data_parallel
 from typeguard import check_argument_types
 from typeguard import check_argument_types
 
 
-from funasr.datasets.preprocessor import CommonPreprocessor
+from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
 from funasr.utils.cli_utils import get_commandline_args
 from funasr.utils.cli_utils import get_commandline_args
-from funasr.fileio.datadir_writer import DatadirWriter
 from funasr.tasks.punctuation import PunctuationTask
 from funasr.tasks.punctuation import PunctuationTask
 from funasr.torch_utils.device_funcs import to_device
 from funasr.torch_utils.device_funcs import to_device
 from funasr.torch_utils.forward_adaptor import ForwardAdaptor
 from funasr.torch_utils.forward_adaptor import ForwardAdaptor
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
 from funasr.utils import config_argparse
 from funasr.utils import config_argparse
-from funasr.utils.types import float_or_none
-from funasr.utils.types import str2bool
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
 from funasr.utils.types import str_or_none
-from funasr.punctuation.text_preprocessor import split_words, split_to_mini_sentence
+from funasr.punctuation.text_preprocessor import split_to_mini_sentence
+
+
+class Text2Punc:
+
+    def __init__(
+        self,
+        train_config: Optional[str],
+        model_file: Optional[str],
+        device: str = "cpu",
+        dtype: str = "float32",
+    ):
+        #  Build Model
+        model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
+        self.device = device
+        # Wrape model to make model.nll() data-parallel
+        self.wrapped_model = ForwardAdaptor(model, "inference")
+        self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
+        # logging.info(f"Model:\n{model}")
+        self.punc_list = train_args.punc_list
+        self.period = 0
+        for i in range(len(self.punc_list)):
+            if self.punc_list[i] == ",":
+                self.punc_list[i] = ","
+            elif self.punc_list[i] == "?":
+                self.punc_list[i] = "?"
+            elif self.punc_list[i] == "。":
+                self.period = i
+        self.preprocessor = CodeMixTokenizerCommonPreprocessor(
+            train=False,
+            token_type=train_args.token_type,
+            token_list=train_args.token_list,
+            bpemodel=train_args.bpemodel,
+            text_cleaner=train_args.cleaner,
+            g2p_type=train_args.g2p,
+            text_name="text",
+            non_linguistic_symbols=train_args.non_linguistic_symbols,
+        )
+        print("start decoding!!!")
+
+    @torch.no_grad()
+    def __call__(self, text: Union[list, str], split_size=20):
+        data = {"text": text}
+        result = self.preprocessor(data=data, uid="12938712838719")
+        split_text = self.preprocessor.pop_split_text_data(result)
+        mini_sentences = split_to_mini_sentence(split_text, split_size)
+        mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
+        assert len(mini_sentences) == len(mini_sentences_id)
+        cache_sent = []
+        cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
+        new_mini_sentence = ""
+        new_mini_sentence_punc = []
+        cache_pop_trigger_limit = 200
+        for mini_sentence_i in range(len(mini_sentences)):
+            mini_sentence = mini_sentences[mini_sentence_i]
+            mini_sentence_id = mini_sentences_id[mini_sentence_i]
+            mini_sentence = cache_sent + mini_sentence
+            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+            data = {
+                "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
+                "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
+            }
+            data = to_device(data, self.device)
+            y, _ = self.wrapped_model(**data)
+            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
+            punctuations = indices
+            if indices.size()[0] != 1:
+                punctuations = torch.squeeze(indices)
+            assert punctuations.size()[0] == len(mini_sentence)
+
+            # Search for the last Period/QuestionMark as cache
+            if mini_sentence_i < len(mini_sentences) - 1:
+                sentenceEnd = -1
+                last_comma_index = -1
+                for i in range(len(punctuations) - 2, 1, -1):
+                    if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?":
+                        sentenceEnd = i
+                        break
+                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
+                        last_comma_index = i
+
+                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+                    # The sentence it too long, cut off at a comma.
+                    sentenceEnd = last_comma_index
+                    punctuations[sentenceEnd] = self.period
+                cache_sent = mini_sentence[sentenceEnd + 1:]
+                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
+                mini_sentence = mini_sentence[0:sentenceEnd + 1]
+                punctuations = punctuations[0:sentenceEnd + 1]
+
+            # if len(punctuations) == 0:
+            #    continue
+
+            punctuations_np = punctuations.cpu().numpy()
+            new_mini_sentence_punc += [int(x) for x in punctuations_np]
+            words_with_punc = []
+            for i in range(len(mini_sentence)):
+                if i > 0:
+                    if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
+                        mini_sentence[i] = " " + mini_sentence[i]
+                words_with_punc.append(mini_sentence[i])
+                if self.punc_list[punctuations[i]] != "_":
+                    words_with_punc.append(self.punc_list[punctuations[i]])
+            new_mini_sentence += "".join(words_with_punc)
+            # Add Period for the end of the sentence
+            new_mini_sentence_out = new_mini_sentence
+            new_mini_sentence_punc_out = new_mini_sentence_punc
+            if mini_sentence_i == len(mini_sentences) - 1:
+                if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、":
+                    new_mini_sentence_out = new_mini_sentence[:-1] + "。"
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+                elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?":
+                    new_mini_sentence_out = new_mini_sentence + "。"
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+        return new_mini_sentence_out, new_mini_sentence_punc_out
 
 
 
 
 def inference(
 def inference(
@@ -45,12 +153,12 @@ def inference(
     key_file: Optional[str] = None,
     key_file: Optional[str] = None,
     data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
     data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
     raw_inputs: Union[List[Any], bytes, str] = None,
     raw_inputs: Union[List[Any], bytes, str] = None,
-    
+    cache: List[Any] = None,
+    param_dict: dict = None,
     **kwargs,
     **kwargs,
 ):
 ):
     inference_pipeline = inference_modelscope(
     inference_pipeline = inference_modelscope(
         output_dir=output_dir,
         output_dir=output_dir,
-        raw_inputs=raw_inputs,
         batch_size=batch_size,
         batch_size=batch_size,
         dtype=dtype,
         dtype=dtype,
         ngpu=ngpu,
         ngpu=ngpu,
@@ -60,6 +168,7 @@ def inference(
         key_file=key_file,
         key_file=key_file,
         train_config=train_config,
         train_config=train_config,
         model_file=model_file,
         model_file=model_file,
+        param_dict=param_dict,
         **kwargs,
         **kwargs,
     )
     )
     return inference_pipeline(data_path_and_name_and_type, raw_inputs)
     return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@@ -76,6 +185,7 @@ def inference_modelscope(
     train_config: Optional[str],
     train_config: Optional[str],
     model_file: Optional[str],
     model_file: Optional[str],
     output_dir: Optional[str] = None,
     output_dir: Optional[str] = None,
+    param_dict: dict = None,
     **kwargs,
     **kwargs,
 ):
 ):
     assert check_argument_types()
     assert check_argument_types()
@@ -91,41 +201,14 @@ def inference_modelscope(
 
 
     # 1. Set random-seed
     # 1. Set random-seed
     set_all_random_seed(seed)
     set_all_random_seed(seed)
-
-    # 2. Build Model
-    model, train_args = PunctuationTask.build_model_from_file(
-        train_config, model_file, device)
-    # Wrape model to make model.nll() data-parallel
-    wrapped_model = ForwardAdaptor(model, "inference")
-    wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
-    logging.info(f"Model:\n{model}")
-    punc_list = train_args.punc_list
-    period = 0
-    for i in range(len(punc_list)):
-        if punc_list[i] == ",":
-            punc_list[i] = ","
-        elif punc_list[i] == "?":
-            punc_list[i] = "?"
-        elif punc_list[i] == "。":
-            period = i
-
-    preprocessor = CommonPreprocessor(
-        train=False,
-        token_type="word",
-        token_list=train_args.token_list,
-        bpemodel=train_args.bpemodel,
-        text_cleaner=train_args.cleaner,
-        g2p_type=train_args.g2p,
-        text_name="text",
-        non_linguistic_symbols=train_args.non_linguistic_symbols,
-    )
-
-    print("start decoding!!!")
+    text2punc = Text2Punc(train_config, model_file, device)
 
 
     def _forward(
     def _forward(
         data_path_and_name_and_type,
         data_path_and_name_and_type,
         raw_inputs: Union[List[Any], bytes, str] = None,
         raw_inputs: Union[List[Any], bytes, str] = None,
         output_dir_v2: Optional[str] = None,
         output_dir_v2: Optional[str] = None,
+        cache: List[Any] = None,
+        param_dict: dict = None,
     ):
     ):
         results = []
         results = []
         split_size = 20
         split_size = 20
@@ -133,77 +216,14 @@ def inference_modelscope(
         if raw_inputs != None:
         if raw_inputs != None:
             line = raw_inputs.strip()
             line = raw_inputs.strip()
             key = "demo"
             key = "demo"
-            if line=="":
+            if line == "":
                 item = {'key': key, 'value': ""}
                 item = {'key': key, 'value': ""}
                 results.append(item)
                 results.append(item)
                 return results
                 return results
-            cache_sent = []
-            words = split_words(line)
-            new_mini_sentence = ""
-            new_mini_sentence_punc = ""
-            cache_pop_trigger_limit = 200
-            mini_sentences = split_to_mini_sentence(words, split_size)
-            for mini_sentence_i in range(len(mini_sentences)):
-                mini_sentence = mini_sentences[mini_sentence_i]
-                mini_sentence = cache_sent + mini_sentence
-                data = {"text": " ".join(mini_sentence)}
-                batch = preprocessor(data=data, uid="12938712838719")
-                batch["text_lengths"] = torch.from_numpy(
-                    np.array([len(batch["text"])], dtype='int32'))
-                batch["text"] = torch.from_numpy(batch["text"])
-                # Extend one dimension to fake a batch dim.
-                batch["text"] = torch.unsqueeze(batch["text"], 0)
-                batch = to_device(batch, device)
-                y, _ = wrapped_model(**batch)
-                _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
-                punctuations = indices
-                if indices.size()[0] != 1:
-                    punctuations = torch.squeeze(indices)
-                assert punctuations.size()[0] == len(mini_sentence)
-    
-                # Search for the last Period/QuestionMark as cache 
-                if mini_sentence_i < len(mini_sentences)-1:
-                    sentenceEnd = -1
-                    last_comma_index = -1
-                    for i in range(len(punctuations)-2,1,-1):
-                        if punc_list[punctuations[i]] == "。" or punc_list[punctuations[i]] == "?":
-                            sentenceEnd = i
-                            break
-                        if last_comma_index < 0 and punc_list[punctuations[i]] == ",":
-                            last_comma_index = i
-                    if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
-                        # The sentence it too long, cut off at a comma.
-                        sentenceEnd = last_comma_index
-                        punctuations[sentenceEnd] = period
-                    cache_sent = mini_sentence[sentenceEnd+1:]
-                    mini_sentence = mini_sentence[0:sentenceEnd+1]
-                    punctuations = punctuations[0:sentenceEnd+1]
-    
-                punctuations_np = punctuations.cpu().numpy()
-                new_mini_sentence_punc += "".join([str(x) for x in punctuations_np])
-                words_with_punc = []
-                for i in range(len(mini_sentence)):
-                    if i>0:
-                        if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i-1][0].encode()) == 1:
-                            mini_sentence[i] = " "+ mini_sentence[i]
-                    words_with_punc.append(mini_sentence[i])
-                    if punc_list[punctuations[i]] != "_":
-                        words_with_punc.append(punc_list[punctuations[i]])
-                new_mini_sentence += "".join(words_with_punc)
-     
-                # Add Period for the end of the sentence
-                new_mini_sentence_out = new_mini_sentence
-                new_mini_sentence_punc_out = new_mini_sentence_punc
-                if mini_sentence_i == len(mini_sentences)-1:
-                    if new_mini_sentence[-1]=="," or new_mini_sentence[-1]=="、":
-                        new_mini_sentence_out = new_mini_sentence[:-1] + "。"
-                        new_mini_sentence_punc_out  = new_mini_sentence_punc[:-1] + str(period)
-                    elif new_mini_sentence[-1]!="。" and new_mini_sentence[-1]!="?":
-                        new_mini_sentence_out=new_mini_sentence+"。"
-                        new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + str(period)
-                    item = {'key': key, 'value': new_mini_sentence_out}
-                    results.append(item)
-            
+            result, _ = text2punc(line)
+            item = {'key': key, 'value': result}
+            results.append(item)
+            print(results)
             return results
             return results
 
 
         for inference_text, _, _ in data_path_and_name_and_type:
         for inference_text, _, _ in data_path_and_name_and_type:
@@ -216,72 +236,9 @@ def inference_modelscope(
                     key = segs[0]
                     key = segs[0]
                     if len(segs[1]) == 0:
                     if len(segs[1]) == 0:
                         continue
                         continue
-                    cache_sent = []
-                    words = split_words(segs[1])
-                    new_mini_sentence = ""
-                    new_mini_sentence_punc = ""
-                    cache_pop_trigger_limit = 200
-                    mini_sentences = split_to_mini_sentence(words, split_size)
-                    for mini_sentence_i in range(len(mini_sentences)):
-                        mini_sentence = mini_sentences[mini_sentence_i]
-                        mini_sentence = cache_sent + mini_sentence
-                        data = {"text": " ".join(mini_sentence)}
-                        batch = preprocessor(data=data, uid="12938712838719")
-                        batch["text_lengths"] = torch.from_numpy(
-                            np.array([len(batch["text"])], dtype='int32'))
-                        batch["text"] = torch.from_numpy(batch["text"])
-                        # Extend one dimension to fake a batch dim.
-                        batch["text"] = torch.unsqueeze(batch["text"], 0)
-                        batch = to_device(batch, device)
-                        y, _ = wrapped_model(**batch)
-                        _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
-                        punctuations = indices
-                        if indices.size()[0] != 1:
-                            punctuations = torch.squeeze(indices)
-                        assert punctuations.size()[0] == len(mini_sentence)
-    
-                        # Search for the last Period/QuestionMark as cache 
-                        if mini_sentence_i < len(mini_sentences)-1:
-                            sentenceEnd = -1
-                            last_comma_index = -1
-                            for i in range(len(punctuations)-2,1,-1):
-                                if punc_list[punctuations[i]] == "。" or punc_list[punctuations[i]] == "?":
-                                    sentenceEnd = i
-                                    break
-                                if last_comma_index < 0 and punc_list[punctuations[i]] == ",":
-                                    last_comma_index = i
-                            if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
-                                # The sentence it too long, cut off at a comma.
-                                sentenceEnd = last_comma_index
-                                punctuations[sentenceEnd] = period
-                            cache_sent = mini_sentence[sentenceEnd+1:]
-                            mini_sentence = mini_sentence[0:sentenceEnd+1]
-                            punctuations = punctuations[0:sentenceEnd+1]
-    
-                        punctuations_np = punctuations.cpu().numpy()
-                        new_mini_sentence_punc += "".join([str(x) for x in punctuations_np])
-                        words_with_punc = []
-                        for i in range(len(mini_sentence)):
-                            if i>0:
-                                if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i-1][0].encode()) == 1:
-                                    mini_sentence[i] = " "+ mini_sentence[i]
-                            words_with_punc.append(mini_sentence[i])
-                            if punc_list[punctuations[i]] != "_":
-                                words_with_punc.append(punc_list[punctuations[i]])
-                        new_mini_sentence += "".join(words_with_punc)
-     
-                        # Add Period for the end of the sentence
-                        new_mini_sentence_out = new_mini_sentence
-                        new_mini_sentence_punc_out = new_mini_sentence_punc
-                        if mini_sentence_i == len(mini_sentences)-1:
-                            if new_mini_sentence[-1]=="," or new_mini_sentence[-1]=="、":
-                                new_mini_sentence_out = new_mini_sentence[:-1] + "。"
-                                new_mini_sentence_punc_out  = new_mini_sentence_punc[:-1] + str(period)
-                            elif new_mini_sentence[-1]!="。" and new_mini_sentence[-1]!="?":
-                                new_mini_sentence_out=new_mini_sentence+"。"
-                                new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + str(period)
-                            item = {'key': key, 'value': new_mini_sentence_out}
-                            results.append(item)
+                    result, _ = text2punc(segs[1])
+                    item = {'key': key, 'value': result}
+                    results.append(item)
         output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
         output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
         if output_path != None:
         if output_path != None:
             output_file_name = "infer.out"
             output_file_name = "infer.out"
@@ -293,6 +250,7 @@ def inference_modelscope(
                     value_out = item_i["value"]
                     value_out = item_i["value"]
                     fout.write(f"{key_out}\t{value_out}\n")
                     fout.write(f"{key_out}\t{value_out}\n")
         return results
         return results
+
     return _forward
     return _forward
 
 
 
 
@@ -338,20 +296,12 @@ def get_parser():
     )
     )
 
 
     group = parser.add_argument_group("Input data related")
     group = parser.add_argument_group("Input data related")
-    group.add_argument(
-        "--data_path_and_name_and_type",
-        type=str2triple_str,
-        action="append",
-        required=False
-    )
-    group.add_argument(
-        "--raw_inputs",
-        type=str,
-        required=False
-    )
+    group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
+    group.add_argument("--raw_inputs", type=str, required=False)
+    group.add_argument("--cache", type=list, required=False)
+    group.add_argument("--param_dict", type=dict, required=False)
     group.add_argument("--key_file", type=str_or_none)
     group.add_argument("--key_file", type=str_or_none)
 
 
-
     group = parser.add_argument_group("The model configuration related")
     group = parser.add_argument_group("The model configuration related")
     group.add_argument("--train_config", type=str)
     group.add_argument("--train_config", type=str)
     group.add_argument("--model_file", type=str)
     group.add_argument("--model_file", type=str)
@@ -364,11 +314,9 @@ def main(cmd=None):
     parser = get_parser()
     parser = get_parser()
     args = parser.parse_args(cmd)
     args = parser.parse_args(cmd)
     kwargs = vars(args)
     kwargs = vars(args)
-   # kwargs.pop("config", None)
+    # kwargs.pop("config", None)
     inference(**kwargs)
     inference(**kwargs)
 
 
+
 if __name__ == "__main__":
 if __name__ == "__main__":
     main()
     main()
-
-
-

+ 1 - 3
funasr/punctuation/abs_model.py

@@ -23,7 +23,5 @@ class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
     """
     """
 
 
     @abstractmethod
     @abstractmethod
-    def forward(
-        self, input: torch.Tensor, hidden: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
+    def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
         raise NotImplementedError
         raise NotImplementedError

+ 19 - 22
funasr/punctuation/espnet_model.py

@@ -13,6 +13,7 @@ from funasr.train.abs_espnet_model import AbsESPnetModel
 
 
 
 
 class ESPnetPunctuationModel(AbsESPnetModel):
 class ESPnetPunctuationModel(AbsESPnetModel):
+
     def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0):
     def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0):
         assert check_argument_types()
         assert check_argument_types()
         super().__init__()
         super().__init__()
@@ -43,8 +44,8 @@ class ESPnetPunctuationModel(AbsESPnetModel):
         batch_size = text.size(0)
         batch_size = text.size(0)
         # For data parallel
         # For data parallel
         if max_length is None:
         if max_length is None:
-            text = text[:, : text_lengths.max()]
-            punc = punc[:, : text_lengths.max()]
+            text = text[:, :text_lengths.max()]
+            punc = punc[:, :text_lengths.max()]
         else:
         else:
             text = text[:, :max_length]
             text = text[:, :max_length]
             punc = punc[:, :max_length]
             punc = punc[:, :max_length]
@@ -63,9 +64,11 @@ class ESPnetPunctuationModel(AbsESPnetModel):
         # 3. Calc negative log likelihood
         # 3. Calc negative log likelihood
         # nll: (BxL,)
         # nll: (BxL,)
         if self.training == False:
         if self.training == False:
-            _, indices = y.view(-1, y.shape[-1]).topk(1,dim=1)
+            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
             from sklearn.metrics import f1_score
             from sklearn.metrics import f1_score
-            f1_score = f1_score(punc.view(-1).detach().cpu().numpy(), indices.squeeze(-1).detach().cpu().numpy(), average='micro')
+            f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
+                                indices.squeeze(-1).detach().cpu().numpy(),
+                                average='micro')
             nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
             nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
             return nll, text_lengths
             return nll, text_lengths
         else:
         else:
@@ -82,14 +85,12 @@ class ESPnetPunctuationModel(AbsESPnetModel):
         nll = nll.view(batch_size, -1)
         nll = nll.view(batch_size, -1)
         return nll, text_lengths
         return nll, text_lengths
 
 
-    def batchify_nll(
-        self,
-        text: torch.Tensor,
-        punc: torch.Tensor,
-        text_lengths: torch.Tensor,
-        punc_lengths: torch.Tensor,
-        batch_size: int = 100
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
+    def batchify_nll(self,
+                     text: torch.Tensor,
+                     punc: torch.Tensor,
+                     text_lengths: torch.Tensor,
+                     punc_lengths: torch.Tensor,
+                     batch_size: int = 100) -> Tuple[torch.Tensor, torch.Tensor]:
         """Compute negative log likelihood(nll) from transformer language model
         """Compute negative log likelihood(nll) from transformer language model
 
 
         To avoid OOM, this fuction seperate the input into batches.
         To avoid OOM, this fuction seperate the input into batches.
@@ -117,9 +118,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
                 batch_punc = punc[start_idx:end_idx, :]
                 batch_punc = punc[start_idx:end_idx, :]
                 batch_text_lengths = text_lengths[start_idx:end_idx]
                 batch_text_lengths = text_lengths[start_idx:end_idx]
                 # batch_nll: [B * T]
                 # batch_nll: [B * T]
-                batch_nll, batch_x_lengths = self.nll(
-                    batch_text, batch_punc, batch_text_lengths, max_length=max_length
-                )
+                batch_nll, batch_x_lengths = self.nll(batch_text, batch_punc, batch_text_lengths, max_length=max_length)
                 nlls.append(batch_nll)
                 nlls.append(batch_nll)
                 x_lengths.append(batch_x_lengths)
                 x_lengths.append(batch_x_lengths)
                 start_idx = end_idx
                 start_idx = end_idx
@@ -131,21 +130,19 @@ class ESPnetPunctuationModel(AbsESPnetModel):
         assert x_lengths.size(0) == total_num
         assert x_lengths.size(0) == total_num
         return nll, x_lengths
         return nll, x_lengths
 
 
-    def forward(
-        self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor, punc_lengths: torch.Tensor
-    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+    def forward(self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor,
+                punc_lengths: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths)
         nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths)
         ntokens = y_lengths.sum()
         ntokens = y_lengths.sum()
         loss = nll.sum() / ntokens
         loss = nll.sum() / ntokens
         stats = dict(loss=loss.detach())
         stats = dict(loss=loss.detach())
-        
+
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
         loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
         return loss, stats, weight
         return loss, stats, weight
 
 
-    def collect_feats(
-        self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor
-    ) -> Dict[str, torch.Tensor]:
+    def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
+                      text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
         return {}
         return {}
 
 
     def inference(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
     def inference(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:

+ 10 - 20
funasr/punctuation/target_delay_transformer.py

@@ -14,6 +14,7 @@ from funasr.punctuation.abs_model import AbsPunctuation
 
 
 
 
 class TargetDelayTransformer(AbsPunctuation):
 class TargetDelayTransformer(AbsPunctuation):
+
     def __init__(
     def __init__(
         self,
         self,
         vocab_size: int,
         vocab_size: int,
@@ -28,7 +29,7 @@ class TargetDelayTransformer(AbsPunctuation):
     ):
     ):
         super().__init__()
         super().__init__()
         if pos_enc == "sinusoidal":
         if pos_enc == "sinusoidal":
-#            pos_enc_class = PositionalEncoding
+            #            pos_enc_class = PositionalEncoding
             pos_enc_class = SinusoidalPositionEncoder
             pos_enc_class = SinusoidalPositionEncoder
         elif pos_enc is None:
         elif pos_enc is None:
 
 
@@ -47,17 +48,17 @@ class TargetDelayTransformer(AbsPunctuation):
             num_blocks=layer,
             num_blocks=layer,
             dropout_rate=dropout_rate,
             dropout_rate=dropout_rate,
             input_layer="pe",
             input_layer="pe",
-           # pos_enc_class=pos_enc_class,
+            # pos_enc_class=pos_enc_class,
             padding_idx=0,
             padding_idx=0,
         )
         )
         self.decoder = nn.Linear(att_unit, punc_size)
         self.decoder = nn.Linear(att_unit, punc_size)
 
 
+
 #    def _target_mask(self, ys_in_pad):
 #    def _target_mask(self, ys_in_pad):
 #        ys_mask = ys_in_pad != 0
 #        ys_mask = ys_in_pad != 0
 #        m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
 #        m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
 #        return ys_mask.unsqueeze(-2) & m
 #        return ys_mask.unsqueeze(-2) & m
 
 
-
     def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
     def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
         """Compute loss value from buffer sequences.
         """Compute loss value from buffer sequences.
 
 
@@ -67,14 +68,12 @@ class TargetDelayTransformer(AbsPunctuation):
 
 
         """
         """
         x = self.embed(input)
         x = self.embed(input)
-       # mask = self._target_mask(input)
+        # mask = self._target_mask(input)
         h, _, _ = self.encoder(x, text_lengths)
         h, _, _ = self.encoder(x, text_lengths)
         y = self.decoder(h)
         y = self.decoder(h)
         return y, None
         return y, None
 
 
-    def score(
-        self, y: torch.Tensor, state: Any, x: torch.Tensor
-    ) -> Tuple[torch.Tensor, Any]:
+    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
         """Score new token.
         """Score new token.
 
 
         Args:
         Args:
@@ -89,16 +88,12 @@ class TargetDelayTransformer(AbsPunctuation):
 
 
         """
         """
         y = y.unsqueeze(0)
         y = y.unsqueeze(0)
-        h, _, cache = self.encoder.forward_one_step(
-            self.embed(y), self._target_mask(y), cache=state
-        )
+        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
         h = self.decoder(h[:, -1])
         h = self.decoder(h[:, -1])
         logp = h.log_softmax(dim=-1).squeeze(0)
         logp = h.log_softmax(dim=-1).squeeze(0)
         return logp, cache
         return logp, cache
 
 
-    def batch_score(
-        self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
-    ) -> Tuple[torch.Tensor, List[Any]]:
+    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
         """Score new token batch.
         """Score new token batch.
 
 
         Args:
         Args:
@@ -120,15 +115,10 @@ class TargetDelayTransformer(AbsPunctuation):
             batch_state = None
             batch_state = None
         else:
         else:
             # transpose state of [batch, layer] into [layer, batch]
             # transpose state of [batch, layer] into [layer, batch]
-            batch_state = [
-                torch.stack([states[b][i] for b in range(n_batch)])
-                for i in range(n_layers)
-            ]
+            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
 
 
         # batch decoding
         # batch decoding
-        h, _, states = self.encoder.forward_one_step(
-            self.embed(ys), self._target_mask(ys), cache=batch_state
-        )
+        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
         h = self.decoder(h[:, -1])
         h = self.decoder(h[:, -1])
         logp = h.log_softmax(dim=-1)
         logp = h.log_softmax(dim=-1)
 
 

+ 0 - 21
funasr/punctuation/text_preprocessor.py

@@ -1,24 +1,3 @@
-def split_words(text: str):
-    words = []
-    segs = text.split()
-    for seg in segs:
-        # There is no space in seg.
-        current_word = ""
-        for c in seg:
-            if len(c.encode()) == 1:
-                # This is an ASCII char.
-                current_word += c
-            else:
-                # This is a Chinese char.
-                if len(current_word) > 0:
-                    words.append(current_word)
-                    current_word = ""
-                words.append(c)
-        if len(current_word) > 0:
-            words.append(current_word)
-    return words
-
-
 def split_to_mini_sentence(words: list, word_limit: int = 20):
 def split_to_mini_sentence(words: list, word_limit: int = 20):
     assert word_limit > 1
     assert word_limit > 1
     if len(words) <= word_limit:
     if len(words) <= word_limit: