punc_inference_launch.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import argparse
  6. import logging
  7. import os
  8. import sys
  9. from pathlib import Path
  10. from typing import Any
  11. from typing import List
  12. from typing import Optional
  13. from typing import Union
  14. import torch
  15. from typeguard import check_argument_types
  16. from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
  17. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  18. from funasr.utils import config_argparse
  19. from funasr.utils.cli_utils import get_commandline_args
  20. from funasr.utils.types import str2triple_str
  21. from funasr.utils.types import str_or_none
  22. def inference_punc(
  23. batch_size: int,
  24. dtype: str,
  25. ngpu: int,
  26. seed: int,
  27. num_workers: int,
  28. log_level: Union[int, str],
  29. key_file: Optional[str],
  30. train_config: Optional[str],
  31. model_file: Optional[str],
  32. output_dir: Optional[str] = None,
  33. param_dict: dict = None,
  34. **kwargs,
  35. ):
  36. assert check_argument_types()
  37. logging.basicConfig(
  38. level=log_level,
  39. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  40. )
  41. if ngpu >= 1 and torch.cuda.is_available():
  42. device = "cuda"
  43. else:
  44. device = "cpu"
  45. # 1. Set random-seed
  46. set_all_random_seed(seed)
  47. text2punc = Text2Punc(train_config, model_file, device)
  48. def _forward(
  49. data_path_and_name_and_type,
  50. raw_inputs: Union[List[Any], bytes, str] = None,
  51. output_dir_v2: Optional[str] = None,
  52. cache: List[Any] = None,
  53. param_dict: dict = None,
  54. ):
  55. results = []
  56. split_size = 20
  57. if raw_inputs != None:
  58. line = raw_inputs.strip()
  59. key = "demo"
  60. if line == "":
  61. item = {'key': key, 'value': ""}
  62. results.append(item)
  63. return results
  64. result, _ = text2punc(line)
  65. item = {'key': key, 'value': result}
  66. results.append(item)
  67. return results
  68. for inference_text, _, _ in data_path_and_name_and_type:
  69. with open(inference_text, "r", encoding="utf-8") as fin:
  70. for line in fin:
  71. line = line.strip()
  72. segs = line.split("\t")
  73. if len(segs) != 2:
  74. continue
  75. key = segs[0]
  76. if len(segs[1]) == 0:
  77. continue
  78. result, _ = text2punc(segs[1])
  79. item = {'key': key, 'value': result}
  80. results.append(item)
  81. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  82. if output_path != None:
  83. output_file_name = "infer.out"
  84. Path(output_path).mkdir(parents=True, exist_ok=True)
  85. output_file_path = (Path(output_path) / output_file_name).absolute()
  86. with open(output_file_path, "w", encoding="utf-8") as fout:
  87. for item_i in results:
  88. key_out = item_i["key"]
  89. value_out = item_i["value"]
  90. fout.write(f"{key_out}\t{value_out}\n")
  91. return results
  92. return _forward
  93. def inference_punc_vad_realtime(
  94. batch_size: int,
  95. dtype: str,
  96. ngpu: int,
  97. seed: int,
  98. num_workers: int,
  99. log_level: Union[int, str],
  100. # cache: list,
  101. key_file: Optional[str],
  102. train_config: Optional[str],
  103. model_file: Optional[str],
  104. output_dir: Optional[str] = None,
  105. param_dict: dict = None,
  106. **kwargs,
  107. ):
  108. assert check_argument_types()
  109. ncpu = kwargs.get("ncpu", 1)
  110. torch.set_num_threads(ncpu)
  111. if ngpu >= 1 and torch.cuda.is_available():
  112. device = "cuda"
  113. else:
  114. device = "cpu"
  115. # 1. Set random-seed
  116. set_all_random_seed(seed)
  117. text2punc = Text2PuncVADRealtime(train_config, model_file, device)
  118. def _forward(
  119. data_path_and_name_and_type,
  120. raw_inputs: Union[List[Any], bytes, str] = None,
  121. output_dir_v2: Optional[str] = None,
  122. cache: List[Any] = None,
  123. param_dict: dict = None,
  124. ):
  125. results = []
  126. split_size = 10
  127. cache_in = param_dict["cache"]
  128. if raw_inputs != None:
  129. line = raw_inputs.strip()
  130. key = "demo"
  131. if line == "":
  132. item = {'key': key, 'value': ""}
  133. results.append(item)
  134. return results
  135. result, _, cache = text2punc(line, cache_in)
  136. param_dict["cache"] = cache
  137. item = {'key': key, 'value': result}
  138. results.append(item)
  139. return results
  140. return results
  141. return _forward
  142. def inference_launch(mode, **kwargs):
  143. if mode == "punc":
  144. return inference_punc(**kwargs)
  145. if mode == "punc_VadRealtime":
  146. return inference_punc_vad_realtime(**kwargs)
  147. else:
  148. logging.info("Unknown decoding mode: {}".format(mode))
  149. return None
  150. def get_parser():
  151. parser = config_argparse.ArgumentParser(
  152. description="Punctuation inference",
  153. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  154. )
  155. parser.add_argument(
  156. "--log_level",
  157. type=lambda x: x.upper(),
  158. default="INFO",
  159. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  160. help="The verbose level of logging",
  161. )
  162. parser.add_argument("--output_dir", type=str, required=True)
  163. parser.add_argument("--gpuid_list", type=str, required=True)
  164. parser.add_argument(
  165. "--ngpu",
  166. type=int,
  167. default=0,
  168. help="The number of gpus. 0 indicates CPU mode",
  169. )
  170. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  171. parser.add_argument("--njob", type=int, default=1, help="Random seed")
  172. parser.add_argument(
  173. "--dtype",
  174. default="float32",
  175. choices=["float16", "float32", "float64"],
  176. help="Data type",
  177. )
  178. parser.add_argument(
  179. "--num_workers",
  180. type=int,
  181. default=1,
  182. help="The number of workers used for DataLoader",
  183. )
  184. parser.add_argument(
  185. "--batch_size",
  186. type=int,
  187. default=1,
  188. help="The batch size for inference",
  189. )
  190. group = parser.add_argument_group("Input data related")
  191. group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
  192. group.add_argument("--raw_inputs", type=str, required=False)
  193. group.add_argument("--key_file", type=str_or_none)
  194. group.add_argument("--cache", type=list, required=False)
  195. group.add_argument("--param_dict", type=dict, required=False)
  196. group = parser.add_argument_group("The model configuration related")
  197. group.add_argument("--train_config", type=str)
  198. group.add_argument("--model_file", type=str)
  199. group.add_argument("--mode", type=str, default="punc")
  200. return parser
  201. def main(cmd=None):
  202. print(get_commandline_args(), file=sys.stderr)
  203. parser = get_parser()
  204. args = parser.parse_args(cmd)
  205. kwargs = vars(args)
  206. kwargs.pop("config", None)
  207. # set logging messages
  208. logging.basicConfig(
  209. level=args.log_level,
  210. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  211. )
  212. logging.info("Decoding args: {}".format(kwargs))
  213. # gpu setting
  214. if args.ngpu > 0:
  215. jobid = int(args.output_dir.split(".")[-1])
  216. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  217. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  218. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  219. kwargs.pop("gpuid_list", None)
  220. kwargs.pop("njob", None)
  221. inference_pipeline = inference_launch(**kwargs)
  222. return inference_pipeline(kwargs["data_path_and_name_and_type"])
  223. if __name__ == "__main__":
  224. main()