funasr_wss_server.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. import asyncio
  2. import json
  3. import websockets
  4. import time
  5. import logging
  6. import tracemalloc
  7. import numpy as np
  8. import argparse
  9. import ssl
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument("--host",
  12. type=str,
  13. default="0.0.0.0",
  14. required=False,
  15. help="host ip, localhost, 0.0.0.0")
  16. parser.add_argument("--port",
  17. type=int,
  18. default=10095,
  19. required=False,
  20. help="grpc server port")
  21. parser.add_argument("--asr_model",
  22. type=str,
  23. default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
  24. help="model from modelscope")
  25. parser.add_argument("--asr_model_revision",
  26. type=str,
  27. default="v2.0.4",
  28. help="")
  29. parser.add_argument("--asr_model_online",
  30. type=str,
  31. default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
  32. help="model from modelscope")
  33. parser.add_argument("--asr_model_online_revision",
  34. type=str,
  35. default="v2.0.4",
  36. help="")
  37. parser.add_argument("--vad_model",
  38. type=str,
  39. default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
  40. help="model from modelscope")
  41. parser.add_argument("--vad_model_revision",
  42. type=str,
  43. default="v2.0.4",
  44. help="")
  45. parser.add_argument("--punc_model",
  46. type=str,
  47. default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
  48. help="model from modelscope")
  49. parser.add_argument("--punc_model_revision",
  50. type=str,
  51. default="v2.0.4",
  52. help="")
  53. parser.add_argument("--ngpu",
  54. type=int,
  55. default=1,
  56. help="0 for cpu, 1 for gpu")
  57. parser.add_argument("--device",
  58. type=str,
  59. default="cuda",
  60. help="cuda, cpu")
  61. parser.add_argument("--ncpu",
  62. type=int,
  63. default=4,
  64. help="cpu cores")
  65. parser.add_argument("--certfile",
  66. type=str,
  67. default="../../ssl_key/server.crt",
  68. required=False,
  69. help="certfile for ssl")
  70. parser.add_argument("--keyfile",
  71. type=str,
  72. default="../../ssl_key/server.key",
  73. required=False,
  74. help="keyfile for ssl")
  75. args = parser.parse_args()
  76. websocket_users = set()
  77. print("model loading")
  78. from funasr import AutoModel
  79. # asr
  80. model_asr = AutoModel(model=args.asr_model,
  81. model_revision=args.asr_model_revision,
  82. ngpu=args.ngpu,
  83. ncpu=args.ncpu,
  84. device=args.device,
  85. disable_pbar=True,
  86. disable_log=True,
  87. )
  88. # asr
  89. model_asr_streaming = AutoModel(model=args.asr_model_online,
  90. model_revision=args.asr_model_online_revision,
  91. ngpu=args.ngpu,
  92. ncpu=args.ncpu,
  93. device=args.device,
  94. disable_pbar=True,
  95. disable_log=True,
  96. )
  97. # vad
  98. model_vad = AutoModel(model=args.vad_model,
  99. model_revision=args.vad_model_revision,
  100. ngpu=args.ngpu,
  101. ncpu=args.ncpu,
  102. device=args.device,
  103. disable_pbar=True,
  104. disable_log=True,
  105. # chunk_size=60,
  106. )
  107. if args.punc_model != "":
  108. model_punc = AutoModel(model=args.punc_model,
  109. model_revision=args.punc_model_revision,
  110. ngpu=args.ngpu,
  111. ncpu=args.ncpu,
  112. device=args.device,
  113. disable_pbar=True,
  114. disable_log=True,
  115. )
  116. else:
  117. model_punc = None
  118. print("model loaded! only support one client at the same time now!!!!")
  119. async def ws_reset(websocket):
  120. print("ws reset now, total num is ",len(websocket_users))
  121. websocket.status_dict_asr_online["cache"] = {}
  122. websocket.status_dict_asr_online["is_final"] = True
  123. websocket.status_dict_vad["cache"] = {}
  124. websocket.status_dict_vad["is_final"] = True
  125. websocket.status_dict_punc["cache"] = {}
  126. await websocket.close()
  127. async def clear_websocket():
  128. for websocket in websocket_users:
  129. await ws_reset(websocket)
  130. websocket_users.clear()
  131. async def ws_serve(websocket, path):
  132. frames = []
  133. frames_asr = []
  134. frames_asr_online = []
  135. global websocket_users
  136. # await clear_websocket()
  137. websocket_users.add(websocket)
  138. websocket.status_dict_asr = {}
  139. websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
  140. websocket.status_dict_vad = {'cache': {}, "is_final": False}
  141. websocket.status_dict_punc = {'cache': {}}
  142. websocket.chunk_interval = 10
  143. websocket.vad_pre_idx = 0
  144. speech_start = False
  145. speech_end_i = -1
  146. websocket.wav_name = "microphone"
  147. websocket.mode = "2pass"
  148. print("new user connected", flush=True)
  149. try:
  150. async for message in websocket:
  151. if isinstance(message, str):
  152. messagejson = json.loads(message)
  153. if "is_speaking" in messagejson:
  154. websocket.is_speaking = messagejson["is_speaking"]
  155. websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
  156. if "chunk_interval" in messagejson:
  157. websocket.chunk_interval = messagejson["chunk_interval"]
  158. if "wav_name" in messagejson:
  159. websocket.wav_name = messagejson.get("wav_name")
  160. if "chunk_size" in messagejson:
  161. chunk_size = messagejson["chunk_size"].split(',')
  162. websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
  163. if "encoder_chunk_look_back" in messagejson:
  164. websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
  165. if "decoder_chunk_look_back" in messagejson:
  166. websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
  167. if "hotword" in messagejson:
  168. websocket.status_dict_asr["hotword"] = messagejson["hotword"]
  169. if "mode" in messagejson:
  170. websocket.mode = messagejson["mode"]
  171. websocket.status_dict_vad["chunk_size"] = int(websocket.status_dict_asr_online["chunk_size"][1]*60/websocket.chunk_interval)
  172. if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):
  173. if not isinstance(message, str):
  174. frames.append(message)
  175. duration_ms = len(message)//32
  176. websocket.vad_pre_idx += duration_ms
  177. # asr online
  178. frames_asr_online.append(message)
  179. websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
  180. if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.status_dict_asr_online["is_final"]:
  181. if websocket.mode == "2pass" or websocket.mode == "online":
  182. audio_in = b"".join(frames_asr_online)
  183. try:
  184. await async_asr_online(websocket, audio_in)
  185. except:
  186. print(f"error in asr streaming, {websocket.status_dict_asr_online}")
  187. frames_asr_online = []
  188. if speech_start:
  189. frames_asr.append(message)
  190. # vad online
  191. try:
  192. speech_start_i, speech_end_i = await async_vad(websocket, message)
  193. except:
  194. print("error in vad")
  195. if speech_start_i != -1:
  196. speech_start = True
  197. beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
  198. frames_pre = frames[-beg_bias:]
  199. frames_asr = []
  200. frames_asr.extend(frames_pre)
  201. # asr punc offline
  202. if speech_end_i != -1 or not websocket.is_speaking:
  203. # print("vad end point")
  204. if websocket.mode == "2pass" or websocket.mode == "offline":
  205. audio_in = b"".join(frames_asr)
  206. try:
  207. await async_asr(websocket, audio_in)
  208. except:
  209. print("error in asr offline")
  210. frames_asr = []
  211. speech_start = False
  212. frames_asr_online = []
  213. websocket.status_dict_asr_online["cache"] = {}
  214. if not websocket.is_speaking:
  215. websocket.vad_pre_idx = 0
  216. frames = []
  217. websocket.status_dict_vad["cache"] = {}
  218. else:
  219. frames = frames[-20:]
  220. except websockets.ConnectionClosed:
  221. print("ConnectionClosed...", websocket_users,flush=True)
  222. await ws_reset(websocket)
  223. websocket_users.remove(websocket)
  224. except websockets.InvalidState:
  225. print("InvalidState...")
  226. except Exception as e:
  227. print("Exception:", e)
  228. async def async_vad(websocket, audio_in):
  229. segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"]
  230. # print(segments_result)
  231. speech_start = -1
  232. speech_end = -1
  233. if len(segments_result) == 0 or len(segments_result) > 1:
  234. return speech_start, speech_end
  235. if segments_result[0][0] != -1:
  236. speech_start = segments_result[0][0]
  237. if segments_result[0][1] != -1:
  238. speech_end = segments_result[0][1]
  239. return speech_start, speech_end
  240. async def async_asr(websocket, audio_in):
  241. if len(audio_in) > 0:
  242. # print(len(audio_in))
  243. rec_result = model_asr.generate(input=audio_in, **websocket.status_dict_asr)[0]
  244. # print("offline_asr, ", rec_result)
  245. if model_punc is not None and len(rec_result["text"])>0:
  246. # print("offline, before punc", rec_result, "cache", websocket.status_dict_punc)
  247. rec_result = model_punc.generate(input=rec_result['text'], **websocket.status_dict_punc)[0]
  248. # print("offline, after punc", rec_result)
  249. if len(rec_result["text"])>0:
  250. # print("offline", rec_result)
  251. mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
  252. message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name,"is_final":websocket.is_speaking})
  253. await websocket.send(message)
  254. async def async_asr_online(websocket, audio_in):
  255. if len(audio_in) > 0:
  256. # print(websocket.status_dict_asr_online.get("is_final", False))
  257. rec_result = model_asr_streaming.generate(input=audio_in, **websocket.status_dict_asr_online)[0]
  258. # print("online, ", rec_result)
  259. if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
  260. return
  261. # websocket.status_dict_asr_online["cache"] = dict()
  262. if len(rec_result["text"]):
  263. mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
  264. message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name,"is_final":websocket.is_speaking})
  265. await websocket.send(message)
  266. if len(args.certfile)>0:
  267. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  268. # Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
  269. ssl_cert = args.certfile
  270. ssl_key = args.keyfile
  271. ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
  272. start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
  273. else:
  274. start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
  275. asyncio.get_event_loop().run_until_complete(start_server)
  276. asyncio.get_event_loop().run_forever()