ws_server_online.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import asyncio
  2. import json
  3. import websockets
  4. import time
  5. from queue import Queue
  6. import threading
  7. import logging
  8. import tracemalloc
  9. import numpy as np
  10. from parse_args import args
  11. from modelscope.pipelines import pipeline
  12. from modelscope.utils.constant import Tasks
  13. from modelscope.utils.logger import get_logger
  14. from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
  15. tracemalloc.start()
  16. logger = get_logger(log_level=logging.CRITICAL)
  17. logger.setLevel(logging.CRITICAL)
  18. websocket_users = set()
  19. print("model loading")
  20. inference_pipeline_asr_online = pipeline(
  21. task=Tasks.auto_speech_recognition,
  22. model=args.asr_model_online,
  23. ngpu=args.ngpu,
  24. ncpu=args.ncpu,
  25. model_revision='v1.0.4')
  26. print("model loaded")
  27. async def ws_serve(websocket, path):
  28. frames_online = []
  29. global websocket_users
  30. websocket.send_msg = Queue()
  31. websocket_users.add(websocket)
  32. websocket.param_dict_asr_online = {"cache": dict()}
  33. websocket.speek_online = Queue()
  34. try:
  35. async for message in websocket:
  36. message = json.loads(message)
  37. is_finished = message["is_finished"]
  38. if not is_finished:
  39. audio = bytes(message['audio'], 'ISO-8859-1')
  40. is_speaking = message["is_speaking"]
  41. websocket.param_dict_asr_online["is_final"] = not is_speaking
  42. websocket.wav_name = message.get("wav_name", "demo")
  43. websocket.param_dict_asr_online["chunk_size"] = message["chunk_size"]
  44. frames_online.append(audio)
  45. if len(frames_online) % message["chunk_interval"] == 0 or not is_speaking:
  46. audio_in = b"".join(frames_online)
  47. await async_asr_online(websocket,audio_in)
  48. frames_online = []
  49. except websockets.ConnectionClosed:
  50. print("ConnectionClosed...", websocket_users)
  51. websocket_users.remove(websocket)
  52. except websockets.InvalidState:
  53. print("InvalidState...")
  54. except Exception as e:
  55. print("Exception:", e)
  56. async def async_asr_online(websocket,audio_in):
  57. if len(audio_in) > 0:
  58. audio_in = load_bytes(audio_in)
  59. rec_result = inference_pipeline_asr_online(audio_in=audio_in,
  60. param_dict=websocket.param_dict_asr_online)
  61. if websocket.param_dict_asr_online["is_final"]:
  62. websocket.param_dict_asr_online["cache"] = dict()
  63. if "text" in rec_result:
  64. if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
  65. # if len(rec_result["text"])>0:
  66. # rec_result["text"][0]=rec_result["text"][0] #.replace(" ","")
  67. message = json.dumps({"mode": "online", "text": rec_result["text"], "wav_name": websocket.wav_name})
  68. await websocket.send(message)
  69. start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
  70. asyncio.get_event_loop().run_until_complete(start_server)
  71. asyncio.get_event_loop().run_forever()