funasr_client_api.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. '''
  2. Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
  3. Reserved. MIT License (https://opensource.org/licenses/MIT)
  4. 2022-2023 by zhaomingwork@qq.com
  5. '''
  6. # pip install websocket-client
  7. import ssl
  8. from websocket import ABNF
  9. from websocket import create_connection
  10. from queue import Queue
  11. import threading
  12. import traceback
  13. import json
  14. import time
  15. import numpy as np
  16. # class for recognizer in websocket
  17. class Funasr_websocket_recognizer():
  18. '''
  19. python asr recognizer lib
  20. '''
  21. def __init__(self, host="127.0.0.1",
  22. port="30035",
  23. is_ssl=True,
  24. chunk_size="0, 10, 5",
  25. chunk_interval=10,
  26. mode="offline",
  27. wav_name="default"):
  28. '''
  29. host: server host ip
  30. port: server port
  31. is_ssl: True for wss protocal, False for ws
  32. '''
  33. try:
  34. if is_ssl == True:
  35. ssl_context = ssl.SSLContext()
  36. ssl_context.check_hostname = False
  37. ssl_context.verify_mode = ssl.CERT_NONE
  38. uri = "wss://{}:{}".format(host, port)
  39. ssl_opt={"cert_reqs": ssl.CERT_NONE}
  40. else:
  41. uri = "ws://{}:{}".format(host, port)
  42. ssl_context = None
  43. ssl_opt=None
  44. self.host = host
  45. self.port = port
  46. self.msg_queue = Queue() # used for recognized result text
  47. print("connect to url",uri)
  48. self.websocket=create_connection(uri, ssl=ssl_context, sslopt=ssl_opt)
  49. self.thread_msg = threading.Thread(target=Funasr_websocket_recognizer.thread_rec_msg, args=(self,))
  50. self.thread_msg.start()
  51. chunk_size = [int(x) for x in chunk_size.split(",")]
  52. stride = int(60 * chunk_size[1] / chunk_interval / 1000 * 16000 * 2)
  53. chunk_num = (len(audio_bytes) - 1) // stride + 1
  54. message = json.dumps({"mode": mode,
  55. "chunk_size": chunk_size,
  56. "encoder_chunk_look_back": 4,
  57. "decoder_chunk_look_back": 1,
  58. "chunk_interval": chunk_interval,
  59. "wav_name": wav_name,
  60. "is_speaking": True})
  61. self.websocket.send(message)
  62. print("send json",message)
  63. except Exception as e:
  64. print("Exception:", e)
  65. traceback.print_exc()
  66. # threads for rev msg
  67. def thread_rec_msg(self):
  68. try:
  69. while(True):
  70. msg=self.websocket.recv()
  71. if msg is None or len(msg) == 0:
  72. continue
  73. msg = json.loads(msg)
  74. self.msg_queue.put(msg)
  75. except Exception as e:
  76. print("client closed")
  77. # feed data to asr engine, wait_time means waiting for result until time out
  78. def feed_chunk(self, chunk, wait_time=0.01):
  79. try:
  80. self.websocket.send(chunk, ABNF.OPCODE_BINARY)
  81. # loop to check if there is a message, timeout in 0.01s
  82. while(True):
  83. msg = self.msg_queue.get(timeout=wait_time)
  84. if self.msg_queue.empty():
  85. break
  86. return msg
  87. except:
  88. return ""
  89. def close(self,timeout=1):
  90. message = json.dumps({"is_speaking": False})
  91. self.websocket.send(message)
  92. # sleep for timeout seconds to wait for result
  93. time.sleep(timeout)
  94. msg=""
  95. while(not self.msg_queue.empty()):
  96. msg = self.msg_queue.get()
  97. self.websocket.close()
  98. # only resturn the last msg
  99. return msg
  100. if __name__ == '__main__':
  101. print('example for Funasr_websocket_recognizer')
  102. import wave
  103. wav_path = "/Users/zhifu/Downloads/modelscope_models/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
  104. with wave.open(wav_path, "rb") as wav_file:
  105. params = wav_file.getparams()
  106. frames = wav_file.readframes(wav_file.getnframes())
  107. audio_bytes = bytes(frames)
  108. stride = int(60 * 10 / 10 / 1000 * 16000 * 2)
  109. chunk_num = (len(audio_bytes) - 1) // stride + 1
  110. # create an recognizer
  111. rcg = Funasr_websocket_recognizer(host="127.0.0.1",
  112. port="10095",
  113. is_ssl=True,
  114. mode="2pass",
  115. chunk_size="0,10,5")
  116. # loop to send chunk
  117. for i in range(chunk_num):
  118. beg = i * stride
  119. data = audio_bytes[beg:beg + stride]
  120. text = rcg.feed_chunk(data,wait_time=0.02)
  121. if len(text)>0:
  122. print("text",text)
  123. time.sleep(0.05)
  124. # get last message
  125. text = rcg.close(timeout=3)
  126. print("text",text)