funasr_wss_server.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import asyncio
  2. import json
  3. import websockets
  4. import time
  5. import logging
  6. import tracemalloc
  7. import numpy as np
  8. import argparse
  9. import ssl
  10. from modelscope.pipelines import pipeline
  11. from modelscope.utils.constant import Tasks
  12. from modelscope.utils.logger import get_logger
  13. from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
  14. tracemalloc.start()
  15. logger = get_logger(log_level=logging.CRITICAL)
  16. logger.setLevel(logging.CRITICAL)
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument("--host",
  19. type=str,
  20. default="0.0.0.0",
  21. required=False,
  22. help="host ip, localhost, 0.0.0.0")
  23. parser.add_argument("--port",
  24. type=int,
  25. default=10095,
  26. required=False,
  27. help="grpc server port")
  28. parser.add_argument("--asr_model",
  29. type=str,
  30. default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
  31. help="model from modelscope")
  32. parser.add_argument("--asr_model_online",
  33. type=str,
  34. default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
  35. help="model from modelscope")
  36. parser.add_argument("--vad_model",
  37. type=str,
  38. default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
  39. help="model from modelscope")
  40. parser.add_argument("--punc_model",
  41. type=str,
  42. default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
  43. help="model from modelscope")
  44. parser.add_argument("--ngpu",
  45. type=int,
  46. default=1,
  47. help="0 for cpu, 1 for gpu")
  48. parser.add_argument("--ncpu",
  49. type=int,
  50. default=4,
  51. help="cpu cores")
  52. parser.add_argument("--certfile",
  53. type=str,
  54. default="./ssl_key/server.crt",
  55. required=False,
  56. help="certfile for ssl")
  57. parser.add_argument("--keyfile",
  58. type=str,
  59. default="./ssl_key/server.key",
  60. required=False,
  61. help="keyfile for ssl")
  62. args = parser.parse_args()
  63. websocket_users = set()
  64. print("model loading")
  65. # asr
  66. inference_pipeline_asr = pipeline(
  67. task=Tasks.auto_speech_recognition,
  68. model=args.asr_model,
  69. ngpu=args.ngpu,
  70. ncpu=args.ncpu,
  71. model_revision=None)
  72. # vad
  73. inference_pipeline_vad = pipeline(
  74. task=Tasks.voice_activity_detection,
  75. model=args.vad_model,
  76. model_revision=None,
  77. mode='online',
  78. ngpu=args.ngpu,
  79. ncpu=args.ncpu,
  80. )
  81. if args.punc_model != "":
  82. inference_pipeline_punc = pipeline(
  83. task=Tasks.punctuation,
  84. model=args.punc_model,
  85. model_revision="v1.0.2",
  86. ngpu=args.ngpu,
  87. ncpu=args.ncpu,
  88. )
  89. else:
  90. inference_pipeline_punc = None
  91. inference_pipeline_asr_online = pipeline(
  92. task=Tasks.auto_speech_recognition,
  93. model=args.asr_model_online,
  94. ngpu=args.ngpu,
  95. ncpu=args.ncpu,
  96. model_revision='v1.0.4',
  97. update_model='v1.0.4',
  98. mode='paraformer_streaming')
  99. print("model loaded! only support one client at the same time now!!!!")
  100. async def ws_reset(websocket):
  101. print("ws reset now, total num is ",len(websocket_users))
  102. websocket.param_dict_asr_online = {"cache": dict()}
  103. websocket.param_dict_vad = {'in_cache': dict(), "is_final": True}
  104. websocket.param_dict_asr_online["is_final"]=True
  105. # audio_in=b''.join(np.zeros(int(16000),dtype=np.int16))
  106. # inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
  107. # inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online)
  108. await websocket.close()
  109. async def clear_websocket():
  110. for websocket in websocket_users:
  111. await ws_reset(websocket)
  112. websocket_users.clear()
  113. async def ws_serve(websocket, path):
  114. frames = []
  115. frames_asr = []
  116. frames_asr_online = []
  117. global websocket_users
  118. await clear_websocket()
  119. websocket_users.add(websocket)
  120. websocket.param_dict_asr = {}
  121. websocket.param_dict_asr_online = {"cache": dict()}
  122. websocket.param_dict_vad = {'in_cache': dict(), "is_final": False}
  123. websocket.param_dict_punc = {'cache': list()}
  124. websocket.vad_pre_idx = 0
  125. speech_start = False
  126. speech_end_i = -1
  127. websocket.wav_name = "microphone"
  128. websocket.mode = "2pass"
  129. print("new user connected", flush=True)
  130. try:
  131. async for message in websocket:
  132. if isinstance(message, str):
  133. messagejson = json.loads(message)
  134. if "is_speaking" in messagejson:
  135. websocket.is_speaking = messagejson["is_speaking"]
  136. websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
  137. if "chunk_interval" in messagejson:
  138. websocket.chunk_interval = messagejson["chunk_interval"]
  139. if "wav_name" in messagejson:
  140. websocket.wav_name = messagejson.get("wav_name")
  141. if "chunk_size" in messagejson:
  142. websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
  143. if "mode" in messagejson:
  144. websocket.mode = messagejson["mode"]
  145. if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):
  146. if not isinstance(message, str):
  147. frames.append(message)
  148. duration_ms = len(message)//32
  149. websocket.vad_pre_idx += duration_ms
  150. # asr online
  151. frames_asr_online.append(message)
  152. websocket.param_dict_asr_online["is_final"] = speech_end_i != -1
  153. if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.param_dict_asr_online["is_final"]:
  154. if websocket.mode == "2pass" or websocket.mode == "online":
  155. audio_in = b"".join(frames_asr_online)
  156. await async_asr_online(websocket, audio_in)
  157. frames_asr_online = []
  158. if speech_start:
  159. frames_asr.append(message)
  160. # vad online
  161. speech_start_i, speech_end_i = await async_vad(websocket, message)
  162. if speech_start_i != -1:
  163. speech_start = True
  164. beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
  165. frames_pre = frames[-beg_bias:]
  166. frames_asr = []
  167. frames_asr.extend(frames_pre)
  168. # asr punc offline
  169. if speech_end_i != -1 or not websocket.is_speaking:
  170. # print("vad end point")
  171. if websocket.mode == "2pass" or websocket.mode == "offline":
  172. audio_in = b"".join(frames_asr)
  173. await async_asr(websocket, audio_in)
  174. frames_asr = []
  175. speech_start = False
  176. # frames_asr_online = []
  177. # websocket.param_dict_asr_online = {"cache": dict()}
  178. if not websocket.is_speaking:
  179. websocket.vad_pre_idx = 0
  180. frames = []
  181. websocket.param_dict_vad = {'in_cache': dict()}
  182. else:
  183. frames = frames[-20:]
  184. except websockets.ConnectionClosed:
  185. print("ConnectionClosed...", websocket_users,flush=True)
  186. await ws_reset(websocket)
  187. websocket_users.remove(websocket)
  188. except websockets.InvalidState:
  189. print("InvalidState...")
  190. except Exception as e:
  191. print("Exception:", e)
  192. async def async_vad(websocket, audio_in):
  193. segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
  194. speech_start = -1
  195. speech_end = -1
  196. if len(segments_result) == 0 or len(segments_result["text"]) > 1:
  197. return speech_start, speech_end
  198. if segments_result["text"][0][0] != -1:
  199. speech_start = segments_result["text"][0][0]
  200. if segments_result["text"][0][1] != -1:
  201. speech_end = segments_result["text"][0][1]
  202. return speech_start, speech_end
  203. async def async_asr(websocket, audio_in):
  204. if len(audio_in) > 0:
  205. # print(len(audio_in))
  206. audio_in = load_bytes(audio_in)
  207. rec_result = inference_pipeline_asr(audio_in=audio_in,
  208. param_dict=websocket.param_dict_asr)
  209. # print(rec_result)
  210. if inference_pipeline_punc is not None and 'text' in rec_result and len(rec_result["text"])>0:
  211. rec_result = inference_pipeline_punc(text_in=rec_result['text'],
  212. param_dict=websocket.param_dict_punc)
  213. # print("offline", rec_result)
  214. if 'text' in rec_result:
  215. message = json.dumps({"mode": "2pass-offline", "text": rec_result["text"], "wav_name": websocket.wav_name})
  216. await websocket.send(message)
  217. async def async_asr_online(websocket, audio_in):
  218. if len(audio_in) > 0:
  219. audio_in = load_bytes(audio_in)
  220. # print(websocket.param_dict_asr_online.get("is_final", False))
  221. rec_result = inference_pipeline_asr_online(audio_in=audio_in,
  222. param_dict=websocket.param_dict_asr_online)
  223. # print(rec_result)
  224. if websocket.mode == "2pass" and websocket.param_dict_asr_online.get("is_final", False):
  225. return
  226. # websocket.param_dict_asr_online["cache"] = dict()
  227. if "text" in rec_result:
  228. if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
  229. # print("online", rec_result)
  230. message = json.dumps({"mode": "2pass-online", "text": rec_result["text"], "wav_name": websocket.wav_name})
  231. await websocket.send(message)
  232. if len(args.certfile)>0:
  233. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  234. # Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
  235. ssl_cert = args.certfile
  236. ssl_key = args.keyfile
  237. ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
  238. start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
  239. else:
  240. start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
  241. asyncio.get_event_loop().run_until_complete(start_server)
  242. asyncio.get_event_loop().run_forever()