server.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import argparse
  2. import logging
  3. import os
  4. import uuid
  5. import aiofiles
  6. import ffmpeg
  7. import uvicorn
  8. from fastapi import FastAPI, File, UploadFile, Body
  9. from modelscope.pipelines import pipeline
  10. from modelscope.utils.constant import Tasks
  11. from modelscope.utils.logger import get_logger
  12. logger = get_logger(log_level=logging.CRITICAL)
  13. logger.setLevel(logging.CRITICAL)
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument("--host",
  16. type=str,
  17. default="0.0.0.0",
  18. required=False,
  19. help="host ip, localhost, 0.0.0.0")
  20. parser.add_argument("--port",
  21. type=int,
  22. default=8000,
  23. required=False,
  24. help="server port")
  25. parser.add_argument("--asr_model",
  26. type=str,
  27. default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
  28. help="offline asr model from modelscope")
  29. parser.add_argument("--vad_model",
  30. type=str,
  31. default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
  32. help="vad model from modelscope")
  33. parser.add_argument("--punc_model",
  34. type=str,
  35. default="damo/punc_ct-transformer_cn-en-common-vocab471067-large",
  36. help="punc model from modelscope")
  37. parser.add_argument("--ngpu",
  38. type=int,
  39. default=1,
  40. help="0 for cpu, 1 for gpu")
  41. parser.add_argument("--ncpu",
  42. type=int,
  43. default=4,
  44. help="cpu cores")
  45. parser.add_argument("--hotword_path",
  46. type=str,
  47. default=None,
  48. help="hot word txt path, only the hot word model works")
  49. parser.add_argument("--certfile",
  50. type=str,
  51. default=None,
  52. required=False,
  53. help="certfile for ssl")
  54. parser.add_argument("--keyfile",
  55. type=str,
  56. default=None,
  57. required=False,
  58. help="keyfile for ssl")
  59. parser.add_argument("--temp_dir",
  60. type=str,
  61. default="temp_dir/",
  62. required=False,
  63. help="temp dir")
  64. args = parser.parse_args()
  65. print("----------- Configuration Arguments -----------")
  66. for arg, value in vars(args).items():
  67. print("%s: %s" % (arg, value))
  68. print("------------------------------------------------")
  69. os.makedirs(args.temp_dir, exist_ok=True)
  70. print("model loading")
  71. param_dict = {}
  72. if args.hotword_path is not None and os.path.exists(args.hotword_path):
  73. param_dict['hotword'] = args.hotword_path
  74. # asr
  75. inference_pipeline_asr = pipeline(task=Tasks.auto_speech_recognition,
  76. model=args.asr_model,
  77. vad_model=args.vad_model,
  78. ngpu=args.ngpu,
  79. ncpu=args.ncpu,
  80. param_dict=param_dict)
  81. print(f'loaded asr models.')
  82. if args.punc_model != "":
  83. inference_pipeline_punc = pipeline(task=Tasks.punctuation,
  84. model=args.punc_model,
  85. ngpu=args.ngpu,
  86. ncpu=args.ncpu)
  87. print(f'loaded pun models.')
  88. else:
  89. inference_pipeline_punc = None
  90. app = FastAPI(title="FunASR")
  91. @app.post("/recognition")
  92. async def api_recognition(audio: UploadFile = File(..., description="audio file"),
  93. add_pun: int = Body(1, description="add punctuation", embed=True)):
  94. suffix = audio.filename.split('.')[-1]
  95. audio_path = f'{args.temp_dir}/{str(uuid.uuid1())}.{suffix}'
  96. async with aiofiles.open(audio_path, 'wb') as out_file:
  97. content = await audio.read()
  98. await out_file.write(content)
  99. audio_bytes, _ = (
  100. ffmpeg.input(audio_path, threads=0)
  101. .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
  102. .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
  103. )
  104. rec_result = inference_pipeline_asr(audio_in=audio_bytes, param_dict={})
  105. if add_pun:
  106. rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict={'cache': list()})
  107. ret = {"results": rec_result['text'], "code": 0}
  108. print(ret)
  109. return ret
  110. if __name__ == '__main__':
  111. uvicorn.run(app, host=args.host, port=args.port, ssl_keyfile=args.keyfile, ssl_certfile=args.certfile)