ws_server_online.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import asyncio
  2. import json
  3. import websockets
  4. import time
  5. from queue import Queue
  6. import threading
  7. import logging
  8. import tracemalloc
  9. import numpy as np
  10. from parse_args import args
  11. from modelscope.pipelines import pipeline
  12. from modelscope.utils.constant import Tasks
  13. from modelscope.utils.logger import get_logger
  14. from funasr_onnx.utils.frontend import load_bytes
  15. tracemalloc.start()
  16. logger = get_logger(log_level=logging.CRITICAL)
  17. logger.setLevel(logging.CRITICAL)
  18. websocket_users = set()
  19. print("model loading")
  20. inference_pipeline_asr_online = pipeline(
  21. task=Tasks.auto_speech_recognition,
  22. model=args.asr_model_online,
  23. model_revision='v1.0.4')
  24. print("model loaded")
  25. async def ws_serve(websocket, path):
  26. frames_online = []
  27. global websocket_users
  28. websocket.send_msg = Queue()
  29. websocket_users.add(websocket)
  30. websocket.param_dict_asr_online = {"cache": dict()}
  31. websocket.speek_online = Queue()
  32. ss_online = threading.Thread(target=asr_online, args=(websocket,))
  33. ss_online.start()
  34. try:
  35. async for message in websocket:
  36. message = json.loads(message)
  37. is_finished = message["is_finished"]
  38. if not is_finished:
  39. audio = bytes(message['audio'], 'ISO-8859-1')
  40. is_speaking = message["is_speaking"]
  41. websocket.param_dict_asr_online["is_final"] = not is_speaking
  42. websocket.param_dict_asr_online["chunk_size"] = message["chunk_size"]
  43. frames_online.append(audio)
  44. if len(frames_online) % message["chunk_interval"] == 0 or not is_speaking:
  45. audio_in = b"".join(frames_online)
  46. websocket.speek_online.put(audio_in)
  47. frames_online = []
  48. if not websocket.send_msg.empty():
  49. await websocket.send(websocket.send_msg.get())
  50. websocket.send_msg.task_done()
  51. except websockets.ConnectionClosed:
  52. print("ConnectionClosed...", websocket_users) # 链接断开
  53. websocket_users.remove(websocket)
  54. except websockets.InvalidState:
  55. print("InvalidState...") # 无效状态
  56. except Exception as e:
  57. print("Exception:", e)
  58. def asr_online(websocket): # ASR推理
  59. global websocket_users
  60. while websocket in websocket_users:
  61. if not websocket.speek_online.empty():
  62. audio_in = websocket.speek_online.get()
  63. websocket.speek_online.task_done()
  64. if len(audio_in) > 0:
  65. # print(len(audio_in))
  66. audio_in = load_bytes(audio_in)
  67. rec_result = inference_pipeline_asr_online(audio_in=audio_in,
  68. param_dict=websocket.param_dict_asr_online)
  69. if websocket.param_dict_asr_online["is_final"]:
  70. websocket.param_dict_asr_online["cache"] = dict()
  71. if "text" in rec_result:
  72. if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
  73. print(rec_result["text"])
  74. message = json.dumps({"mode": "online", "text": rec_result["text"]})
  75. websocket.send_msg.put(message)
  76. time.sleep(0.005)
  77. start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
  78. asyncio.get_event_loop().run_until_complete(start_server)
  79. asyncio.get_event_loop().run_forever()