wss_client_asr.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # -*- encoding: utf-8 -*-
  2. import os
  3. import time
  4. import websockets,ssl
  5. import asyncio
  6. # import threading
  7. import argparse
  8. import json
  9. import traceback
  10. from multiprocessing import Process
  11. from funasr.fileio.datadir_writer import DatadirWriter
  12. import logging
  13. logging.basicConfig(level=logging.ERROR)
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument("--host",
  16. type=str,
  17. default="localhost",
  18. required=False,
  19. help="host ip, localhost, 0.0.0.0")
  20. parser.add_argument("--port",
  21. type=int,
  22. default=10095,
  23. required=False,
  24. help="grpc server port")
  25. parser.add_argument("--chunk_size",
  26. type=str,
  27. default="5, 10, 5",
  28. help="chunk")
  29. parser.add_argument("--chunk_interval",
  30. type=int,
  31. default=10,
  32. help="chunk")
  33. parser.add_argument("--audio_in",
  34. type=str,
  35. default=None,
  36. help="audio_in")
  37. parser.add_argument("--send_without_sleep",
  38. action="store_true",
  39. default=False,
  40. help="if audio_in is set, send_without_sleep")
  41. parser.add_argument("--test_thread_num",
  42. type=int,
  43. default=1,
  44. help="test_thread_num")
  45. parser.add_argument("--words_max_print",
  46. type=int,
  47. default=10000,
  48. help="chunk")
  49. parser.add_argument("--output_dir",
  50. type=str,
  51. default=None,
  52. help="output_dir")
  53. parser.add_argument("--ssl",
  54. type=int,
  55. default=1,
  56. help="1 for ssl connect, 0 for no ssl")
  57. parser.add_argument("--mode",
  58. type=str,
  59. default="2pass",
  60. help="offline, online, 2pass")
  61. args = parser.parse_args()
  62. args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
  63. print(args)
  64. # voices = asyncio.Queue()
  65. from queue import Queue
  66. voices = Queue()
  67. ibest_writer = None
  68. if args.output_dir is not None:
  69. writer = DatadirWriter(args.output_dir)
  70. ibest_writer = writer[f"1best_recog"]
  71. async def record_microphone():
  72. is_finished = False
  73. import pyaudio
  74. #print("2")
  75. global voices
  76. FORMAT = pyaudio.paInt16
  77. CHANNELS = 1
  78. RATE = 16000
  79. chunk_size = 60*args.chunk_size[1]/args.chunk_interval
  80. CHUNK = int(RATE / 1000 * chunk_size)
  81. p = pyaudio.PyAudio()
  82. stream = p.open(format=FORMAT,
  83. channels=CHANNELS,
  84. rate=RATE,
  85. input=True,
  86. frames_per_buffer=CHUNK)
  87. message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": "microphone", "is_speaking": True})
  88. voices.put(message)
  89. while True:
  90. data = stream.read(CHUNK)
  91. message = data
  92. voices.put(message)
  93. await asyncio.sleep(0.005)
  94. async def record_from_scp(chunk_begin,chunk_size):
  95. import wave
  96. global voices
  97. is_finished = False
  98. if args.audio_in.endswith(".scp"):
  99. f_scp = open(args.audio_in)
  100. wavs = f_scp.readlines()
  101. else:
  102. wavs = [args.audio_in]
  103. if chunk_size>0:
  104. wavs=wavs[chunk_begin:chunk_begin+chunk_size]
  105. for wav in wavs:
  106. wav_splits = wav.strip().split()
  107. wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
  108. wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
  109. # bytes_f = open(wav_path, "rb")
  110. # bytes_data = bytes_f.read()
  111. with wave.open(wav_path, "rb") as wav_file:
  112. params = wav_file.getparams()
  113. # header_length = wav_file.getheaders()[0][1]
  114. # wav_file.setpos(header_length)
  115. frames = wav_file.readframes(wav_file.getnframes())
  116. audio_bytes = bytes(frames)
  117. # stride = int(args.chunk_size/1000*16000*2)
  118. stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2)
  119. chunk_num = (len(audio_bytes)-1)//stride + 1
  120. # print(stride)
  121. # send first time
  122. message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": wav_name,"is_speaking": True})
  123. voices.put(message)
  124. is_speaking = True
  125. for i in range(chunk_num):
  126. beg = i*stride
  127. data = audio_bytes[beg:beg+stride]
  128. message = data
  129. voices.put(message)
  130. if i == chunk_num-1:
  131. is_speaking = False
  132. message = json.dumps({"is_speaking": is_speaking})
  133. voices.put(message)
  134. # print("data_chunk: ", len(data_chunk))
  135. # print(voices.qsize())
  136. sleep_duration = 0.001 if args.send_without_sleep else 60*args.chunk_size[1]/args.chunk_interval/1000
  137. await asyncio.sleep(sleep_duration)
  138. async def ws_send():
  139. global voices
  140. global websocket
  141. print("started to sending data!")
  142. while True:
  143. while not voices.empty():
  144. data = voices.get()
  145. voices.task_done()
  146. try:
  147. await websocket.send(data)
  148. except Exception as e:
  149. print('Exception occurred:', e)
  150. traceback.print_exc()
  151. exit(0)
  152. await asyncio.sleep(0.005)
  153. await asyncio.sleep(0.005)
  154. async def message(id):
  155. global websocket
  156. text_print = ""
  157. text_print_2pass_online = ""
  158. text_print_2pass_offline = ""
  159. while True:
  160. try:
  161. meg = await websocket.recv()
  162. meg = json.loads(meg)
  163. wav_name = meg.get("wav_name", "demo")
  164. # print(wav_name)
  165. text = meg["text"]
  166. if ibest_writer is not None:
  167. ibest_writer["text"][wav_name] = text
  168. if meg["mode"] == "online":
  169. text_print += "{}".format(text)
  170. text_print = text_print[-args.words_max_print:]
  171. os.system('clear')
  172. print("\rpid"+str(id)+": "+text_print)
  173. elif meg["mode"] == "offline":
  174. text_print += "{}".format(text)
  175. text_print = text_print[-args.words_max_print:]
  176. os.system('clear')
  177. print("\rpid"+str(id)+": "+text_print)
  178. else:
  179. if meg["mode"] == "2pass-online":
  180. text_print_2pass_online += "{}".format(text)
  181. text_print = text_print_2pass_offline + text_print_2pass_online
  182. else:
  183. text_print_2pass_online = ""
  184. text_print = text_print_2pass_offline + "{}".format(text)
  185. text_print_2pass_offline += "{}".format(text)
  186. text_print = text_print[-args.words_max_print:]
  187. os.system('clear')
  188. print("\rpid" + str(id) + ": " + text_print)
  189. except Exception as e:
  190. print("Exception:", e)
  191. traceback.print_exc()
  192. exit(0)
  193. async def print_messge():
  194. global websocket
  195. while True:
  196. try:
  197. meg = await websocket.recv()
  198. meg = json.loads(meg)
  199. print(meg)
  200. except Exception as e:
  201. print("Exception:", e)
  202. traceback.print_exc()
  203. exit(0)
  204. async def ws_client(id,chunk_begin,chunk_size):
  205. global websocket
  206. if args.ssl==1:
  207. ssl_context = ssl.SSLContext()
  208. ssl_context.check_hostname = False
  209. ssl_context.verify_mode = ssl.CERT_NONE
  210. uri = "wss://{}:{}".format(args.host, args.port)
  211. else:
  212. uri = "ws://{}:{}".format(args.host, args.port)
  213. ssl_context=None
  214. print("connect to",uri)
  215. async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None,ssl=ssl_context):
  216. if args.audio_in is not None:
  217. task = asyncio.create_task(record_from_scp(chunk_begin,chunk_size))
  218. else:
  219. task = asyncio.create_task(record_microphone())
  220. task2 = asyncio.create_task(ws_send())
  221. task3 = asyncio.create_task(message(id))
  222. await asyncio.gather(task, task2, task3)
  223. def one_thread(id,chunk_begin,chunk_size):
  224. asyncio.get_event_loop().run_until_complete(ws_client(id,chunk_begin,chunk_size))
  225. asyncio.get_event_loop().run_forever()
  226. if __name__ == '__main__':
  227. # for microphone
  228. if args.audio_in is None:
  229. p = Process(target=one_thread,args=(0, 0, 0))
  230. p.start()
  231. p.join()
  232. print('end')
  233. else:
  234. # calculate the number of wavs for each preocess
  235. if args.audio_in.endswith(".scp"):
  236. f_scp = open(args.audio_in)
  237. wavs = f_scp.readlines()
  238. else:
  239. wavs = [args.audio_in]
  240. total_len=len(wavs)
  241. if total_len>=args.test_thread_num:
  242. chunk_size=int((total_len)/args.test_thread_num)
  243. remain_wavs=total_len-chunk_size*args.test_thread_num
  244. else:
  245. chunk_size=1
  246. remain_wavs=0
  247. process_list = []
  248. chunk_begin=0
  249. for i in range(args.test_thread_num):
  250. now_chunk_size= chunk_size
  251. if remain_wavs>0:
  252. now_chunk_size=chunk_size+1
  253. remain_wavs=remain_wavs-1
  254. # process i handle wavs at chunk_begin and size of now_chunk_size
  255. p = Process(target=one_thread,args=(i,chunk_begin,now_chunk_size))
  256. chunk_begin=chunk_begin+now_chunk_size
  257. p.start()
  258. process_list.append(p)
  259. for i in process_list:
  260. p.join()
  261. print('end')