punc_inference_launch.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #!/usr/bin/env python3
  2. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  3. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  4. import argparse
  5. import logging
  6. import os
  7. import sys
  8. from typing import Union, Dict, Any
  9. from funasr.utils import config_argparse
  10. from funasr.utils.cli_utils import get_commandline_args
  11. from funasr.utils.types import str2bool
  12. from funasr.utils.types import str2triple_str
  13. from funasr.utils.types import str_or_none
  14. from funasr.utils.types import float_or_none
  15. def get_parser():
  16. parser = config_argparse.ArgumentParser(
  17. description="Punctuation inference",
  18. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  19. )
  20. parser.add_argument(
  21. "--log_level",
  22. type=lambda x: x.upper(),
  23. default="INFO",
  24. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  25. help="The verbose level of logging",
  26. )
  27. parser.add_argument("--output_dir", type=str, required=True)
  28. parser.add_argument("--gpuid_list", type=str, required=True)
  29. parser.add_argument(
  30. "--ngpu",
  31. type=int,
  32. default=0,
  33. help="The number of gpus. 0 indicates CPU mode",
  34. )
  35. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  36. parser.add_argument("--njob", type=int, default=1, help="Random seed")
  37. parser.add_argument(
  38. "--dtype",
  39. default="float32",
  40. choices=["float16", "float32", "float64"],
  41. help="Data type",
  42. )
  43. parser.add_argument(
  44. "--num_workers",
  45. type=int,
  46. default=1,
  47. help="The number of workers used for DataLoader",
  48. )
  49. parser.add_argument(
  50. "--batch_size",
  51. type=int,
  52. default=1,
  53. help="The batch size for inference",
  54. )
  55. group = parser.add_argument_group("Input data related")
  56. group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
  57. group.add_argument("--raw_inputs", type=str, required=False)
  58. group.add_argument("--key_file", type=str_or_none)
  59. group.add_argument("--cache", type=list, required=False)
  60. group.add_argument("--param_dict", type=dict, required=False)
  61. group = parser.add_argument_group("The model configuration related")
  62. group.add_argument("--train_config", type=str)
  63. group.add_argument("--model_file", type=str)
  64. group.add_argument("--mode", type=str, default="punc")
  65. return parser
  66. def inference_launch(mode, **kwargs):
  67. if mode == "punc":
  68. from funasr.bin.punctuation_infer import inference_modelscope
  69. return inference_modelscope(**kwargs)
  70. if mode == "punc_VadRealtime":
  71. from funasr.bin.punctuation_infer_vadrealtime import inference_modelscope
  72. return inference_modelscope(**kwargs)
  73. else:
  74. logging.info("Unknown decoding mode: {}".format(mode))
  75. return None
  76. def main(cmd=None):
  77. print(get_commandline_args(), file=sys.stderr)
  78. parser = get_parser()
  79. args = parser.parse_args(cmd)
  80. kwargs = vars(args)
  81. kwargs.pop("config", None)
  82. # set logging messages
  83. logging.basicConfig(
  84. level=args.log_level,
  85. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  86. )
  87. logging.info("Decoding args: {}".format(kwargs))
  88. # gpu setting
  89. if args.ngpu > 0:
  90. jobid = int(args.output_dir.split(".")[-1])
  91. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  92. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  93. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  94. kwargs.pop("gpuid_list", None)
  95. kwargs.pop("njob", None)
  96. results = inference_launch(**kwargs)
  97. if __name__ == "__main__":
  98. main()