ws_server_online.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import asyncio
  2. import json
  3. import websockets
  4. import time
  5. from queue import Queue
  6. import threading
  7. import logging
  8. import tracemalloc
  9. import numpy as np
  10. import ssl
  11. from parse_args import args
  12. from modelscope.pipelines import pipeline
  13. from modelscope.utils.constant import Tasks
  14. from modelscope.utils.logger import get_logger
  15. from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
  16. tracemalloc.start()
  17. logger = get_logger(log_level=logging.CRITICAL)
  18. logger.setLevel(logging.CRITICAL)
  19. websocket_users = set()
  20. print("model loading")
  21. inference_pipeline_asr_online = pipeline(
  22. task=Tasks.auto_speech_recognition,
  23. model=args.asr_model_online,
  24. ngpu=args.ngpu,
  25. ncpu=args.ncpu,
  26. model_revision='v1.0.4')
  27. print("model loaded")
  28. async def ws_serve(websocket, path):
  29. frames_asr_online = []
  30. global websocket_users
  31. websocket_users.add(websocket)
  32. websocket.param_dict_asr_online = {"cache": dict()}
  33. websocket.wav_name = "microphone"
  34. print("new user connected",flush=True)
  35. try:
  36. async for message in websocket:
  37. if isinstance(message, str):
  38. messagejson = json.loads(message)
  39. if "is_speaking" in messagejson:
  40. websocket.is_speaking = messagejson["is_speaking"]
  41. websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
  42. # need to fire engine manually if no data received any more
  43. if not websocket.is_speaking:
  44. await async_asr_online(websocket,b"")
  45. if "chunk_interval" in messagejson:
  46. websocket.chunk_interval=messagejson["chunk_interval"]
  47. if "wav_name" in messagejson:
  48. websocket.wav_name = messagejson.get("wav_name")
  49. if "chunk_size" in messagejson:
  50. websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
  51. # if has bytes in buffer or message is bytes
  52. if len(frames_asr_online) > 0 or not isinstance(message, str):
  53. if not isinstance(message,str):
  54. frames_asr_online.append(message)
  55. if len(frames_asr_online) % websocket.chunk_interval == 0 or not websocket.is_speaking:
  56. audio_in = b"".join(frames_asr_online)
  57. # if not websocket.is_speaking:
  58. #padding 0.5s at end gurantee that asr engine can fire out last word
  59. # audio_in=audio_in+b''.join(np.zeros(int(16000*0.5),dtype=np.int16))
  60. await async_asr_online(websocket,audio_in)
  61. frames_asr_online = []
  62. except websockets.ConnectionClosed:
  63. print("ConnectionClosed...", websocket_users)
  64. websocket_users.remove(websocket)
  65. except websockets.InvalidState:
  66. print("InvalidState...")
  67. except Exception as e:
  68. print("Exception:", e)
  69. async def async_asr_online(websocket,audio_in):
  70. if len(audio_in) >=0:
  71. audio_in = load_bytes(audio_in)
  72. rec_result = inference_pipeline_asr_online(audio_in=audio_in,
  73. param_dict=websocket.param_dict_asr_online)
  74. if websocket.param_dict_asr_online.get("is_final", False):
  75. websocket.param_dict_asr_online["cache"] = dict()
  76. if "text" in rec_result:
  77. if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
  78. message = json.dumps({"mode": "online", "text": rec_result["text"], "wav_name": websocket.wav_name})
  79. await websocket.send(message)
  80. if len(args.certfile)>0:
  81. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  82. # Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
  83. ssl_cert = args.certfile
  84. ssl_key = args.keyfile
  85. ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
  86. start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
  87. else:
  88. start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
  89. asyncio.get_event_loop().run_until_complete(start_server)
  90. asyncio.get_event_loop().run_forever()