funasr_client_api.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. '''
  2. Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
  3. Reserved. MIT License (https://opensource.org/licenses/MIT)
  4. 2022-2023 by zhaomingwork@qq.com
  5. '''
  6. # pip install websocket-client
  7. import ssl
  8. from websocket import ABNF
  9. from websocket import create_connection
  10. from queue import Queue
  11. import threading
  12. import traceback
  13. import json
  14. import time
  15. import numpy as np
  16. # class for recognizer in websocket
  17. class Funasr_websocket_recognizer():
  18. '''
  19. python asr recognizer lib
  20. '''
  21. def __init__(self, host="127.0.0.1", port="30035", is_ssl=True,chunk_size="5, 10, 5",chunk_interval=10,mode="offline",wav_name="default"):
  22. '''
  23. host: server host ip
  24. port: server port
  25. is_ssl: True for wss protocal, False for ws
  26. '''
  27. try:
  28. if is_ssl == True:
  29. ssl_context = ssl.SSLContext()
  30. ssl_context.check_hostname = False
  31. ssl_context.verify_mode = ssl.CERT_NONE
  32. uri = "wss://{}:{}".format(host, port)
  33. ssl_opt={"cert_reqs": ssl.CERT_NONE}
  34. else:
  35. uri = "ws://{}:{}".format(host, port)
  36. ssl_context = None
  37. ssl_opt=None
  38. self.host = host
  39. self.port = port
  40. self.msg_queue = Queue() # used for recognized result text
  41. print("connect to url",uri)
  42. self.websocket=create_connection(uri,ssl=ssl_context,sslopt=ssl_opt)
  43. self.thread_msg = threading.Thread(target=Funasr_websocket_recognizer.thread_rec_msg,args=(self,))
  44. self.thread_msg.start()
  45. chunk_size = [int(x) for x in chunk_size.split(",")]
  46. stride = int(60 * chunk_size[1]/ chunk_interval / 1000 * 16000 * 2)
  47. chunk_num = (len(audio_bytes) - 1) // stride + 1
  48. message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "encoder_chunk_look_back": 4,
  49. "decoder_chunk_look_back": 1, "chunk_interval": args.chunk_interval,
  50. "wav_name": wav_name, "is_speaking": True})
  51. self.websocket.send(message)
  52. print("send json",message)
  53. except Exception as e:
  54. print("Exception:", e)
  55. traceback.print_exc()
  56. # threads for rev msg
  57. def thread_rec_msg(self):
  58. try:
  59. while(True):
  60. msg=self.websocket.recv()
  61. if msg is None or len(msg)==0:
  62. continue
  63. msg = json.loads(msg)
  64. self.msg_queue.put(msg)
  65. except Exception as e:
  66. print("client closed")
  67. # feed data to asr engine, wait_time means waiting for result until time out
  68. def feed_chunk(self, chunk,wait_time=0.01):
  69. try:
  70. self.websocket.send(chunk, ABNF.OPCODE_BINARY)
  71. # loop to check if there is a message, timeout in 0.01s
  72. while(True):
  73. msg = self.msg_queue.get(timeout=wait_time)
  74. if self.msg_queue.empty():
  75. break
  76. return msg
  77. except:
  78. return ""
  79. def close(self,timeout=1):
  80. message = json.dumps({"is_speaking": False})
  81. self.websocket.send(message)
  82. # sleep for timeout seconds to wait for result
  83. time.sleep(timeout)
  84. msg=""
  85. while(not self.msg_queue.empty()):
  86. msg = self.msg_queue.get()
  87. self.websocket.close()
  88. # only resturn the last msg
  89. return msg
  90. if __name__ == '__main__':
  91. print('example for Funasr_websocket_recognizer')
  92. import wave
  93. wav_path="asr_example.wav"
  94. with wave.open(wav_path, "rb") as wav_file:
  95. params = wav_file.getparams()
  96. frames = wav_file.readframes(wav_file.getnframes())
  97. audio_bytes = bytes(frames)
  98. stride = int(60 * 10 / 10 / 1000 * 16000 * 2)
  99. chunk_num = (len(audio_bytes) - 1) // stride + 1
  100. # create an recognizer
  101. rcg=Funasr_websocket_recognizer(host="127.0.0.1",port="30035",is_ssl=True,mode="2pass")
  102. # loop to send chunk
  103. for i in range(chunk_num):
  104. beg = i * stride
  105. data = audio_bytes[beg:beg + stride]
  106. text=rcg.feed_chunk(data,wait_time=0.02)
  107. if len(text)>0:
  108. print("text",text)
  109. time.sleep(0.05)
  110. # get last message
  111. text=rcg.close(timeout=3)
  112. print("text",text)