punc_inference_launch.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import argparse
  6. import logging
  7. import os
  8. import sys
  9. from pathlib import Path
  10. from typing import Any
  11. from typing import List
  12. from typing import Optional
  13. from typing import Union
  14. import torch
  15. from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
  16. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  17. from funasr.utils import config_argparse
  18. from funasr.utils.cli_utils import get_commandline_args
  19. from funasr.utils.types import str2triple_str
  20. from funasr.utils.types import str_or_none
  21. def inference_punc(
  22. batch_size: int,
  23. dtype: str,
  24. ngpu: int,
  25. seed: int,
  26. num_workers: int,
  27. log_level: Union[int, str],
  28. key_file: Optional[str],
  29. train_config: Optional[str],
  30. model_file: Optional[str],
  31. output_dir: Optional[str] = None,
  32. param_dict: dict = None,
  33. **kwargs,
  34. ):
  35. logging.basicConfig(
  36. level=log_level,
  37. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  38. )
  39. if ngpu >= 1 and torch.cuda.is_available():
  40. device = "cuda"
  41. else:
  42. device = "cpu"
  43. # 1. Set random-seed
  44. set_all_random_seed(seed)
  45. text2punc = Text2Punc(train_config, model_file, device)
  46. def _forward(
  47. data_path_and_name_and_type,
  48. raw_inputs: Union[List[Any], bytes, str] = None,
  49. output_dir_v2: Optional[str] = None,
  50. cache: List[Any] = None,
  51. param_dict: dict = None,
  52. ):
  53. results = []
  54. split_size = 20
  55. if raw_inputs != None:
  56. line = raw_inputs.strip()
  57. key = "demo"
  58. if line == "":
  59. item = {'key': key, 'value': ""}
  60. results.append(item)
  61. return results
  62. result, _ = text2punc(line)
  63. item = {'key': key, 'value': result}
  64. results.append(item)
  65. return results
  66. for inference_text, _, _ in data_path_and_name_and_type:
  67. with open(inference_text, "r", encoding="utf-8") as fin:
  68. for line in fin:
  69. line = line.strip()
  70. segs = line.split("\t")
  71. if len(segs) != 2:
  72. continue
  73. key = segs[0]
  74. if len(segs[1]) == 0:
  75. continue
  76. result, _ = text2punc(segs[1])
  77. item = {'key': key, 'value': result}
  78. results.append(item)
  79. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  80. if output_path != None:
  81. output_file_name = "infer.out"
  82. Path(output_path).mkdir(parents=True, exist_ok=True)
  83. output_file_path = (Path(output_path) / output_file_name).absolute()
  84. with open(output_file_path, "w", encoding="utf-8") as fout:
  85. for item_i in results:
  86. key_out = item_i["key"]
  87. value_out = item_i["value"]
  88. fout.write(f"{key_out}\t{value_out}\n")
  89. return results
  90. return _forward
  91. def inference_punc_vad_realtime(
  92. batch_size: int,
  93. dtype: str,
  94. ngpu: int,
  95. seed: int,
  96. num_workers: int,
  97. log_level: Union[int, str],
  98. # cache: list,
  99. key_file: Optional[str],
  100. train_config: Optional[str],
  101. model_file: Optional[str],
  102. output_dir: Optional[str] = None,
  103. param_dict: dict = None,
  104. **kwargs,
  105. ):
  106. ncpu = kwargs.get("ncpu", 1)
  107. torch.set_num_threads(ncpu)
  108. if ngpu >= 1 and torch.cuda.is_available():
  109. device = "cuda"
  110. else:
  111. device = "cpu"
  112. # 1. Set random-seed
  113. set_all_random_seed(seed)
  114. text2punc = Text2PuncVADRealtime(train_config, model_file, device)
  115. def _forward(
  116. data_path_and_name_and_type,
  117. raw_inputs: Union[List[Any], bytes, str] = None,
  118. output_dir_v2: Optional[str] = None,
  119. cache: List[Any] = None,
  120. param_dict: dict = None,
  121. ):
  122. results = []
  123. split_size = 10
  124. cache_in = param_dict["cache"]
  125. if raw_inputs != None:
  126. line = raw_inputs.strip()
  127. key = "demo"
  128. if line == "":
  129. item = {'key': key, 'value': ""}
  130. results.append(item)
  131. return results
  132. result, _, cache = text2punc(line, cache_in)
  133. param_dict["cache"] = cache
  134. item = {'key': key, 'value': result}
  135. results.append(item)
  136. return results
  137. return results
  138. return _forward
  139. def inference_launch(mode, **kwargs):
  140. if mode == "punc":
  141. return inference_punc(**kwargs)
  142. if mode == "punc_VadRealtime":
  143. return inference_punc_vad_realtime(**kwargs)
  144. else:
  145. logging.info("Unknown decoding mode: {}".format(mode))
  146. return None
  147. def get_parser():
  148. parser = config_argparse.ArgumentParser(
  149. description="Punctuation inference",
  150. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  151. )
  152. parser.add_argument(
  153. "--log_level",
  154. type=lambda x: x.upper(),
  155. default="INFO",
  156. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  157. help="The verbose level of logging",
  158. )
  159. parser.add_argument("--output_dir", type=str, required=True)
  160. parser.add_argument("--gpuid_list", type=str, required=True)
  161. parser.add_argument(
  162. "--ngpu",
  163. type=int,
  164. default=0,
  165. help="The number of gpus. 0 indicates CPU mode",
  166. )
  167. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  168. parser.add_argument("--njob", type=int, default=1, help="Random seed")
  169. parser.add_argument(
  170. "--dtype",
  171. default="float32",
  172. choices=["float16", "float32", "float64"],
  173. help="Data type",
  174. )
  175. parser.add_argument(
  176. "--num_workers",
  177. type=int,
  178. default=1,
  179. help="The number of workers used for DataLoader",
  180. )
  181. parser.add_argument(
  182. "--batch_size",
  183. type=int,
  184. default=1,
  185. help="The batch size for inference",
  186. )
  187. group = parser.add_argument_group("Input data related")
  188. group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
  189. group.add_argument("--raw_inputs", type=str, required=False)
  190. group.add_argument("--key_file", type=str_or_none)
  191. group.add_argument("--cache", type=list, required=False)
  192. group.add_argument("--param_dict", type=dict, required=False)
  193. group = parser.add_argument_group("The model configuration related")
  194. group.add_argument("--train_config", type=str)
  195. group.add_argument("--model_file", type=str)
  196. group.add_argument("--mode", type=str, default="punc")
  197. return parser
  198. def main(cmd=None):
  199. print(get_commandline_args(), file=sys.stderr)
  200. parser = get_parser()
  201. args = parser.parse_args(cmd)
  202. kwargs = vars(args)
  203. kwargs.pop("config", None)
  204. # set logging messages
  205. logging.basicConfig(
  206. level=args.log_level,
  207. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  208. )
  209. logging.info("Decoding args: {}".format(kwargs))
  210. # gpu setting
  211. if args.ngpu > 0:
  212. jobid = int(args.output_dir.split(".")[-1])
  213. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  214. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  215. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  216. kwargs.pop("gpuid_list", None)
  217. kwargs.pop("njob", None)
  218. inference_pipeline = inference_launch(**kwargs)
  219. return inference_pipeline(kwargs["data_path_and_name_and_type"])
  220. if __name__ == "__main__":
  221. main()