punctuation_infer_vadrealtime.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. #!/usr/bin/env python3
  2. import argparse
  3. import logging
  4. from pathlib import Path
  5. import sys
  6. from typing import Optional
  7. from typing import Sequence
  8. from typing import Tuple
  9. from typing import Union
  10. from typing import Any
  11. from typing import List
  12. import numpy as np
  13. import torch
  14. from typeguard import check_argument_types
  15. from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
  16. from funasr.utils.cli_utils import get_commandline_args
  17. from funasr.tasks.punctuation import PunctuationTask
  18. from funasr.torch_utils.device_funcs import to_device
  19. from funasr.torch_utils.forward_adaptor import ForwardAdaptor
  20. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  21. from funasr.utils import config_argparse
  22. from funasr.utils.types import str2triple_str
  23. from funasr.utils.types import str_or_none
  24. from funasr.datasets.preprocessor import split_to_mini_sentence
  25. class Text2Punc:
  26. def __init__(
  27. self,
  28. train_config: Optional[str],
  29. model_file: Optional[str],
  30. device: str = "cpu",
  31. dtype: str = "float32",
  32. ):
  33. # Build Model
  34. model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
  35. self.device = device
  36. # Wrape model to make model.nll() data-parallel
  37. self.wrapped_model = ForwardAdaptor(model, "inference")
  38. self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
  39. # logging.info(f"Model:\n{model}")
  40. self.punc_list = train_args.punc_list
  41. self.period = 0
  42. for i in range(len(self.punc_list)):
  43. if self.punc_list[i] == ",":
  44. self.punc_list[i] = ","
  45. elif self.punc_list[i] == "?":
  46. self.punc_list[i] = "?"
  47. elif self.punc_list[i] == "。":
  48. self.period = i
  49. self.preprocessor = CodeMixTokenizerCommonPreprocessor(
  50. train=False,
  51. token_type=train_args.token_type,
  52. token_list=train_args.token_list,
  53. bpemodel=train_args.bpemodel,
  54. text_cleaner=train_args.cleaner,
  55. g2p_type=train_args.g2p,
  56. text_name="text",
  57. non_linguistic_symbols=train_args.non_linguistic_symbols,
  58. )
  59. print("start decoding!!!")
  60. @torch.no_grad()
  61. def __call__(self, text: Union[list, str], cache: list, split_size=20):
  62. if cache is not None and len(cache) > 0:
  63. precache = "".join(cache)
  64. else:
  65. precache = ""
  66. cache = []
  67. data = {"text": precache + text}
  68. result = self.preprocessor(data=data, uid="12938712838719")
  69. split_text = self.preprocessor.pop_split_text_data(result)
  70. mini_sentences = split_to_mini_sentence(split_text, split_size)
  71. mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
  72. assert len(mini_sentences) == len(mini_sentences_id)
  73. cache_sent = []
  74. cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
  75. sentence_punc_list = []
  76. sentence_words_list= []
  77. cache_pop_trigger_limit = 200
  78. skip_num = 0
  79. for mini_sentence_i in range(len(mini_sentences)):
  80. mini_sentence = mini_sentences[mini_sentence_i]
  81. mini_sentence_id = mini_sentences_id[mini_sentence_i]
  82. mini_sentence = cache_sent + mini_sentence
  83. mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
  84. data = {
  85. "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
  86. "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
  87. "vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')),
  88. }
  89. data = to_device(data, self.device)
  90. y, _ = self.wrapped_model(**data)
  91. _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
  92. punctuations = indices
  93. if indices.size()[0] != 1:
  94. punctuations = torch.squeeze(indices)
  95. assert punctuations.size()[0] == len(mini_sentence)
  96. # Search for the last Period/QuestionMark as cache
  97. if mini_sentence_i < len(mini_sentences) - 1:
  98. sentenceEnd = -1
  99. last_comma_index = -1
  100. for i in range(len(punctuations) - 2, 1, -1):
  101. if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?":
  102. sentenceEnd = i
  103. break
  104. if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
  105. last_comma_index = i
  106. if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
  107. # The sentence it too long, cut off at a comma.
  108. sentenceEnd = last_comma_index
  109. punctuations[sentenceEnd] = self.period
  110. cache_sent = mini_sentence[sentenceEnd + 1:]
  111. cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
  112. mini_sentence = mini_sentence[0:sentenceEnd + 1]
  113. punctuations = punctuations[0:sentenceEnd + 1]
  114. punctuations_np = punctuations.cpu().numpy()
  115. sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
  116. sentence_words_list += mini_sentence
  117. assert len(sentence_punc_list) == len(sentence_words_list)
  118. words_with_punc = []
  119. sentence_punc_list_out = []
  120. for i in range(0, len(sentence_words_list)):
  121. if i > 0:
  122. if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
  123. sentence_words_list[i] = " " + sentence_words_list[i]
  124. if skip_num < len(cache):
  125. skip_num += 1
  126. else:
  127. words_with_punc.append(sentence_words_list[i])
  128. if skip_num >= len(cache):
  129. sentence_punc_list_out.append(sentence_punc_list[i])
  130. if sentence_punc_list[i] != "_":
  131. words_with_punc.append(sentence_punc_list[i])
  132. sentence_out = "".join(words_with_punc)
  133. sentenceEnd = -1
  134. for i in range(len(sentence_punc_list) - 2, 1, -1):
  135. if sentence_punc_list[i] == "。" or sentence_punc_list[i] == "?":
  136. sentenceEnd = i
  137. break
  138. cache_out = sentence_words_list[sentenceEnd + 1 :]
  139. if sentence_out[-1] in self.punc_list:
  140. sentence_out = sentence_out[:-1]
  141. sentence_punc_list_out[-1] = "_"
  142. return sentence_out, sentence_punc_list_out, cache_out
  143. def inference(
  144. batch_size: int,
  145. dtype: str,
  146. ngpu: int,
  147. seed: int,
  148. num_workers: int,
  149. output_dir: str,
  150. log_level: Union[int, str],
  151. train_config: Optional[str],
  152. model_file: Optional[str],
  153. key_file: Optional[str] = None,
  154. data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
  155. raw_inputs: Union[List[Any], bytes, str] = None,
  156. cache: List[Any] = None,
  157. param_dict: dict = None,
  158. **kwargs,
  159. ):
  160. inference_pipeline = inference_modelscope(
  161. output_dir=output_dir,
  162. batch_size=batch_size,
  163. dtype=dtype,
  164. ngpu=ngpu,
  165. seed=seed,
  166. num_workers=num_workers,
  167. log_level=log_level,
  168. key_file=key_file,
  169. train_config=train_config,
  170. model_file=model_file,
  171. param_dict=param_dict,
  172. **kwargs,
  173. )
  174. return inference_pipeline(data_path_and_name_and_type, raw_inputs, cache)
  175. def inference_modelscope(
  176. batch_size: int,
  177. dtype: str,
  178. ngpu: int,
  179. seed: int,
  180. num_workers: int,
  181. log_level: Union[int, str],
  182. #cache: list,
  183. key_file: Optional[str],
  184. train_config: Optional[str],
  185. model_file: Optional[str],
  186. output_dir: Optional[str] = None,
  187. param_dict: dict = None,
  188. **kwargs,
  189. ):
  190. assert check_argument_types()
  191. logging.basicConfig(
  192. level=log_level,
  193. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  194. )
  195. if ngpu >= 1 and torch.cuda.is_available():
  196. device = "cuda"
  197. else:
  198. device = "cpu"
  199. # 1. Set random-seed
  200. set_all_random_seed(seed)
  201. text2punc = Text2Punc(train_config, model_file, device)
  202. def _forward(
  203. data_path_and_name_and_type,
  204. raw_inputs: Union[List[Any], bytes, str] = None,
  205. output_dir_v2: Optional[str] = None,
  206. cache: List[Any] = None,
  207. param_dict: dict = None,
  208. ):
  209. results = []
  210. split_size = 10
  211. cache_in = param_dict["cache"]
  212. if raw_inputs != None:
  213. line = raw_inputs.strip()
  214. key = "demo"
  215. if line == "":
  216. item = {'key': key, 'value': ""}
  217. results.append(item)
  218. return results
  219. result, _, cache = text2punc(line, cache_in)
  220. param_dict["cache"] = cache
  221. item = {'key': key, 'value': result}
  222. results.append(item)
  223. return results
  224. return results
  225. return _forward
  226. def get_parser():
  227. parser = config_argparse.ArgumentParser(
  228. description="Punctuation inference",
  229. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  230. )
  231. parser.add_argument(
  232. "--log_level",
  233. type=lambda x: x.upper(),
  234. default="INFO",
  235. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  236. help="The verbose level of logging",
  237. )
  238. parser.add_argument("--output_dir", type=str, required=False)
  239. parser.add_argument(
  240. "--ngpu",
  241. type=int,
  242. default=0,
  243. help="The number of gpus. 0 indicates CPU mode",
  244. )
  245. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  246. parser.add_argument(
  247. "--dtype",
  248. default="float32",
  249. choices=["float16", "float32", "float64"],
  250. help="Data type",
  251. )
  252. parser.add_argument(
  253. "--num_workers",
  254. type=int,
  255. default=1,
  256. help="The number of workers used for DataLoader",
  257. )
  258. parser.add_argument(
  259. "--batch_size",
  260. type=int,
  261. default=1,
  262. help="The batch size for inference",
  263. )
  264. group = parser.add_argument_group("Input data related")
  265. group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
  266. group.add_argument("--raw_inputs", type=str, required=False)
  267. group.add_argument("--cache", type=list, required=False)
  268. group.add_argument("--param_dict", type=dict, required=False)
  269. group.add_argument("--key_file", type=str_or_none)
  270. group = parser.add_argument_group("The model configuration related")
  271. group.add_argument("--train_config", type=str)
  272. group.add_argument("--model_file", type=str)
  273. return parser
  274. def main(cmd=None):
  275. print(get_commandline_args(), file=sys.stderr)
  276. parser = get_parser()
  277. args = parser.parse_args(cmd)
  278. kwargs = vars(args)
  279. # kwargs.pop("config", None)
  280. inference(**kwargs)
  281. if __name__ == "__main__":
  282. main()