ws_client.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # -*- encoding: utf-8 -*-
  2. import os
  3. import time
  4. import websockets
  5. import asyncio
  6. # import threading
  7. import argparse
  8. import json
  9. import traceback
  10. from multiprocessing import Process
  11. from funasr.fileio.datadir_writer import DatadirWriter
  12. import logging
  13. logging.basicConfig(level=logging.ERROR)
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument("--host",
  16. type=str,
  17. default="localhost",
  18. required=False,
  19. help="host ip, localhost, 0.0.0.0")
  20. parser.add_argument("--port",
  21. type=int,
  22. default=10095,
  23. required=False,
  24. help="grpc server port")
  25. parser.add_argument("--chunk_size",
  26. type=str,
  27. default="5, 10, 5",
  28. help="chunk")
  29. parser.add_argument("--chunk_interval",
  30. type=int,
  31. default=10,
  32. help="chunk")
  33. parser.add_argument("--audio_in",
  34. type=str,
  35. default=None,
  36. help="audio_in")
  37. parser.add_argument("--send_without_sleep",
  38. action="store_true",
  39. default=False,
  40. help="if audio_in is set, send_without_sleep")
  41. parser.add_argument("--test_thread_num",
  42. type=int,
  43. default=1,
  44. help="test_thread_num")
  45. parser.add_argument("--words_max_print",
  46. type=int,
  47. default=100,
  48. help="chunk")
  49. parser.add_argument("--output_dir",
  50. type=str,
  51. default=None,
  52. help="output_dir")
  53. args = parser.parse_args()
  54. args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
  55. print(args)
  56. # voices = asyncio.Queue()
  57. from queue import Queue
  58. voices = Queue()
  59. ibest_writer = None
  60. if args.output_dir is not None:
  61. writer = DatadirWriter(args.output_dir)
  62. ibest_writer = writer[f"1best_recog"]
  63. async def record_microphone():
  64. is_finished = False
  65. import pyaudio
  66. #print("2")
  67. global voices
  68. FORMAT = pyaudio.paInt16
  69. CHANNELS = 1
  70. RATE = 16000
  71. chunk_size = 60*args.chunk_size[1]/args.chunk_interval
  72. CHUNK = int(RATE / 1000 * chunk_size)
  73. p = pyaudio.PyAudio()
  74. stream = p.open(format=FORMAT,
  75. channels=CHANNELS,
  76. rate=RATE,
  77. input=True,
  78. frames_per_buffer=CHUNK)
  79. message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": "microphone", "is_speaking": True})
  80. voices.put(message)
  81. while True:
  82. data = stream.read(CHUNK)
  83. message = data
  84. voices.put(message)
  85. await asyncio.sleep(0.005)
  86. async def record_from_scp(chunk_begin,chunk_size):
  87. import wave
  88. global voices
  89. is_finished = False
  90. if args.audio_in.endswith(".scp"):
  91. f_scp = open(args.audio_in)
  92. wavs = f_scp.readlines()
  93. else:
  94. wavs = [args.audio_in]
  95. if chunk_size>0:
  96. wavs=wavs[chunk_begin:chunk_begin+chunk_size]
  97. for wav in wavs:
  98. wav_splits = wav.strip().split()
  99. wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
  100. wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
  101. # bytes_f = open(wav_path, "rb")
  102. # bytes_data = bytes_f.read()
  103. with wave.open(wav_path, "rb") as wav_file:
  104. params = wav_file.getparams()
  105. # header_length = wav_file.getheaders()[0][1]
  106. # wav_file.setpos(header_length)
  107. frames = wav_file.readframes(wav_file.getnframes())
  108. audio_bytes = bytes(frames)
  109. # stride = int(args.chunk_size/1000*16000*2)
  110. stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2)
  111. chunk_num = (len(audio_bytes)-1)//stride + 1
  112. # print(stride)
  113. # send first time
  114. message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": wav_name,"is_speaking": True})
  115. voices.put(message)
  116. is_speaking = True
  117. for i in range(chunk_num):
  118. beg = i*stride
  119. data = audio_bytes[beg:beg+stride]
  120. message = data
  121. voices.put(message)
  122. if i == chunk_num-1:
  123. is_speaking = False
  124. message = json.dumps({"is_speaking": is_speaking})
  125. voices.put(message)
  126. # print("data_chunk: ", len(data_chunk))
  127. # print(voices.qsize())
  128. sleep_duration = 0.001 if args.send_without_sleep else 60*args.chunk_size[1]/args.chunk_interval/1000
  129. await asyncio.sleep(sleep_duration)
  130. async def ws_send():
  131. global voices
  132. global websocket
  133. print("started to sending data!")
  134. while True:
  135. while not voices.empty():
  136. data = voices.get()
  137. voices.task_done()
  138. try:
  139. await websocket.send(data)
  140. except Exception as e:
  141. print('Exception occurred:', e)
  142. traceback.print_exc()
  143. exit(0)
  144. await asyncio.sleep(0.005)
  145. await asyncio.sleep(0.005)
  146. async def message(id):
  147. global websocket
  148. text_print = ""
  149. text_print_2pass_online = ""
  150. text_print_2pass_offline = ""
  151. while True:
  152. try:
  153. meg = await websocket.recv()
  154. meg = json.loads(meg)
  155. wav_name = meg.get("wav_name", "demo")
  156. # print(wav_name)
  157. text = meg["text"]
  158. if ibest_writer is not None:
  159. ibest_writer["text"][wav_name] = text
  160. if meg["mode"] == "online":
  161. text_print += " {}".format(text)
  162. text_print = text_print[-args.words_max_print:]
  163. os.system('clear')
  164. print("\rpid"+str(id)+": "+text_print)
  165. elif meg["mode"] == "online":
  166. text_print += "{}".format(text)
  167. text_print = text_print[-args.words_max_print:]
  168. os.system('clear')
  169. print("\rpid"+str(id)+": "+text_print)
  170. else:
  171. if meg["mode"] == "2pass-online":
  172. text_print_2pass_online += " {}".format(text)
  173. text_print = text_print_2pass_offline + text_print_2pass_online
  174. else:
  175. text_print_2pass_online = ""
  176. text_print = text_print_2pass_offline + "{}".format(text)
  177. text_print_2pass_offline += "{}".format(text)
  178. text_print = text_print[-args.words_max_print:]
  179. os.system('clear')
  180. print("\rpid" + str(id) + ": " + text_print)
  181. except Exception as e:
  182. print("Exception:", e)
  183. traceback.print_exc()
  184. exit(0)
  185. async def print_messge():
  186. global websocket
  187. while True:
  188. try:
  189. meg = await websocket.recv()
  190. meg = json.loads(meg)
  191. print(meg)
  192. except Exception as e:
  193. print("Exception:", e)
  194. traceback.print_exc()
  195. exit(0)
  196. async def ws_client(id,chunk_begin,chunk_size):
  197. global websocket
  198. uri = "ws://{}:{}".format(args.host, args.port)
  199. async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None):
  200. if args.audio_in is not None:
  201. task = asyncio.create_task(record_from_scp(chunk_begin,chunk_size))
  202. else:
  203. task = asyncio.create_task(record_microphone())
  204. task2 = asyncio.create_task(ws_send())
  205. task3 = asyncio.create_task(message(id))
  206. await asyncio.gather(task, task2, task3)
  207. def one_thread(id,chunk_begin,chunk_size):
  208. asyncio.get_event_loop().run_until_complete(ws_client(id,chunk_begin,chunk_size))
  209. asyncio.get_event_loop().run_forever()
  210. if __name__ == '__main__':
  211. # for microphone
  212. if args.audio_in is None:
  213. p = Process(target=one_thread,args=(0, 0, 0))
  214. p.start()
  215. p.join()
  216. print('end')
  217. else:
  218. # calculate the number of wavs for each preocess
  219. if args.audio_in.endswith(".scp"):
  220. f_scp = open(args.audio_in)
  221. wavs = f_scp.readlines()
  222. else:
  223. wavs = [args.audio_in]
  224. total_len=len(wavs)
  225. if total_len>=args.test_thread_num:
  226. chunk_size=int((total_len)/args.test_thread_num)
  227. remain_wavs=total_len-chunk_size*args.test_thread_num
  228. else:
  229. chunk_size=1
  230. remain_wavs=0
  231. process_list = []
  232. chunk_begin=0
  233. for i in range(args.test_thread_num):
  234. now_chunk_size= chunk_size
  235. if remain_wavs>0:
  236. now_chunk_size=chunk_size+1
  237. remain_wavs=remain_wavs-1
  238. # process i handle wavs at chunk_begin and size of now_chunk_size
  239. p = Process(target=one_thread,args=(i,chunk_begin,now_chunk_size))
  240. chunk_begin=chunk_begin+now_chunk_size
  241. p.start()
  242. process_list.append(p)
  243. for i in process_list:
  244. p.join()
  245. print('end')