punctuation_infer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. #!/usr/bin/env python3
  2. import argparse
  3. import logging
  4. from pathlib import Path
  5. import sys
  6. from typing import Optional
  7. from typing import Sequence
  8. from typing import Tuple
  9. from typing import Union
  10. from typing import Any
  11. from typing import List
  12. import numpy as np
  13. import torch
  14. from typeguard import check_argument_types
  15. from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
  16. from funasr.utils.cli_utils import get_commandline_args
  17. from funasr.tasks.punctuation import PunctuationTask
  18. from funasr.torch_utils.device_funcs import to_device
  19. from funasr.torch_utils.forward_adaptor import ForwardAdaptor
  20. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  21. from funasr.utils import config_argparse
  22. from funasr.utils.types import str2triple_str
  23. from funasr.utils.types import str_or_none
  24. from funasr.datasets.preprocessor import split_to_mini_sentence
  25. class Text2Punc:
  26. def __init__(
  27. self,
  28. train_config: Optional[str],
  29. model_file: Optional[str],
  30. device: str = "cpu",
  31. dtype: str = "float32",
  32. ):
  33. # Build Model
  34. model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
  35. self.device = device
  36. # Wrape model to make model.nll() data-parallel
  37. self.wrapped_model = ForwardAdaptor(model, "inference")
  38. self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
  39. # logging.info(f"Model:\n{model}")
  40. self.punc_list = train_args.punc_list
  41. self.period = 0
  42. for i in range(len(self.punc_list)):
  43. if self.punc_list[i] == ",":
  44. self.punc_list[i] = ","
  45. elif self.punc_list[i] == "?":
  46. self.punc_list[i] = "?"
  47. elif self.punc_list[i] == "。":
  48. self.period = i
  49. self.preprocessor = CodeMixTokenizerCommonPreprocessor(
  50. train=False,
  51. token_type=train_args.token_type,
  52. token_list=train_args.token_list,
  53. bpemodel=train_args.bpemodel,
  54. text_cleaner=train_args.cleaner,
  55. g2p_type=train_args.g2p,
  56. text_name="text",
  57. non_linguistic_symbols=train_args.non_linguistic_symbols,
  58. )
  59. print("start decoding!!!")
  60. @torch.no_grad()
  61. def __call__(self, text: Union[list, str], split_size=20):
  62. data = {"text": text}
  63. result = self.preprocessor(data=data, uid="12938712838719")
  64. split_text = self.preprocessor.pop_split_text_data(result)
  65. mini_sentences = split_to_mini_sentence(split_text, split_size)
  66. mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
  67. assert len(mini_sentences) == len(mini_sentences_id)
  68. cache_sent = []
  69. cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
  70. new_mini_sentence = ""
  71. new_mini_sentence_punc = []
  72. cache_pop_trigger_limit = 200
  73. for mini_sentence_i in range(len(mini_sentences)):
  74. mini_sentence = mini_sentences[mini_sentence_i]
  75. mini_sentence_id = mini_sentences_id[mini_sentence_i]
  76. mini_sentence = cache_sent + mini_sentence
  77. mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
  78. data = {
  79. "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
  80. "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
  81. }
  82. data = to_device(data, self.device)
  83. y, _ = self.wrapped_model(**data)
  84. _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
  85. punctuations = indices
  86. if indices.size()[0] != 1:
  87. punctuations = torch.squeeze(indices)
  88. assert punctuations.size()[0] == len(mini_sentence)
  89. # Search for the last Period/QuestionMark as cache
  90. if mini_sentence_i < len(mini_sentences) - 1:
  91. sentenceEnd = -1
  92. last_comma_index = -1
  93. for i in range(len(punctuations) - 2, 1, -1):
  94. if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?":
  95. sentenceEnd = i
  96. break
  97. if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
  98. last_comma_index = i
  99. if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
  100. # The sentence it too long, cut off at a comma.
  101. sentenceEnd = last_comma_index
  102. punctuations[sentenceEnd] = self.period
  103. cache_sent = mini_sentence[sentenceEnd + 1:]
  104. cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
  105. mini_sentence = mini_sentence[0:sentenceEnd + 1]
  106. punctuations = punctuations[0:sentenceEnd + 1]
  107. # if len(punctuations) == 0:
  108. # continue
  109. punctuations_np = punctuations.cpu().numpy()
  110. new_mini_sentence_punc += [int(x) for x in punctuations_np]
  111. words_with_punc = []
  112. for i in range(len(mini_sentence)):
  113. if i > 0:
  114. if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
  115. mini_sentence[i] = " " + mini_sentence[i]
  116. words_with_punc.append(mini_sentence[i])
  117. if self.punc_list[punctuations[i]] != "_":
  118. words_with_punc.append(self.punc_list[punctuations[i]])
  119. new_mini_sentence += "".join(words_with_punc)
  120. # Add Period for the end of the sentence
  121. new_mini_sentence_out = new_mini_sentence
  122. new_mini_sentence_punc_out = new_mini_sentence_punc
  123. if mini_sentence_i == len(mini_sentences) - 1:
  124. if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、":
  125. new_mini_sentence_out = new_mini_sentence[:-1] + "。"
  126. new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
  127. elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?":
  128. new_mini_sentence_out = new_mini_sentence + "。"
  129. new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
  130. return new_mini_sentence_out, new_mini_sentence_punc_out
  131. def inference(
  132. batch_size: int,
  133. dtype: str,
  134. ngpu: int,
  135. seed: int,
  136. num_workers: int,
  137. output_dir: str,
  138. log_level: Union[int, str],
  139. train_config: Optional[str],
  140. model_file: Optional[str],
  141. key_file: Optional[str] = None,
  142. data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
  143. raw_inputs: Union[List[Any], bytes, str] = None,
  144. cache: List[Any] = None,
  145. param_dict: dict = None,
  146. **kwargs,
  147. ):
  148. inference_pipeline = inference_modelscope(
  149. output_dir=output_dir,
  150. batch_size=batch_size,
  151. dtype=dtype,
  152. ngpu=ngpu,
  153. seed=seed,
  154. num_workers=num_workers,
  155. log_level=log_level,
  156. key_file=key_file,
  157. train_config=train_config,
  158. model_file=model_file,
  159. param_dict=param_dict,
  160. **kwargs,
  161. )
  162. return inference_pipeline(data_path_and_name_and_type, raw_inputs)
  163. def inference_modelscope(
  164. batch_size: int,
  165. dtype: str,
  166. ngpu: int,
  167. seed: int,
  168. num_workers: int,
  169. log_level: Union[int, str],
  170. key_file: Optional[str],
  171. train_config: Optional[str],
  172. model_file: Optional[str],
  173. output_dir: Optional[str] = None,
  174. param_dict: dict = None,
  175. **kwargs,
  176. ):
  177. assert check_argument_types()
  178. logging.basicConfig(
  179. level=log_level,
  180. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  181. )
  182. if ngpu >= 1 and torch.cuda.is_available():
  183. device = "cuda"
  184. else:
  185. device = "cpu"
  186. # 1. Set random-seed
  187. set_all_random_seed(seed)
  188. text2punc = Text2Punc(train_config, model_file, device)
  189. def _forward(
  190. data_path_and_name_and_type,
  191. raw_inputs: Union[List[Any], bytes, str] = None,
  192. output_dir_v2: Optional[str] = None,
  193. cache: List[Any] = None,
  194. param_dict: dict = None,
  195. ):
  196. results = []
  197. split_size = 20
  198. if raw_inputs != None:
  199. line = raw_inputs.strip()
  200. key = "demo"
  201. if line == "":
  202. item = {'key': key, 'value': ""}
  203. results.append(item)
  204. return results
  205. result, _ = text2punc(line)
  206. item = {'key': key, 'value': result}
  207. results.append(item)
  208. print(results)
  209. return results
  210. for inference_text, _, _ in data_path_and_name_and_type:
  211. with open(inference_text, "r", encoding="utf-8") as fin:
  212. for line in fin:
  213. line = line.strip()
  214. segs = line.split("\t")
  215. if len(segs) != 2:
  216. continue
  217. key = segs[0]
  218. if len(segs[1]) == 0:
  219. continue
  220. result, _ = text2punc(segs[1])
  221. item = {'key': key, 'value': result}
  222. results.append(item)
  223. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  224. if output_path != None:
  225. output_file_name = "infer.out"
  226. Path(output_path).mkdir(parents=True, exist_ok=True)
  227. output_file_path = (Path(output_path) / output_file_name).absolute()
  228. with open(output_file_path, "w", encoding="utf-8") as fout:
  229. for item_i in results:
  230. key_out = item_i["key"]
  231. value_out = item_i["value"]
  232. fout.write(f"{key_out}\t{value_out}\n")
  233. return results
  234. return _forward
  235. def get_parser():
  236. parser = config_argparse.ArgumentParser(
  237. description="Punctuation inference",
  238. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  239. )
  240. parser.add_argument(
  241. "--log_level",
  242. type=lambda x: x.upper(),
  243. default="INFO",
  244. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  245. help="The verbose level of logging",
  246. )
  247. parser.add_argument("--output_dir", type=str, required=False)
  248. parser.add_argument(
  249. "--ngpu",
  250. type=int,
  251. default=0,
  252. help="The number of gpus. 0 indicates CPU mode",
  253. )
  254. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  255. parser.add_argument(
  256. "--dtype",
  257. default="float32",
  258. choices=["float16", "float32", "float64"],
  259. help="Data type",
  260. )
  261. parser.add_argument(
  262. "--num_workers",
  263. type=int,
  264. default=1,
  265. help="The number of workers used for DataLoader",
  266. )
  267. parser.add_argument(
  268. "--batch_size",
  269. type=int,
  270. default=1,
  271. help="The batch size for inference",
  272. )
  273. group = parser.add_argument_group("Input data related")
  274. group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
  275. group.add_argument("--raw_inputs", type=str, required=False)
  276. group.add_argument("--cache", type=list, required=False)
  277. group.add_argument("--param_dict", type=dict, required=False)
  278. group.add_argument("--key_file", type=str_or_none)
  279. group = parser.add_argument_group("The model configuration related")
  280. group.add_argument("--train_config", type=str)
  281. group.add_argument("--model_file", type=str)
  282. return parser
  283. def main(cmd=None):
  284. print(get_commandline_args(), file=sys.stderr)
  285. parser = get_parser()
  286. args = parser.parse_args(cmd)
  287. kwargs = vars(args)
  288. # kwargs.pop("config", None)
  289. inference(**kwargs)
  290. if __name__ == "__main__":
  291. main()