funasr_wss_server.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import asyncio
  2. import json
  3. import websockets
  4. import time
  5. import logging
  6. import tracemalloc
  7. import numpy as np
  8. import argparse
  9. import ssl
  10. from modelscope.pipelines import pipeline
  11. from modelscope.utils.constant import Tasks
  12. from modelscope.utils.logger import get_logger
  13. tracemalloc.start()
  14. logger = get_logger(log_level=logging.CRITICAL)
  15. logger.setLevel(logging.CRITICAL)
  16. parser = argparse.ArgumentParser()
  17. parser.add_argument("--host",
  18. type=str,
  19. default="0.0.0.0",
  20. required=False,
  21. help="host ip, localhost, 0.0.0.0")
  22. parser.add_argument("--port",
  23. type=int,
  24. default=10095,
  25. required=False,
  26. help="grpc server port")
  27. parser.add_argument("--asr_model",
  28. type=str,
  29. default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
  30. help="model from modelscope")
  31. parser.add_argument("--asr_model_online",
  32. type=str,
  33. default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
  34. help="model from modelscope")
  35. parser.add_argument("--vad_model",
  36. type=str,
  37. default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
  38. help="model from modelscope")
  39. parser.add_argument("--punc_model",
  40. type=str,
  41. default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
  42. help="model from modelscope")
  43. parser.add_argument("--ngpu",
  44. type=int,
  45. default=1,
  46. help="0 for cpu, 1 for gpu")
  47. parser.add_argument("--ncpu",
  48. type=int,
  49. default=4,
  50. help="cpu cores")
  51. parser.add_argument("--certfile",
  52. type=str,
  53. default="../../ssl_key/server.crt",
  54. required=False,
  55. help="certfile for ssl")
  56. parser.add_argument("--keyfile",
  57. type=str,
  58. default="../../ssl_key/server.key",
  59. required=False,
  60. help="keyfile for ssl")
  61. args = parser.parse_args()
  62. websocket_users = set()
  63. print("model loading")
  64. # asr
  65. inference_pipeline_asr = pipeline(
  66. task=Tasks.auto_speech_recognition,
  67. model=args.asr_model,
  68. ngpu=args.ngpu,
  69. ncpu=args.ncpu,
  70. model_revision=None)
  71. # vad
  72. inference_pipeline_vad = pipeline(
  73. task=Tasks.voice_activity_detection,
  74. model=args.vad_model,
  75. model_revision=None,
  76. mode='online',
  77. ngpu=args.ngpu,
  78. ncpu=args.ncpu,
  79. )
  80. if args.punc_model != "":
  81. inference_pipeline_punc = pipeline(
  82. task=Tasks.punctuation,
  83. model=args.punc_model,
  84. model_revision="v1.0.2",
  85. ngpu=args.ngpu,
  86. ncpu=args.ncpu,
  87. )
  88. else:
  89. inference_pipeline_punc = None
  90. inference_pipeline_asr_online = pipeline(
  91. task=Tasks.auto_speech_recognition,
  92. model=args.asr_model_online,
  93. ngpu=args.ngpu,
  94. ncpu=args.ncpu,
  95. model_revision='v1.0.7',
  96. update_model='v1.0.7',
  97. mode='paraformer_streaming')
  98. print("model loaded! only support one client at the same time now!!!!")
  99. async def ws_reset(websocket):
  100. print("ws reset now, total num is ",len(websocket_users))
  101. websocket.param_dict_asr_online = {"cache": dict()}
  102. websocket.param_dict_vad = {'in_cache': dict(), "is_final": True}
  103. websocket.param_dict_asr_online["is_final"]=True
  104. # audio_in=b''.join(np.zeros(int(16000),dtype=np.int16))
  105. # inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
  106. # inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online)
  107. await websocket.close()
  108. async def clear_websocket():
  109. for websocket in websocket_users:
  110. await ws_reset(websocket)
  111. websocket_users.clear()
  112. async def ws_serve(websocket, path):
  113. frames = []
  114. frames_asr = []
  115. frames_asr_online = []
  116. global websocket_users
  117. await clear_websocket()
  118. websocket_users.add(websocket)
  119. websocket.param_dict_asr = {}
  120. websocket.param_dict_asr_online = {"cache": dict()}
  121. websocket.param_dict_vad = {'in_cache': dict(), "is_final": False}
  122. websocket.param_dict_punc = {'cache': list()}
  123. websocket.vad_pre_idx = 0
  124. speech_start = False
  125. speech_end_i = -1
  126. websocket.wav_name = "microphone"
  127. websocket.mode = "2pass"
  128. print("new user connected", flush=True)
  129. try:
  130. async for message in websocket:
  131. if isinstance(message, str):
  132. messagejson = json.loads(message)
  133. if "is_speaking" in messagejson:
  134. websocket.is_speaking = messagejson["is_speaking"]
  135. websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
  136. if "chunk_interval" in messagejson:
  137. websocket.chunk_interval = messagejson["chunk_interval"]
  138. if "wav_name" in messagejson:
  139. websocket.wav_name = messagejson.get("wav_name")
  140. if "chunk_size" in messagejson:
  141. websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
  142. if "encoder_chunk_look_back" in messagejson:
  143. websocket.param_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
  144. if "decoder_chunk_look_back" in messagejson:
  145. websocket.param_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
  146. if "mode" in messagejson:
  147. websocket.mode = messagejson["mode"]
  148. if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):
  149. if not isinstance(message, str):
  150. frames.append(message)
  151. duration_ms = len(message)//32
  152. websocket.vad_pre_idx += duration_ms
  153. # asr online
  154. frames_asr_online.append(message)
  155. websocket.param_dict_asr_online["is_final"] = speech_end_i != -1
  156. if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.param_dict_asr_online["is_final"]:
  157. if websocket.mode == "2pass" or websocket.mode == "online":
  158. audio_in = b"".join(frames_asr_online)
  159. await async_asr_online(websocket, audio_in)
  160. frames_asr_online = []
  161. if speech_start:
  162. frames_asr.append(message)
  163. # vad online
  164. speech_start_i, speech_end_i = await async_vad(websocket, message)
  165. if speech_start_i != -1:
  166. speech_start = True
  167. beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
  168. frames_pre = frames[-beg_bias:]
  169. frames_asr = []
  170. frames_asr.extend(frames_pre)
  171. # asr punc offline
  172. if speech_end_i != -1 or not websocket.is_speaking:
  173. # print("vad end point")
  174. if websocket.mode == "2pass" or websocket.mode == "offline":
  175. audio_in = b"".join(frames_asr)
  176. await async_asr(websocket, audio_in)
  177. frames_asr = []
  178. speech_start = False
  179. # frames_asr_online = []
  180. # websocket.param_dict_asr_online = {"cache": dict()}
  181. if not websocket.is_speaking:
  182. websocket.vad_pre_idx = 0
  183. frames = []
  184. websocket.param_dict_vad = {'in_cache': dict()}
  185. else:
  186. frames = frames[-20:]
  187. except websockets.ConnectionClosed:
  188. print("ConnectionClosed...", websocket_users,flush=True)
  189. await ws_reset(websocket)
  190. websocket_users.remove(websocket)
  191. except websockets.InvalidState:
  192. print("InvalidState...")
  193. except Exception as e:
  194. print("Exception:", e)
  195. async def async_vad(websocket, audio_in):
  196. segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
  197. speech_start = -1
  198. speech_end = -1
  199. if len(segments_result) == 0 or len(segments_result["text"]) > 1:
  200. return speech_start, speech_end
  201. if segments_result["text"][0][0] != -1:
  202. speech_start = segments_result["text"][0][0]
  203. if segments_result["text"][0][1] != -1:
  204. speech_end = segments_result["text"][0][1]
  205. return speech_start, speech_end
  206. async def async_asr(websocket, audio_in):
  207. if len(audio_in) > 0:
  208. # print(len(audio_in))
  209. rec_result = inference_pipeline_asr(audio_in=audio_in,
  210. param_dict=websocket.param_dict_asr)
  211. # print(rec_result)
  212. if inference_pipeline_punc is not None and 'text' in rec_result and len(rec_result["text"])>0:
  213. rec_result = inference_pipeline_punc(text_in=rec_result['text'],
  214. param_dict=websocket.param_dict_punc)
  215. # print("offline", rec_result)
  216. if 'text' in rec_result:
  217. mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
  218. message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name,"is_final":websocket.is_speaking})
  219. await websocket.send(message)
  220. async def async_asr_online(websocket, audio_in):
  221. if len(audio_in) > 0:
  222. # print(websocket.param_dict_asr_online.get("is_final", False))
  223. rec_result = inference_pipeline_asr_online(audio_in=audio_in,
  224. param_dict=websocket.param_dict_asr_online)
  225. # print(rec_result)
  226. if websocket.mode == "2pass" and websocket.param_dict_asr_online.get("is_final", False):
  227. return
  228. # websocket.param_dict_asr_online["cache"] = dict()
  229. if "text" in rec_result:
  230. if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
  231. # print("online", rec_result)
  232. mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
  233. message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name,"is_final":websocket.is_speaking})
  234. await websocket.send(message)
  235. if len(args.certfile)>0:
  236. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  237. # Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
  238. ssl_cert = args.certfile
  239. ssl_key = args.keyfile
  240. ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
  241. start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
  242. else:
  243. start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
  244. asyncio.get_event_loop().run_until_complete(start_server)
  245. asyncio.get_event_loop().run_forever()