ws_server_online.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import asyncio
  2. import json
  3. import websockets
  4. import time
  5. from queue import Queue
  6. import threading
  7. import logging
  8. import tracemalloc
  9. import numpy as np
  10. from parse_args import args
  11. from modelscope.pipelines import pipeline
  12. from modelscope.utils.constant import Tasks
  13. from modelscope.utils.logger import get_logger
  14. from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
  15. tracemalloc.start()
  16. logger = get_logger(log_level=logging.CRITICAL)
  17. logger.setLevel(logging.CRITICAL)
  18. websocket_users = set()
  19. print("model loading")
  20. inference_pipeline_asr_online = pipeline(
  21. task=Tasks.auto_speech_recognition,
  22. model=args.asr_model_online,
  23. ngpu=args.ngpu,
  24. ncpu=args.ncpu,
  25. model_revision='v1.0.4')
  26. print("model loaded")
  27. async def ws_serve(websocket, path):
  28. frames_asr_online = []
  29. global websocket_users
  30. websocket_users.add(websocket)
  31. websocket.param_dict_asr_online = {"cache": dict()}
  32. print("new user connected",flush=True)
  33. try:
  34. async for message in websocket:
  35. if isinstance(message,str):
  36. messagejson = json.loads(message)
  37. if "is_speaking" in messagejson:
  38. websocket.is_speaking = messagejson["is_speaking"]
  39. websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
  40. if "is_finished" in messagejson:
  41. websocket.is_speaking = False
  42. websocket.param_dict_asr_online["is_final"] = True
  43. if "chunk_interval" in messagejson:
  44. websocket.chunk_interval=messagejson["chunk_interval"]
  45. if "wav_name" in messagejson:
  46. websocket.wav_name = messagejson.get("wav_name", "demo")
  47. if "chunk_size" in messagejson:
  48. websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
  49. # if has bytes in buffer or message is bytes
  50. if len(frames_asr_online)>0 or not isinstance(message,str):
  51. if not isinstance(message,str):
  52. frames_asr_online.append(message)
  53. if len(frames_asr_online) % websocket.chunk_interval == 0 or not websocket.is_speaking:
  54. audio_in = b"".join(frames_asr_online)
  55. if not websocket.is_speaking:
  56. #padding 0.5s at end gurantee that asr engine can fire out last word
  57. audio_in=audio_in+b''.join(np.zeros(int(16000*0.5),dtype=np.int16))
  58. await async_asr_online(websocket,audio_in)
  59. frames_asr_online = []
  60. except websockets.ConnectionClosed:
  61. print("ConnectionClosed...", websocket_users)
  62. websocket_users.remove(websocket)
  63. except websockets.InvalidState:
  64. print("InvalidState...")
  65. except Exception as e:
  66. print("Exception:", e)
  67. async def async_asr_online(websocket,audio_in):
  68. if len(audio_in) > 0:
  69. audio_in = load_bytes(audio_in)
  70. rec_result = inference_pipeline_asr_online(audio_in=audio_in,
  71. param_dict=websocket.param_dict_asr_online)
  72. if websocket.param_dict_asr_online["is_final"]:
  73. websocket.param_dict_asr_online["cache"] = dict()
  74. if "text" in rec_result:
  75. if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
  76. message = json.dumps({"mode": "online", "text": rec_result["text"], "wav_name": websocket.wav_name})
  77. await websocket.send(message)
  78. start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
  79. asyncio.get_event_loop().run_until_complete(start_server)
  80. asyncio.get_event_loop().run_forever()