ws_server_online.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import asyncio
  2. import json
  3. import websockets
  4. import time
  5. from queue import Queue
  6. import threading
  7. import logging
  8. import tracemalloc
  9. import numpy as np
  10. import ssl
  11. from parse_args import args
  12. from modelscope.pipelines import pipeline
  13. from modelscope.utils.constant import Tasks
  14. from modelscope.utils.logger import get_logger
  15. from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
  16. tracemalloc.start()
  17. logger = get_logger(log_level=logging.CRITICAL)
  18. logger.setLevel(logging.CRITICAL)
  19. websocket_users = set()
  20. print("model loading")
  21. inference_pipeline_asr_online = pipeline(
  22. task=Tasks.auto_speech_recognition,
  23. model=args.asr_model_online,
  24. ngpu=args.ngpu,
  25. ncpu=args.ncpu,
  26. model_revision='v1.0.4')
  27. # vad
  28. inference_pipeline_vad = pipeline(
  29. task=Tasks.voice_activity_detection,
  30. model=args.vad_model,
  31. model_revision=None,
  32. output_dir=None,
  33. batch_size=1,
  34. mode='online',
  35. ngpu=args.ngpu,
  36. ncpu=1,
  37. )
  38. print("model loaded")
  39. async def ws_serve(websocket, path):
  40. frames = []
  41. frames_asr_online = []
  42. global websocket_users
  43. websocket_users.add(websocket)
  44. websocket.param_dict_asr_online = {"cache": dict()}
  45. websocket.param_dict_vad = {'in_cache': dict()}
  46. websocket.wav_name = "microphone"
  47. print("new user connected",flush=True)
  48. try:
  49. async for message in websocket:
  50. if isinstance(message, str):
  51. messagejson = json.loads(message)
  52. if "is_speaking" in messagejson:
  53. websocket.is_speaking = messagejson["is_speaking"]
  54. websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
  55. websocket.param_dict_vad["is_final"] = not websocket.is_speaking
  56. # need to fire engine manually if no data received any more
  57. if not websocket.is_speaking:
  58. await async_asr_online(websocket, b"")
  59. if "chunk_interval" in messagejson:
  60. websocket.chunk_interval=messagejson["chunk_interval"]
  61. if "wav_name" in messagejson:
  62. websocket.wav_name = messagejson.get("wav_name")
  63. if "chunk_size" in messagejson:
  64. websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
  65. # if has bytes in buffer or message is bytes
  66. if len(frames_asr_online) > 0 or not isinstance(message, str):
  67. if not isinstance(message, str):
  68. frames_asr_online.append(message)
  69. # frames.append(message)
  70. # duration_ms = len(message) // 32
  71. # websocket.vad_pre_idx += duration_ms
  72. speech_start_i, speech_end_i = await async_vad(websocket, message)
  73. websocket.is_speaking = not speech_end_i
  74. if len(frames_asr_online) % websocket.chunk_interval == 0 or not websocket.is_speaking:
  75. websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
  76. audio_in = b"".join(frames_asr_online)
  77. await async_asr_online(websocket, audio_in)
  78. frames_asr_online = []
  79. except websockets.ConnectionClosed:
  80. print("ConnectionClosed...", websocket_users)
  81. websocket_users.remove(websocket)
  82. except websockets.InvalidState:
  83. print("InvalidState...")
  84. except Exception as e:
  85. print("Exception:", e)
  86. async def async_asr_online(websocket,audio_in):
  87. if len(audio_in) >= 0:
  88. audio_in = load_bytes(audio_in)
  89. # print(websocket.param_dict_asr_online.get("is_final", False))
  90. rec_result = inference_pipeline_asr_online(audio_in=audio_in,
  91. param_dict=websocket.param_dict_asr_online)
  92. # print(rec_result)
  93. if websocket.param_dict_asr_online.get("is_final", False):
  94. websocket.param_dict_asr_online["cache"] = dict()
  95. if "text" in rec_result:
  96. if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
  97. message = json.dumps({"mode": "online", "text": rec_result["text"], "wav_name": websocket.wav_name})
  98. await websocket.send(message)
  99. async def async_vad(websocket, audio_in):
  100. segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
  101. speech_start = False
  102. speech_end = False
  103. if len(segments_result) == 0 or len(segments_result["text"]) > 1:
  104. return speech_start, speech_end
  105. if segments_result["text"][0][0] != -1:
  106. speech_start = segments_result["text"][0][0]
  107. if segments_result["text"][0][1] != -1:
  108. speech_end = True
  109. return speech_start, speech_end
  110. if len(args.certfile)>0:
  111. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  112. # Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
  113. ssl_cert = args.certfile
  114. ssl_key = args.keyfile
  115. ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
  116. start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
  117. else:
  118. start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
  119. asyncio.get_event_loop().run_until_complete(start_server)
  120. asyncio.get_event_loop().run_forever()