ASR_server.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import asyncio
  2. import websockets
  3. import time
  4. from queue import Queue
  5. import threading
  6. import argparse
  7. from modelscope.pipelines import pipeline
  8. from modelscope.utils.constant import Tasks
  9. from modelscope.utils.logger import get_logger
  10. import logging
  11. import tracemalloc
  12. tracemalloc.start()
  13. logger = get_logger(log_level=logging.CRITICAL)
  14. logger.setLevel(logging.CRITICAL)
  15. websocket_users = set() #维护客户端列表
  16. parser = argparse.ArgumentParser()
  17. parser.add_argument("--host",
  18. type=str,
  19. default="0.0.0.0",
  20. required=False,
  21. help="host ip, localhost, 0.0.0.0")
  22. parser.add_argument("--port",
  23. type=int,
  24. default=10095,
  25. required=False,
  26. help="grpc server port")
  27. parser.add_argument("--asr_model",
  28. type=str,
  29. default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
  30. help="model from modelscope")
  31. parser.add_argument("--vad_model",
  32. type=str,
  33. default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
  34. help="model from modelscope")
  35. parser.add_argument("--punc_model",
  36. type=str,
  37. default="",
  38. help="model from modelscope")
  39. parser.add_argument("--ngpu",
  40. type=int,
  41. default=1,
  42. help="0 for cpu, 1 for gpu")
  43. args = parser.parse_args()
  44. print("model loading")
  45. # vad
  46. inference_pipeline_vad = pipeline(
  47. task=Tasks.voice_activity_detection,
  48. model=args.vad_model,
  49. model_revision=None,
  50. output_dir=None,
  51. batch_size=1,
  52. mode='online',
  53. ngpu=args.ngpu,
  54. )
  55. # param_dict_vad = {'in_cache': dict(), "is_final": False}
  56. # asr
  57. param_dict_asr = {}
  58. # param_dict["hotword"] = "小五 小五月" # 设置热词,用空格隔开
  59. inference_pipeline_asr = pipeline(
  60. task=Tasks.auto_speech_recognition,
  61. model=args.asr_model,
  62. param_dict=param_dict_asr,
  63. ngpu=args.ngpu,
  64. )
  65. if args.punc_model != "":
  66. # param_dict_punc = {'cache': list()}
  67. inference_pipeline_punc = pipeline(
  68. task=Tasks.punctuation,
  69. model=args.punc_model,
  70. model_revision=None,
  71. ngpu=args.ngpu,
  72. )
  73. else:
  74. inference_pipeline_punc = None
  75. print("model loaded")
  76. async def ws_serve(websocket, path):
  77. #speek = Queue()
  78. frames = [] # 存储所有的帧数据
  79. buffer = [] # 存储缓存中的帧数据(最多两个片段)
  80. RECORD_NUM = 0
  81. global websocket_users
  82. speech_start, speech_end = False, False
  83. # 调用asr函数
  84. websocket.param_dict_vad = {'in_cache': dict(), "is_final": False}
  85. websocket.param_dict_punc = {'cache': list()}
  86. websocket.speek = Queue() #websocket 添加进队列对象 让asr读取语音数据包
  87. websocket.send_msg = Queue() #websocket 添加个队列对象 让ws发送消息到客户端
  88. websocket_users.add(websocket)
  89. ss = threading.Thread(target=asr, args=(websocket,))
  90. ss.start()
  91. try:
  92. async for message in websocket:
  93. #voices.put(message)
  94. #print("put")
  95. #await websocket.send("123")
  96. buffer.append(message)
  97. if len(buffer) > 2:
  98. buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个
  99. if speech_start:
  100. frames.append(message)
  101. RECORD_NUM += 1
  102. speech_start_i, speech_end_i = vad(message, websocket)
  103. #print(speech_start_i, speech_end_i)
  104. if speech_start_i:
  105. speech_start = speech_start_i
  106. frames = []
  107. frames.extend(buffer) # 把之前2个语音数据快加入
  108. if speech_end_i or RECORD_NUM > 300:
  109. speech_start = False
  110. audio_in = b"".join(frames)
  111. websocket.speek.put(audio_in)
  112. frames = [] # 清空所有的帧数据
  113. buffer = [] # 清空缓存中的帧数据(最多两个片段)
  114. RECORD_NUM = 0
  115. if not websocket.send_msg.empty():
  116. await websocket.send(websocket.send_msg.get())
  117. websocket.send_msg.task_done()
  118. except websockets.ConnectionClosed:
  119. print("ConnectionClosed...", websocket_users) # 链接断开
  120. websocket_users.remove(websocket)
  121. except websockets.InvalidState:
  122. print("InvalidState...") # 无效状态
  123. except Exception as e:
  124. print("Exception:", e)
  125. def asr(websocket): # ASR推理
  126. global inference_pipeline_asr, inference_pipeline_punc
  127. # global param_dict_punc
  128. global websocket_users
  129. while websocket in websocket_users:
  130. if not websocket.speek.empty():
  131. audio_in = websocket.speek.get()
  132. websocket.speek.task_done()
  133. if len(audio_in) > 0:
  134. rec_result = inference_pipeline_asr(audio_in=audio_in)
  135. if inference_pipeline_punc is not None and 'text' in rec_result:
  136. rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=websocket.param_dict_punc)
  137. # print(rec_result)
  138. if "text" in rec_result:
  139. websocket.send_msg.put(rec_result["text"]) # 存入发送队列 直接调用send发送不了
  140. time.sleep(0.1)
  141. def vad(data, websocket): # VAD推理
  142. global inference_pipeline_vad
  143. #print(type(data))
  144. # print(param_dict_vad)
  145. segments_result = inference_pipeline_vad(audio_in=data, param_dict=websocket.param_dict_vad)
  146. # print(segments_result)
  147. # print(param_dict_vad)
  148. speech_start = False
  149. speech_end = False
  150. if len(segments_result) == 0 or len(segments_result["text"]) > 1:
  151. return speech_start, speech_end
  152. if segments_result["text"][0][0] != -1:
  153. speech_start = True
  154. if segments_result["text"][0][1] != -1:
  155. speech_end = True
  156. return speech_start, speech_end
  157. start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
  158. asyncio.get_event_loop().run_until_complete(start_server)
  159. asyncio.get_event_loop().run_forever()