server.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import argparse
  2. import logging
  3. import os
  4. import random
  5. import time
  6. import aiofiles
  7. import ffmpeg
  8. import uvicorn
  9. from fastapi import FastAPI, File, UploadFile, Body
  10. from modelscope.pipelines import pipeline
  11. from modelscope.utils.constant import Tasks
  12. from modelscope.utils.logger import get_logger
  13. logger = get_logger(log_level=logging.CRITICAL)
  14. logger.setLevel(logging.CRITICAL)
  15. parser = argparse.ArgumentParser()
  16. parser.add_argument("--host",
  17. type=str,
  18. default="0.0.0.0",
  19. required=False,
  20. help="host ip, localhost, 0.0.0.0")
  21. parser.add_argument("--port",
  22. type=int,
  23. default=8000,
  24. required=False,
  25. help="server port")
  26. parser.add_argument("--asr_model",
  27. type=str,
  28. default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
  29. help="model from modelscope")
  30. parser.add_argument("--punc_model",
  31. type=str,
  32. default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
  33. help="model from modelscope")
  34. parser.add_argument("--ngpu",
  35. type=int,
  36. default=1,
  37. help="0 for cpu, 1 for gpu")
  38. parser.add_argument("--ncpu",
  39. type=int,
  40. default=4,
  41. help="cpu cores")
  42. parser.add_argument("--certfile",
  43. type=str,
  44. default=None,
  45. required=False,
  46. help="certfile for ssl")
  47. parser.add_argument("--keyfile",
  48. type=str,
  49. default=None,
  50. required=False,
  51. help="keyfile for ssl")
  52. parser.add_argument("--temp_dir",
  53. type=str,
  54. default="temp_dir/",
  55. required=False,
  56. help="temp dir")
  57. args = parser.parse_args()
  58. os.makedirs(args.temp_dir, exist_ok=True)
  59. print("model loading")
  60. # asr
  61. inference_pipeline_asr = pipeline(task=Tasks.auto_speech_recognition,
  62. model=args.asr_model,
  63. ngpu=args.ngpu,
  64. ncpu=args.ncpu,
  65. model_revision=None)
  66. print(f'loaded asr models.')
  67. if args.punc_model != "":
  68. inference_pipeline_punc = pipeline(task=Tasks.punctuation,
  69. model=args.punc_model,
  70. model_revision="v1.0.2",
  71. ngpu=args.ngpu,
  72. ncpu=args.ncpu)
  73. print(f'loaded pun models.')
  74. else:
  75. inference_pipeline_punc = None
  76. app = FastAPI(title="FunASR")
  77. @app.post("/recognition")
  78. async def api_recognition(audio: UploadFile = File(..., description="audio file"),
  79. add_pun: int = Body(1, description="add punctuation", embed=True)):
  80. suffix = audio.filename.split('.')[-1]
  81. audio_path = f'{args.temp_dir}/{int(time.time() * 1000)}_{random.randint(100, 999)}.{suffix}'
  82. async with aiofiles.open(audio_path, 'wb') as out_file:
  83. content = await audio.read()
  84. await out_file.write(content)
  85. audio_bytes, _ = (
  86. ffmpeg.input(audio_path, threads=0)
  87. .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
  88. .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
  89. )
  90. rec_result = inference_pipeline_asr(audio_in=audio_bytes, param_dict={})
  91. if add_pun:
  92. rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict={'cache': list()})
  93. ret = {"results": rec_result['text'], "code": 0}
  94. return ret
  95. if __name__ == '__main__':
  96. uvicorn.run(app, host=args.host, port=args.port, ssl_keyfile=args.keyfile, ssl_certfile=args.certfile)