tp_inference_launch.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. #!/usr/bin/env python3
  2. import argparse
  3. import logging
  4. import os
  5. import sys
  6. from typing import Union, Dict, Any
  7. from funasr.utils import config_argparse
  8. from funasr.utils.cli_utils import get_commandline_args
  9. from funasr.utils.types import str2bool
  10. from funasr.utils.types import str2triple_str
  11. from funasr.utils.types import str_or_none
  12. def get_parser():
  13. parser = config_argparse.ArgumentParser(
  14. description="Timestamp Prediction Inference",
  15. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  16. )
  17. # Note(kamo): Use '_' instead of '-' as separator.
  18. # '-' is confusing if written in yaml.
  19. parser.add_argument(
  20. "--log_level",
  21. type=lambda x: x.upper(),
  22. default="INFO",
  23. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  24. help="The verbose level of logging",
  25. )
  26. parser.add_argument("--output_dir", type=str, required=False)
  27. parser.add_argument(
  28. "--ngpu",
  29. type=int,
  30. default=0,
  31. help="The number of gpus. 0 indicates CPU mode",
  32. )
  33. parser.add_argument(
  34. "--njob",
  35. type=int,
  36. default=1,
  37. help="The number of jobs for each gpu",
  38. )
  39. parser.add_argument(
  40. "--gpuid_list",
  41. type=str,
  42. default="",
  43. help="The visible gpus",
  44. )
  45. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  46. parser.add_argument(
  47. "--dtype",
  48. default="float32",
  49. choices=["float16", "float32", "float64"],
  50. help="Data type",
  51. )
  52. parser.add_argument(
  53. "--num_workers",
  54. type=int,
  55. default=1,
  56. help="The number of workers used for DataLoader",
  57. )
  58. group = parser.add_argument_group("Input data related")
  59. group.add_argument(
  60. "--data_path_and_name_and_type",
  61. type=str2triple_str,
  62. required=True,
  63. action="append",
  64. )
  65. group.add_argument("--key_file", type=str_or_none)
  66. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  67. group = parser.add_argument_group("The model configuration related")
  68. group.add_argument(
  69. "--timestamp_infer_config",
  70. type=str,
  71. help="VAD infer configuration",
  72. )
  73. group.add_argument(
  74. "--timestamp_model_file",
  75. type=str,
  76. help="VAD model parameter file",
  77. )
  78. group.add_argument(
  79. "--timestamp_cmvn_file",
  80. type=str,
  81. help="Global CMVN file",
  82. )
  83. group = parser.add_argument_group("The inference configuration related")
  84. group.add_argument(
  85. "--batch_size",
  86. type=int,
  87. default=1,
  88. help="The batch size for inference",
  89. )
  90. return parser
  91. def inference_launch(mode, **kwargs):
  92. if mode == "tp_norm":
  93. from funasr.bin.tp_inference import inference_modelscope
  94. return inference_modelscope(**kwargs)
  95. else:
  96. logging.info("Unknown decoding mode: {}".format(mode))
  97. return None
  98. def main(cmd=None):
  99. print(get_commandline_args(), file=sys.stderr)
  100. parser = get_parser()
  101. parser.add_argument(
  102. "--mode",
  103. type=str,
  104. default="tp_norm",
  105. help="The decoding mode",
  106. )
  107. args = parser.parse_args(cmd)
  108. kwargs = vars(args)
  109. kwargs.pop("config", None)
  110. # set logging messages
  111. logging.basicConfig(
  112. level=args.log_level,
  113. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  114. )
  115. logging.info("Decoding args: {}".format(kwargs))
  116. # gpu setting
  117. if args.ngpu > 0:
  118. jobid = int(args.output_dir.split(".")[-1])
  119. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  120. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  121. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  122. inference_launch(**kwargs)
  123. if __name__ == "__main__":
  124. main()