inference_cli.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import os
  6. import logging
  7. import torch
  8. import numpy as np
  9. from funasr.utils.download_and_prepare_model import prepare_model
  10. from funasr.utils.types import str2bool
  11. def infer(task_name: str = "asr",
  12. model: str = None,
  13. # mode: str = None,
  14. vad_model: str = None,
  15. disable_vad: bool = False,
  16. punc_model: str = None,
  17. disable_punc: bool = False,
  18. model_hub: str = "ms",
  19. cache_dir: str = None,
  20. **kwargs,
  21. ):
  22. # set logging messages
  23. logging.basicConfig(
  24. level=logging.ERROR,
  25. )
  26. model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs)
  27. if task_name == "asr":
  28. from funasr.bin.asr_inference_launch import inference_launch
  29. inference_pipeline = inference_launch(**kwargs)
  30. elif task_name == "":
  31. pipeline = 1
  32. elif task_name == "":
  33. pipeline = 2
  34. elif task_name == "":
  35. pipeline = 2
  36. def _infer_fn(input, **kwargs):
  37. data_type = kwargs.get('data_type', 'sound')
  38. data_path_and_name_and_type = [input, 'speech', data_type]
  39. raw_inputs = None
  40. if isinstance(input, torch.Tensor):
  41. input = input.numpy()
  42. if isinstance(input, np.ndarray):
  43. data_path_and_name_and_type = None
  44. raw_inputs = input
  45. return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs)
  46. return _infer_fn
  47. def main(cmd=None):
  48. # print(get_commandline_args(), file=sys.stderr)
  49. from funasr.bin.argument import get_parser
  50. parser = get_parser()
  51. parser.add_argument('input', help='input file to transcribe')
  52. parser.add_argument(
  53. "--task_name",
  54. type=str,
  55. default="asr",
  56. help="The decoding mode",
  57. )
  58. parser.add_argument(
  59. "-m",
  60. "--model",
  61. type=str,
  62. default="paraformer-zh",
  63. help="The asr mode name",
  64. )
  65. parser.add_argument(
  66. "-v",
  67. "--vad_model",
  68. type=str,
  69. default="fsmn-vad",
  70. help="vad model name",
  71. )
  72. parser.add_argument(
  73. "-dv",
  74. "--disable_vad",
  75. type=str2bool,
  76. default=False,
  77. help="",
  78. )
  79. parser.add_argument(
  80. "-p",
  81. "--punc_model",
  82. type=str,
  83. default="ct-punc",
  84. help="",
  85. )
  86. parser.add_argument(
  87. "-dp",
  88. "--disable_punc",
  89. type=str2bool,
  90. default=False,
  91. help="",
  92. )
  93. parser.add_argument(
  94. "--batch_size_token",
  95. type=int,
  96. default=5000,
  97. help="",
  98. )
  99. parser.add_argument(
  100. "--batch_size_token_threshold_s",
  101. type=int,
  102. default=35,
  103. help="",
  104. )
  105. parser.add_argument(
  106. "--max_single_segment_time",
  107. type=int,
  108. default=5000,
  109. help="",
  110. )
  111. args = parser.parse_args(cmd)
  112. kwargs = vars(args)
  113. # set logging messages
  114. logging.basicConfig(
  115. level=logging.ERROR,
  116. )
  117. logging.info("Decoding args: {}".format(kwargs))
  118. # kwargs["ncpu"] = 2 #os.cpu_count()
  119. kwargs.pop("data_path_and_name_and_type")
  120. print("args: {}".format(kwargs))
  121. p = infer(**kwargs)
  122. res = p(**kwargs)
  123. print(res)