server.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import argparse
  2. import logging
  3. import os
  4. import uuid
  5. import aiofiles
  6. import ffmpeg
  7. import uvicorn
  8. from fastapi import FastAPI, File, UploadFile
  9. from modelscope.utils.logger import get_logger
  10. from funasr import AutoModel
  11. logger = get_logger(log_level=logging.INFO)
  12. logger.setLevel(logging.INFO)
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument("--host",
  15. type=str,
  16. default="0.0.0.0",
  17. required=False,
  18. help="host ip, localhost, 0.0.0.0")
  19. parser.add_argument("--port",
  20. type=int,
  21. default=8000,
  22. required=False,
  23. help="server port")
  24. parser.add_argument("--asr_model",
  25. type=str,
  26. default="paraformer-zh",
  27. help="asr model from https://github.com/alibaba-damo-academy/FunASR?tab=readme-ov-file#model-zoo")
  28. parser.add_argument("--asr_model_revision",
  29. type=str,
  30. default="v2.0.4",
  31. help="")
  32. parser.add_argument("--vad_model",
  33. type=str,
  34. default="fsmn-vad",
  35. help="vad model from https://github.com/alibaba-damo-academy/FunASR?tab=readme-ov-file#model-zoo")
  36. parser.add_argument("--vad_model_revision",
  37. type=str,
  38. default="v2.0.4",
  39. help="")
  40. parser.add_argument("--punc_model",
  41. type=str,
  42. default="ct-punc-c",
  43. help="model from https://github.com/alibaba-damo-academy/FunASR?tab=readme-ov-file#model-zoo")
  44. parser.add_argument("--punc_model_revision",
  45. type=str,
  46. default="v2.0.4",
  47. help="")
  48. parser.add_argument("--ngpu",
  49. type=int,
  50. default=1,
  51. help="0 for cpu, 1 for gpu")
  52. parser.add_argument("--device",
  53. type=str,
  54. default="cuda",
  55. help="cuda, cpu")
  56. parser.add_argument("--ncpu",
  57. type=int,
  58. default=4,
  59. help="cpu cores")
  60. parser.add_argument("--hotword_path",
  61. type=str,
  62. default='hotwords.txt',
  63. help="hot word txt path, only the hot word model works")
  64. parser.add_argument("--certfile",
  65. type=str,
  66. default=None,
  67. required=False,
  68. help="certfile for ssl")
  69. parser.add_argument("--keyfile",
  70. type=str,
  71. default=None,
  72. required=False,
  73. help="keyfile for ssl")
  74. parser.add_argument("--temp_dir",
  75. type=str,
  76. default="temp_dir/",
  77. required=False,
  78. help="temp dir")
  79. args = parser.parse_args()
  80. logger.info("----------- Configuration Arguments -----------")
  81. for arg, value in vars(args).items():
  82. logger.info("%s: %s" % (arg, value))
  83. logger.info("------------------------------------------------")
  84. os.makedirs(args.temp_dir, exist_ok=True)
  85. logger.info("model loading")
  86. # load funasr model
  87. model = AutoModel(model=args.asr_model,
  88. model_revision=args.asr_model_revision,
  89. vad_model=args.vad_model,
  90. vad_model_revision=args.vad_model_revision,
  91. punc_model=args.punc_model,
  92. punc_model_revision=args.punc_model_revision,
  93. ngpu=args.ngpu,
  94. ncpu=args.ncpu,
  95. device=args.device,
  96. disable_pbar=True,
  97. disable_log=True)
  98. logger.info("loaded models!")
  99. app = FastAPI(title="FunASR")
  100. param_dict = {"sentence_timestamp": True, "batch_size_s": 300}
  101. if args.hotword_path is not None and os.path.exists(args.hotword_path):
  102. with open(args.hotword_path, 'r', encoding='utf-8') as f:
  103. lines = f.readlines()
  104. lines = [line.strip() for line in lines]
  105. hotword = ' '.join(lines)
  106. logger.info(f'热词:{hotword}')
  107. param_dict['hotword'] = hotword
  108. @app.post("/recognition")
  109. async def api_recognition(audio: UploadFile = File(..., description="audio file")):
  110. suffix = audio.filename.split('.')[-1]
  111. audio_path = f'{args.temp_dir}/{str(uuid.uuid1())}.{suffix}'
  112. async with aiofiles.open(audio_path, 'wb') as out_file:
  113. content = await audio.read()
  114. await out_file.write(content)
  115. try:
  116. audio_bytes, _ = (
  117. ffmpeg.input(audio_path, threads=0)
  118. .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
  119. .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
  120. )
  121. except Exception as e:
  122. logger.error(f'读取音频文件发生错误,错误信息:{e}')
  123. return {"msg": "读取音频文件发生错误", "code": 1}
  124. rec_results = model.generate(input=audio_bytes, is_final=True, **param_dict)
  125. # 结果为空
  126. if len(rec_results) == 0:
  127. return {"text": "", "sentences": [], "code": 0}
  128. elif len(rec_results) == 1:
  129. # 解析识别结果
  130. rec_result = rec_results[0]
  131. text = rec_result['text']
  132. sentences = []
  133. for sentence in rec_result['sentence_info']:
  134. # 每句话的时间戳
  135. sentences.append({'text': sentence['text'], 'start': sentence['start'], 'end': sentence['start']})
  136. ret = {"text": text, "sentences": sentences, "code": 0}
  137. logger.info(f'识别结果:{ret}')
  138. return ret
  139. else:
  140. logger.info(f'识别结果:{rec_results}')
  141. return {"msg": "未知错误", "code": -1}
  142. if __name__ == '__main__':
  143. uvicorn.run(app, host=args.host, port=args.port, ssl_keyfile=args.keyfile, ssl_certfile=args.certfile)