diar_inference_launch.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. #!/usr/bin/env python3
  2. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  3. # MIT License (https://opensource.org/licenses/MIT)
  4. import torch
  5. torch.set_num_threads(1)
  6. import argparse
  7. import logging
  8. import os
  9. import sys
  10. from typing import Union, Dict, Any
  11. from funasr.utils import config_argparse
  12. from funasr.utils.cli_utils import get_commandline_args
  13. from funasr.utils.types import str2bool
  14. from funasr.utils.types import str2triple_str
  15. from funasr.utils.types import str_or_none
  16. def get_parser():
  17. parser = config_argparse.ArgumentParser(
  18. description="Speaker Verification",
  19. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  20. )
  21. # Note(kamo): Use '_' instead of '-' as separator.
  22. # '-' is confusing if written in yaml.
  23. parser.add_argument(
  24. "--log_level",
  25. type=lambda x: x.upper(),
  26. default="INFO",
  27. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  28. help="The verbose level of logging",
  29. )
  30. parser.add_argument("--output_dir", type=str, required=False)
  31. parser.add_argument(
  32. "--ngpu",
  33. type=int,
  34. default=0,
  35. help="The number of gpus. 0 indicates CPU mode",
  36. )
  37. parser.add_argument(
  38. "--njob",
  39. type=int,
  40. default=1,
  41. help="The number of jobs for each gpu",
  42. )
  43. parser.add_argument(
  44. "--gpuid_list",
  45. type=str,
  46. default="",
  47. help="The visible gpus",
  48. )
  49. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  50. parser.add_argument(
  51. "--dtype",
  52. default="float32",
  53. choices=["float16", "float32", "float64"],
  54. help="Data type",
  55. )
  56. parser.add_argument(
  57. "--num_workers",
  58. type=int,
  59. default=1,
  60. help="The number of workers used for DataLoader",
  61. )
  62. group = parser.add_argument_group("Input data related")
  63. group.add_argument(
  64. "--data_path_and_name_and_type",
  65. type=str2triple_str,
  66. required=False,
  67. action="append",
  68. )
  69. group.add_argument("--key_file", type=str_or_none)
  70. group.add_argument("--allow_variable_data_keys", type=str2bool, default=True)
  71. group = parser.add_argument_group("The model configuration related")
  72. group.add_argument(
  73. "--vad_infer_config",
  74. type=str,
  75. help="VAD infer configuration",
  76. )
  77. group.add_argument(
  78. "--vad_model_file",
  79. type=str,
  80. help="VAD model parameter file",
  81. )
  82. group.add_argument(
  83. "--diar_train_config",
  84. type=str,
  85. help="ASR training configuration",
  86. )
  87. group.add_argument(
  88. "--diar_model_file",
  89. type=str,
  90. help="ASR model parameter file",
  91. )
  92. group.add_argument(
  93. "--cmvn_file",
  94. type=str,
  95. help="Global CMVN file",
  96. )
  97. group.add_argument(
  98. "--model_tag",
  99. type=str,
  100. help="Pretrained model tag. If specify this option, *_train_config and "
  101. "*_file will be overwritten",
  102. )
  103. group = parser.add_argument_group("The inference configuration related")
  104. group.add_argument(
  105. "--batch_size",
  106. type=int,
  107. default=1,
  108. help="The batch size for inference",
  109. )
  110. group.add_argument(
  111. "--diar_smooth_size",
  112. type=int,
  113. default=121,
  114. help="The smoothing size for post-processing"
  115. )
  116. return parser
  117. def inference_launch(mode, **kwargs):
  118. if mode == "sond":
  119. from funasr.bin.sond_inference import inference_modelscope
  120. return inference_modelscope(mode=mode, **kwargs)
  121. elif mode == "sond_demo":
  122. from funasr.bin.sond_inference import inference_modelscope
  123. param_dict = {
  124. "extract_profile": True,
  125. "sv_train_config": "sv.yaml",
  126. "sv_model_file": "sv.pb",
  127. }
  128. if "param_dict" in kwargs and kwargs["param_dict"] is not None:
  129. for key in param_dict:
  130. if key not in kwargs["param_dict"]:
  131. kwargs["param_dict"][key] = param_dict[key]
  132. else:
  133. kwargs["param_dict"] = param_dict
  134. return inference_modelscope(mode=mode, **kwargs)
  135. elif mode == "eend-ola":
  136. from funasr.bin.eend_ola_inference import inference_modelscope
  137. return inference_modelscope(mode=mode, **kwargs)
  138. else:
  139. logging.info("Unknown decoding mode: {}".format(mode))
  140. return None
  141. def main(cmd=None):
  142. print(get_commandline_args(), file=sys.stderr)
  143. parser = get_parser()
  144. parser.add_argument(
  145. "--mode",
  146. type=str,
  147. default="sond",
  148. help="The decoding mode",
  149. )
  150. args = parser.parse_args(cmd)
  151. kwargs = vars(args)
  152. kwargs.pop("config", None)
  153. # set logging messages
  154. logging.basicConfig(
  155. level=args.log_level,
  156. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  157. )
  158. logging.info("Decoding args: {}".format(kwargs))
  159. # gpu setting
  160. if args.ngpu > 0:
  161. jobid = int(args.output_dir.split(".")[-1])
  162. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  163. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  164. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  165. inference_launch(**kwargs)
  166. if __name__ == "__main__":
  167. main()