ws_client.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # -*- encoding: utf-8 -*-
  2. import os
  3. import time
  4. import websockets,ssl
  5. import asyncio
  6. # import threading
  7. import argparse
  8. import json
  9. import traceback
  10. from multiprocessing import Process
  11. from funasr.fileio.datadir_writer import DatadirWriter
  12. import logging
  13. logging.basicConfig(level=logging.ERROR)
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument("--host",
  16. type=str,
  17. default="localhost",
  18. required=False,
  19. help="host ip, localhost, 0.0.0.0")
  20. parser.add_argument("--port",
  21. type=int,
  22. default=10095,
  23. required=False,
  24. help="grpc server port")
  25. parser.add_argument("--chunk_size",
  26. type=str,
  27. default="5, 10, 5",
  28. help="chunk")
  29. parser.add_argument("--chunk_interval",
  30. type=int,
  31. default=10,
  32. help="chunk")
  33. parser.add_argument("--audio_in",
  34. type=str,
  35. default=None,
  36. help="audio_in")
  37. parser.add_argument("--send_without_sleep",
  38. action="store_true",
  39. default=False,
  40. help="if audio_in is set, send_without_sleep")
  41. parser.add_argument("--test_thread_num",
  42. type=int,
  43. default=1,
  44. help="test_thread_num")
  45. parser.add_argument("--words_max_print",
  46. type=int,
  47. default=10000,
  48. help="chunk")
  49. parser.add_argument("--output_dir",
  50. type=str,
  51. default=None,
  52. help="output_dir")
  53. parser.add_argument("--ssl",
  54. type=int,
  55. default=1,
  56. help="1 for ssl connect, 0 for no ssl")
  57. args = parser.parse_args()
  58. args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
  59. print(args)
  60. # voices = asyncio.Queue()
  61. from queue import Queue
  62. voices = Queue()
  63. ibest_writer = None
  64. if args.output_dir is not None:
  65. writer = DatadirWriter(args.output_dir)
  66. ibest_writer = writer[f"1best_recog"]
  67. async def record_microphone():
  68. is_finished = False
  69. import pyaudio
  70. #print("2")
  71. global voices
  72. FORMAT = pyaudio.paInt16
  73. CHANNELS = 1
  74. RATE = 16000
  75. chunk_size = 60*args.chunk_size[1]/args.chunk_interval
  76. CHUNK = int(RATE / 1000 * chunk_size)
  77. p = pyaudio.PyAudio()
  78. stream = p.open(format=FORMAT,
  79. channels=CHANNELS,
  80. rate=RATE,
  81. input=True,
  82. frames_per_buffer=CHUNK)
  83. message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": "microphone", "is_speaking": True})
  84. voices.put(message)
  85. while True:
  86. data = stream.read(CHUNK)
  87. message = data
  88. voices.put(message)
  89. await asyncio.sleep(0.005)
  90. async def record_from_scp(chunk_begin,chunk_size):
  91. import wave
  92. global voices
  93. is_finished = False
  94. if args.audio_in.endswith(".scp"):
  95. f_scp = open(args.audio_in)
  96. wavs = f_scp.readlines()
  97. else:
  98. wavs = [args.audio_in]
  99. if chunk_size>0:
  100. wavs=wavs[chunk_begin:chunk_begin+chunk_size]
  101. for wav in wavs:
  102. wav_splits = wav.strip().split()
  103. wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
  104. wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
  105. # bytes_f = open(wav_path, "rb")
  106. # bytes_data = bytes_f.read()
  107. with wave.open(wav_path, "rb") as wav_file:
  108. params = wav_file.getparams()
  109. # header_length = wav_file.getheaders()[0][1]
  110. # wav_file.setpos(header_length)
  111. frames = wav_file.readframes(wav_file.getnframes())
  112. audio_bytes = bytes(frames)
  113. # stride = int(args.chunk_size/1000*16000*2)
  114. stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2)
  115. chunk_num = (len(audio_bytes)-1)//stride + 1
  116. # print(stride)
  117. # send first time
  118. message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": wav_name,"is_speaking": True})
  119. voices.put(message)
  120. is_speaking = True
  121. for i in range(chunk_num):
  122. beg = i*stride
  123. data = audio_bytes[beg:beg+stride]
  124. message = data
  125. voices.put(message)
  126. if i == chunk_num-1:
  127. is_speaking = False
  128. message = json.dumps({"is_speaking": is_speaking})
  129. voices.put(message)
  130. # print("data_chunk: ", len(data_chunk))
  131. # print(voices.qsize())
  132. sleep_duration = 0.001 if args.send_without_sleep else 60*args.chunk_size[1]/args.chunk_interval/1000
  133. await asyncio.sleep(sleep_duration)
  134. async def ws_send():
  135. global voices
  136. global websocket
  137. print("started to sending data!")
  138. while True:
  139. while not voices.empty():
  140. data = voices.get()
  141. voices.task_done()
  142. try:
  143. await websocket.send(data)
  144. except Exception as e:
  145. print('Exception occurred:', e)
  146. traceback.print_exc()
  147. exit(0)
  148. await asyncio.sleep(0.005)
  149. await asyncio.sleep(0.005)
  150. async def message(id):
  151. global websocket
  152. text_print = ""
  153. text_print_2pass_online = ""
  154. text_print_2pass_offline = ""
  155. while True:
  156. try:
  157. meg = await websocket.recv()
  158. meg = json.loads(meg)
  159. wav_name = meg.get("wav_name", "demo")
  160. # print(wav_name)
  161. text = meg["text"]
  162. if ibest_writer is not None:
  163. ibest_writer["text"][wav_name] = text
  164. if meg["mode"] == "online":
  165. text_print += "{}".format(text)
  166. text_print = text_print[-args.words_max_print:]
  167. os.system('clear')
  168. print("\rpid"+str(id)+": "+text_print)
  169. elif meg["mode"] == "online":
  170. text_print += "{}".format(text)
  171. text_print = text_print[-args.words_max_print:]
  172. os.system('clear')
  173. print("\rpid"+str(id)+": "+text_print)
  174. else:
  175. if meg["mode"] == "2pass-online":
  176. text_print_2pass_online += "{}".format(text)
  177. text_print = text_print_2pass_offline + text_print_2pass_online
  178. else:
  179. text_print_2pass_online = ""
  180. text_print = text_print_2pass_offline + "{}".format(text)
  181. text_print_2pass_offline += "{}".format(text)
  182. text_print = text_print[-args.words_max_print:]
  183. os.system('clear')
  184. print("\rpid" + str(id) + ": " + text_print)
  185. except Exception as e:
  186. print("Exception:", e)
  187. traceback.print_exc()
  188. exit(0)
  189. async def print_messge():
  190. global websocket
  191. while True:
  192. try:
  193. meg = await websocket.recv()
  194. meg = json.loads(meg)
  195. print(meg)
  196. except Exception as e:
  197. print("Exception:", e)
  198. traceback.print_exc()
  199. exit(0)
  200. async def ws_client(id,chunk_begin,chunk_size):
  201. global websocket
  202. if args.ssl==1:
  203. ssl_context = ssl.SSLContext()
  204. ssl_context.check_hostname = False
  205. ssl_context.verify_mode = ssl.CERT_NONE
  206. uri = "wss://{}:{}".format(args.host, args.port)
  207. else:
  208. uri = "ws://{}:{}".format(args.host, args.port)
  209. ssl_context=None
  210. print("connect to",uri)
  211. async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None,ssl=ssl_context):
  212. if args.audio_in is not None:
  213. task = asyncio.create_task(record_from_scp(chunk_begin,chunk_size))
  214. else:
  215. task = asyncio.create_task(record_microphone())
  216. task2 = asyncio.create_task(ws_send())
  217. task3 = asyncio.create_task(message(id))
  218. await asyncio.gather(task, task2, task3)
  219. def one_thread(id,chunk_begin,chunk_size):
  220. asyncio.get_event_loop().run_until_complete(ws_client(id,chunk_begin,chunk_size))
  221. asyncio.get_event_loop().run_forever()
  222. if __name__ == '__main__':
  223. # for microphone
  224. if args.audio_in is None:
  225. p = Process(target=one_thread,args=(0, 0, 0))
  226. p.start()
  227. p.join()
  228. print('end')
  229. else:
  230. # calculate the number of wavs for each preocess
  231. if args.audio_in.endswith(".scp"):
  232. f_scp = open(args.audio_in)
  233. wavs = f_scp.readlines()
  234. else:
  235. wavs = [args.audio_in]
  236. total_len=len(wavs)
  237. if total_len>=args.test_thread_num:
  238. chunk_size=int((total_len)/args.test_thread_num)
  239. remain_wavs=total_len-chunk_size*args.test_thread_num
  240. else:
  241. chunk_size=1
  242. remain_wavs=0
  243. process_list = []
  244. chunk_begin=0
  245. for i in range(args.test_thread_num):
  246. now_chunk_size= chunk_size
  247. if remain_wavs>0:
  248. now_chunk_size=chunk_size+1
  249. remain_wavs=remain_wavs-1
  250. # process i handle wavs at chunk_begin and size of now_chunk_size
  251. p = Process(target=one_thread,args=(i,chunk_begin,now_chunk_size))
  252. chunk_begin=chunk_begin+now_chunk_size
  253. p.start()
  254. process_list.append(p)
  255. for i in process_list:
  256. p.join()
  257. print('end')