diar_inference_launch.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. #!/usr/bin/env python3
  2. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  3. # MIT License (https://opensource.org/licenses/MIT)
  4. import argparse
  5. import logging
  6. import os
  7. import sys
  8. from typing import Union, Dict, Any
  9. from funasr.utils import config_argparse
  10. from funasr.utils.cli_utils import get_commandline_args
  11. from funasr.utils.types import str2bool
  12. from funasr.utils.types import str2triple_str
  13. from funasr.utils.types import str_or_none
  14. def get_parser():
  15. parser = config_argparse.ArgumentParser(
  16. description="Speaker Verification",
  17. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  18. )
  19. # Note(kamo): Use '_' instead of '-' as separator.
  20. # '-' is confusing if written in yaml.
  21. parser.add_argument(
  22. "--log_level",
  23. type=lambda x: x.upper(),
  24. default="INFO",
  25. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  26. help="The verbose level of logging",
  27. )
  28. parser.add_argument("--output_dir", type=str, required=False)
  29. parser.add_argument(
  30. "--ngpu",
  31. type=int,
  32. default=0,
  33. help="The number of gpus. 0 indicates CPU mode",
  34. )
  35. parser.add_argument(
  36. "--njob",
  37. type=int,
  38. default=1,
  39. help="The number of jobs for each gpu",
  40. )
  41. parser.add_argument(
  42. "--gpuid_list",
  43. type=str,
  44. default="",
  45. help="The visible gpus",
  46. )
  47. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  48. parser.add_argument(
  49. "--dtype",
  50. default="float32",
  51. choices=["float16", "float32", "float64"],
  52. help="Data type",
  53. )
  54. parser.add_argument(
  55. "--num_workers",
  56. type=int,
  57. default=1,
  58. help="The number of workers used for DataLoader",
  59. )
  60. group = parser.add_argument_group("Input data related")
  61. group.add_argument(
  62. "--data_path_and_name_and_type",
  63. type=str2triple_str,
  64. required=False,
  65. action="append",
  66. )
  67. group.add_argument("--key_file", type=str_or_none)
  68. group.add_argument("--allow_variable_data_keys", type=str2bool, default=True)
  69. group = parser.add_argument_group("The model configuration related")
  70. group.add_argument(
  71. "--vad_infer_config",
  72. type=str,
  73. help="VAD infer configuration",
  74. )
  75. group.add_argument(
  76. "--vad_model_file",
  77. type=str,
  78. help="VAD model parameter file",
  79. )
  80. group.add_argument(
  81. "--diar_train_config",
  82. type=str,
  83. help="ASR training configuration",
  84. )
  85. group.add_argument(
  86. "--diar_model_file",
  87. type=str,
  88. help="ASR model parameter file",
  89. )
  90. group.add_argument(
  91. "--cmvn_file",
  92. type=str,
  93. help="Global CMVN file",
  94. )
  95. group.add_argument(
  96. "--model_tag",
  97. type=str,
  98. help="Pretrained model tag. If specify this option, *_train_config and "
  99. "*_file will be overwritten",
  100. )
  101. group = parser.add_argument_group("The inference configuration related")
  102. group.add_argument(
  103. "--batch_size",
  104. type=int,
  105. default=1,
  106. help="The batch size for inference",
  107. )
  108. group.add_argument(
  109. "--diar_smooth_size",
  110. type=int,
  111. default=121,
  112. help="The smoothing size for post-processing"
  113. )
  114. return parser
  115. def inference_launch(mode, **kwargs):
  116. if mode == "sond":
  117. from funasr.bin.sond_inference import inference_modelscope
  118. return inference_modelscope(mode=mode, **kwargs)
  119. elif mode == "sond_demo":
  120. from funasr.bin.sond_inference import inference_modelscope
  121. param_dict = {
  122. "extract_profile": True,
  123. "sv_train_config": "sv.yaml",
  124. "sv_model_file": "sv.pb",
  125. }
  126. if "param_dict" in kwargs and kwargs["param_dict"] is not None:
  127. for key in param_dict:
  128. if key not in kwargs["param_dict"]:
  129. kwargs["param_dict"][key] = param_dict[key]
  130. else:
  131. kwargs["param_dict"] = param_dict
  132. return inference_modelscope(mode=mode, **kwargs)
  133. elif mode == "eend-ola":
  134. from funasr.bin.eend_ola_inference import inference_modelscope
  135. return inference_modelscope(mode=mode, **kwargs)
  136. else:
  137. logging.info("Unknown decoding mode: {}".format(mode))
  138. return None
  139. def main(cmd=None):
  140. print(get_commandline_args(), file=sys.stderr)
  141. parser = get_parser()
  142. parser.add_argument(
  143. "--mode",
  144. type=str,
  145. default="sond",
  146. help="The decoding mode",
  147. )
  148. args = parser.parse_args(cmd)
  149. kwargs = vars(args)
  150. kwargs.pop("config", None)
  151. # set logging messages
  152. logging.basicConfig(
  153. level=args.log_level,
  154. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  155. )
  156. logging.info("Decoding args: {}".format(kwargs))
  157. # gpu setting
  158. if args.ngpu > 0:
  159. jobid = int(args.output_dir.split(".")[-1])
  160. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  161. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  162. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  163. inference_launch(**kwargs)
  164. if __name__ == "__main__":
  165. main()