ws_client.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # -*- encoding: utf-8 -*-
  2. import os
  3. import time
  4. import websockets
  5. import asyncio
  6. # import threading
  7. import argparse
  8. import json
  9. import traceback
  10. from multiprocessing import Process
  11. from funasr.fileio.datadir_writer import DatadirWriter
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument("--host",
  14. type=str,
  15. default="localhost",
  16. required=False,
  17. help="host ip, localhost, 0.0.0.0")
  18. parser.add_argument("--port",
  19. type=int,
  20. default=10095,
  21. required=False,
  22. help="grpc server port")
  23. parser.add_argument("--chunk_size",
  24. type=str,
  25. default="5, 10, 5",
  26. help="chunk")
  27. parser.add_argument("--chunk_interval",
  28. type=int,
  29. default=10,
  30. help="chunk")
  31. parser.add_argument("--audio_in",
  32. type=str,
  33. default=None,
  34. help="audio_in")
  35. parser.add_argument("--send_without_sleep",
  36. action="store_true",
  37. default=False,
  38. help="if audio_in is set, send_without_sleep")
  39. parser.add_argument("--test_thread_num",
  40. type=int,
  41. default=1,
  42. help="test_thread_num")
  43. parser.add_argument("--words_max_print",
  44. type=int,
  45. default=100,
  46. help="chunk")
  47. parser.add_argument("--output_dir",
  48. type=str,
  49. default=None,
  50. help="output_dir")
  51. args = parser.parse_args()
  52. args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
  53. print(args)
  54. # voices = asyncio.Queue()
  55. from queue import Queue
  56. voices = Queue()
  57. ibest_writer = None
  58. if args.output_dir is not None:
  59. writer = DatadirWriter(args.output_dir)
  60. ibest_writer = writer[f"1best_recog"]
  61. async def record_microphone():
  62. is_finished = False
  63. import pyaudio
  64. #print("2")
  65. global voices
  66. FORMAT = pyaudio.paInt16
  67. CHANNELS = 1
  68. RATE = 16000
  69. chunk_size = 60*args.chunk_size[1]/args.chunk_interval
  70. CHUNK = int(RATE / 1000 * chunk_size)
  71. p = pyaudio.PyAudio()
  72. stream = p.open(format=FORMAT,
  73. channels=CHANNELS,
  74. rate=RATE,
  75. input=True,
  76. frames_per_buffer=CHUNK)
  77. is_speaking = True
  78. while True:
  79. data = stream.read(CHUNK)
  80. data = data.decode('ISO-8859-1')
  81. message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "audio": data, "is_speaking": is_speaking, "is_finished": is_finished})
  82. voices.put(message)
  83. await asyncio.sleep(0.005)
  84. async def record_from_scp():
  85. import wave
  86. global voices
  87. is_finished = False
  88. if args.audio_in.endswith(".scp"):
  89. f_scp = open(args.audio_in)
  90. wavs = f_scp.readlines()
  91. else:
  92. wavs = [args.audio_in]
  93. for wav in wavs:
  94. wav_splits = wav.strip().split()
  95. wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
  96. wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
  97. # bytes_f = open(wav_path, "rb")
  98. # bytes_data = bytes_f.read()
  99. with wave.open(wav_path, "rb") as wav_file:
  100. params = wav_file.getparams()
  101. # header_length = wav_file.getheaders()[0][1]
  102. # wav_file.setpos(header_length)
  103. frames = wav_file.readframes(wav_file.getnframes())
  104. audio_bytes = bytes(frames)
  105. # stride = int(args.chunk_size/1000*16000*2)
  106. stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2)
  107. chunk_num = (len(audio_bytes)-1)//stride + 1
  108. # print(stride)
  109. is_speaking = True
  110. for i in range(chunk_num):
  111. if i == chunk_num-1:
  112. is_speaking = False
  113. beg = i*stride
  114. data = audio_bytes[beg:beg+stride]
  115. data = data.decode('ISO-8859-1')
  116. message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "is_speaking": is_speaking, "audio": data, "is_finished": is_finished, "wav_name": wav_name})
  117. voices.put(message)
  118. # print("data_chunk: ", len(data_chunk))
  119. # print(voices.qsize())
  120. sleep_duration = 0.001 if args.send_without_sleep else 60*args.chunk_size[1]/args.chunk_interval/1000
  121. await asyncio.sleep(sleep_duration)
  122. is_finished = True
  123. message = json.dumps({"is_finished": is_finished})
  124. voices.put(message)
  125. async def ws_send():
  126. global voices
  127. global websocket
  128. print("started to sending data!")
  129. while True:
  130. while not voices.empty():
  131. data = voices.get()
  132. voices.task_done()
  133. try:
  134. await websocket.send(data)
  135. except Exception as e:
  136. print('Exception occurred:', e)
  137. traceback.print_exc()
  138. exit(0)
  139. await asyncio.sleep(0.005)
  140. await asyncio.sleep(0.005)
  141. async def message(id):
  142. global websocket
  143. text_print = ""
  144. while True:
  145. try:
  146. meg = await websocket.recv()
  147. meg = json.loads(meg)
  148. # print(meg, end = '')
  149. # print("\r")
  150. # print(meg)
  151. wav_name = meg.get("wav_name", "demo")
  152. print(wav_name)
  153. text = meg["text"]
  154. if ibest_writer is not None:
  155. ibest_writer["text"][wav_name] = text
  156. if meg["mode"] == "online":
  157. text_print += " {}".format(text)
  158. else:
  159. text_print += "{}".format(text)
  160. text_print = text_print[-args.words_max_print:]
  161. os.system('clear')
  162. print("\rpid"+str(id)+": "+text_print)
  163. except Exception as e:
  164. print("Exception:", e)
  165. traceback.print_exc()
  166. exit(0)
  167. async def print_messge():
  168. global websocket
  169. while True:
  170. try:
  171. meg = await websocket.recv()
  172. meg = json.loads(meg)
  173. print(meg)
  174. except Exception as e:
  175. print("Exception:", e)
  176. traceback.print_exc()
  177. exit(0)
  178. async def ws_client(id):
  179. global websocket
  180. uri = "ws://{}:{}".format(args.host, args.port)
  181. async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None):
  182. if args.audio_in is not None:
  183. task = asyncio.create_task(record_from_scp())
  184. else:
  185. task = asyncio.create_task(record_microphone())
  186. task2 = asyncio.create_task(ws_send())
  187. task3 = asyncio.create_task(message(id))
  188. await asyncio.gather(task, task2, task3)
  189. def one_thread(id):
  190. asyncio.get_event_loop().run_until_complete(ws_client(id)) # 启动协程
  191. asyncio.get_event_loop().run_forever()
  192. if __name__ == '__main__':
  193. process_list = []
  194. for i in range(args.test_thread_num):
  195. p = Process(target=one_thread,args=(i,))
  196. p.start()
  197. process_list.append(p)
  198. for i in process_list:
  199. p.join()
  200. print('end')