funasr_wss_client.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # -*- encoding: utf-8 -*-
  2. import os
  3. import time
  4. import websockets, ssl
  5. import asyncio
  6. # import threading
  7. import argparse
  8. import json
  9. import traceback
  10. from multiprocessing import Process
  11. # from funasr.fileio.datadir_writer import DatadirWriter
  12. import logging
  13. logging.basicConfig(level=logging.ERROR)
  14. logger = logging.getLogger("funasr")
  15. logger.setLevel(logging.INFO)
  16. parser = argparse.ArgumentParser()
  17. parser.add_argument("--host",
  18. type=str,
  19. default="localhost",
  20. required=False,
  21. help="host ip, localhost, 0.0.0.0")
  22. parser.add_argument("--port",
  23. type=int,
  24. default=10095,
  25. required=False,
  26. help="grpc server port")
  27. parser.add_argument("--chunk_size",
  28. type=str,
  29. default="5, 10, 5",
  30. help="chunk")
  31. parser.add_argument("--encoder_chunk_look_back",
  32. type=int,
  33. default=4,
  34. help="chunk")
  35. parser.add_argument("--decoder_chunk_look_back",
  36. type=int,
  37. default=0,
  38. help="chunk")
  39. parser.add_argument("--chunk_interval",
  40. type=int,
  41. default=10,
  42. help="chunk")
  43. parser.add_argument("--hotword",
  44. type=str,
  45. default="",
  46. help="hotword file path, one hotword perline (e.g.:阿里巴巴 20)")
  47. parser.add_argument("--audio_in",
  48. type=str,
  49. default=None,
  50. help="audio_in")
  51. parser.add_argument("--audio_fs",
  52. type=int,
  53. default=16000,
  54. help="audio_fs")
  55. parser.add_argument("--send_without_sleep",
  56. action="store_true",
  57. default=True,
  58. help="if audio_in is set, send_without_sleep")
  59. parser.add_argument("--thread_num",
  60. type=int,
  61. default=1,
  62. help="thread_num")
  63. parser.add_argument("--words_max_print",
  64. type=int,
  65. default=10000,
  66. help="chunk")
  67. parser.add_argument("--output_dir",
  68. type=str,
  69. default=None,
  70. help="output_dir")
  71. parser.add_argument("--output_json",
  72. type=str,
  73. default=None,
  74. help="output json result")
  75. parser.add_argument("--ssl",
  76. type=int,
  77. default=1,
  78. help="1 for ssl connect, 0 for no ssl")
  79. parser.add_argument("--use_itn",
  80. type=int,
  81. default=1,
  82. help="1 for using itn, 0 for not itn")
  83. parser.add_argument("--mode",
  84. type=str,
  85. default="2pass",
  86. help="offline, online, 2pass")
  87. # args = parser.parse_args()
  88. args, unparsed = parser.parse_known_args()
  89. args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
  90. logger.debug(f"args: {args}")
  91. # voices = asyncio.Queue()
  92. from queue import Queue
  93. voices = Queue()
  94. offline_msg_done=False
  95. if args.output_dir is not None:
  96. # if os.path.exists(args.output_dir):
  97. # os.remove(args.output_dir)
  98. if not os.path.exists(args.output_dir):
  99. os.makedirs(args.output_dir)
  100. async def record_microphone():
  101. is_finished = False
  102. import pyaudio
  103. # print("2")
  104. global voices
  105. FORMAT = pyaudio.paInt16
  106. CHANNELS = 1
  107. RATE = 16000
  108. chunk_size = 60 * args.chunk_size[1] / args.chunk_interval
  109. CHUNK = int(RATE / 1000 * chunk_size)
  110. p = pyaudio.PyAudio()
  111. stream = p.open(format=FORMAT,
  112. channels=CHANNELS,
  113. rate=RATE,
  114. input=True,
  115. frames_per_buffer=CHUNK)
  116. # hotwords
  117. fst_dict = {}
  118. hotword_msg = ""
  119. if args.hotword.strip() != "":
  120. if os.path.exists(args.hotword):
  121. f_scp = open(args.hotword)
  122. hot_lines = f_scp.readlines()
  123. for line in hot_lines:
  124. words = line.strip().split(" ")
  125. if len(words) < 2:
  126. print("Please checkout format of hotwords")
  127. continue
  128. try:
  129. fst_dict[" ".join(words[:-1])] = int(words[-1])
  130. except ValueError:
  131. print("Please checkout format of hotwords")
  132. hotword_msg = json.dumps(fst_dict)
  133. else:
  134. hotword_msg = args.hotword
  135. use_itn = True
  136. if args.use_itn == 0:
  137. use_itn=False
  138. message = json.dumps({"mode": args.mode,
  139. "chunk_size": args.chunk_size,
  140. "chunk_interval": args.chunk_interval,
  141. "encoder_chunk_look_back": args.encoder_chunk_look_back,
  142. "decoder_chunk_look_back": args.decoder_chunk_look_back,
  143. "wav_name": "microphone",
  144. "is_speaking": True,
  145. "hotwords": hotword_msg,
  146. "itn": use_itn,
  147. })
  148. #voices.put(message)
  149. await websocket.send(message)
  150. while True:
  151. data = stream.read(CHUNK)
  152. message = data
  153. #voices.put(message)
  154. await websocket.send(message)
  155. await asyncio.sleep(0.005)
  156. async def record_from_scp(chunk_begin, chunk_size):
  157. global voices
  158. is_finished = False
  159. print(f"args: {args}")
  160. print(f"chunk_begin: {chunk_begin}")
  161. print(f"chunk_size: {chunk_size}")
  162. if args.audio_in.endswith(".scp"):
  163. f_scp = open(args.audio_in)
  164. wavs = f_scp.readlines()
  165. else:
  166. wavs = [args.audio_in]
  167. # hotwords
  168. fst_dict = {}
  169. hotword_msg = ""
  170. if args.hotword.strip() != "":
  171. if os.path.exists(args.hotword):
  172. f_scp = open(args.hotword)
  173. hot_lines = f_scp.readlines()
  174. for line in hot_lines:
  175. words = line.strip().split(" ")
  176. if len(words) < 2:
  177. print("Please checkout format of hotwords")
  178. continue
  179. try:
  180. fst_dict[" ".join(words[:-1])] = int(words[-1])
  181. except ValueError:
  182. print("Please checkout format of hotwords")
  183. hotword_msg = json.dumps(fst_dict)
  184. else:
  185. hotword_msg = args.hotword
  186. print ("hotword_msg: ", hotword_msg)
  187. sample_rate = args.audio_fs
  188. wav_format = "pcm"
  189. use_itn=True
  190. if args.use_itn == 0:
  191. use_itn=False
  192. if chunk_size > 0:
  193. wavs = wavs[chunk_begin:chunk_begin + chunk_size]
  194. for wav in wavs:
  195. wav_splits = wav.strip().split()
  196. wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
  197. wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
  198. if not len(wav_path.strip())>0:
  199. continue
  200. if wav_path.endswith(".pcm"):
  201. with open(wav_path, "rb") as f:
  202. audio_bytes = f.read()
  203. elif wav_path.endswith(".wav"):
  204. import wave
  205. with wave.open(wav_path, "rb") as wav_file:
  206. params = wav_file.getparams()
  207. sample_rate = wav_file.getframerate()
  208. frames = wav_file.readframes(wav_file.getnframes())
  209. audio_bytes = bytes(frames)
  210. else:
  211. wav_format = "others"
  212. with open(wav_path, "rb") as f:
  213. audio_bytes = f.read()
  214. stride = int(60 * args.chunk_size[1] / args.chunk_interval / 1000 * sample_rate * 2)
  215. chunk_num = (len(audio_bytes) - 1) // stride + 1
  216. # print(stride)
  217. # send first time
  218. message = json.dumps({"mode": args.mode,
  219. "chunk_size": args.chunk_size,
  220. "chunk_interval": args.chunk_interval,
  221. "encoder_chunk_look_back": args.encoder_chunk_look_back,
  222. "decoder_chunk_look_back": args.decoder_chunk_look_back,
  223. "audio_fs":sample_rate,
  224. "wav_name": wav_name,
  225. "wav_format": wav_format,
  226. "is_speaking": True,
  227. "hotwords": hotword_msg,
  228. "itn": use_itn})
  229. #voices.put(message)
  230. await websocket.send(message)
  231. is_speaking = True
  232. # 在一次测试中 chunk_num=218 ,则 i 从 0 到 218 不停地对服务器发送数据 message
  233. for i in range(chunk_num):
  234. beg = i * stride
  235. data = audio_bytes[beg:beg + stride]
  236. message = data
  237. #voices.put(message)
  238. # message 是二进制音频数据
  239. await websocket.send(message)
  240. if i == chunk_num - 1:
  241. is_speaking = False
  242. message = json.dumps({"is_speaking": is_speaking})
  243. #voices.put(message)
  244. await websocket.send(message)
  245. sleep_duration = 0.001 if args.mode == "offline" else 60 * args.chunk_size[1] / args.chunk_interval / 1000
  246. await asyncio.sleep(sleep_duration)
  247. if not args.mode=="offline":
  248. await asyncio.sleep(2)
  249. # offline model need to wait for message recved
  250. if args.mode=="offline":
  251. global offline_msg_done
  252. while not offline_msg_done:
  253. await asyncio.sleep(1)
  254. await websocket.close()
  255. {
  256. "is_final":False,
  257. "mode":"offline",
  258. "stamp_sents":
  259. [{"end":2270,"punc":",","start":430,"text_seg":"正 是 因 为 存 在 绝 对 正 义",
  260. "ts_list":[[430,670],[670,810],[810,1030],[1030,1130],[1130,1330],[1330,1510],[1510,1670],[1670,1810],[1810,1970],[1970,2270]]},
  261. {"end":4505,"punc":"。","start":2270,"text_seg":"所 以 我 们 接 受 现 实 的 相 对 正 义",
  262. "ts_list":[[2270,2389],[2389,2490],[2490,2570],[2570,2709],[2709,2969],[2969,3310],[3310,3570],[3570,3730],[3730,3830],[3830,3969],[3969,4150],[4150,4270],[4270,4505]]},
  263. {"end":7490,"punc":",","start":5310,"text_seg":"但 是 不 要 因 为 现 实 的 相 对 正 义",
  264. "ts_list":[[5310,5470],[5470,5610],[5610,5710],[5710,5910],[5910,6069],[6069,6210],[6210,6470],[6470,6650],[6650,6750],[6750,6950],[6950,7130],[7130,7250],[7250,7490]]},
  265. {"end":9965,"punc":"。","start":7490,"text_seg":"我 们 就 认 为 这 个 世 界 没 有 正 义",
  266. "ts_list":[[7490,7590],[7590,7710],[7710,7910],[7910,8069],[8069,8290],[8290,8430],[8430,8550],[8550,8710],[8710,9050],[9050,9370],[9370,9550],[9550,9790],[9790,9965]]},
  267. {"end":12915,"punc":"。","start":10600,"text_seg":"因 为 如 果 当 你 认 为 这 个 世 界 没 有 正 义",
  268. "ts_list":[[10600,10760],[10760,10900],[10900,11120],[11120,11300],[11300,11400],[11400,11580],[11580,11700],[11700,11800],[11800,11920],[11920,12020],[12020,12160],[12160,12320],[12320,12440],[12440,12560],[12560,12740],[12740,12915]]
  269. }],
  270. "text":"正是因为存在绝对正义,所以我们接受现实的相对正义。但是不要因为现实的相对正义,我们就认为这个世界没有正义。因为如果当你认为这个世界没有正义。",
  271. "timestamp":"[[430,670],[670,810],[810,1030],[1030,1130],[1130,1330],[1330,1510],[1510,1670],[1670,1810],[1810,1970],[1970,2270],[2270,2389],[2389,2490],[2490,2570],[2570,2709],[2709,2969],[2969,3310],[3310,3570],[3570,3730],[3730,3830],[3830,3969],[3969,4150],[4150,4270],[4270,4505],[5310,5470],[5470,5610],[5610,5710],[5710,5910],[5910,6069],[6069,6210],[6210,6470],[6470,6650],[6650,6750],[6750,6950],[6950,7130],[7130,7250],[7250,7490],[7490,7590],[7590,7710],[7710,7910],[7910,8069],[8069,8290],[8290,8430],[8430,8550],[8550,8710],[8710,9050],[9050,9370],[9370,9550],[9550,9790],[9790,9965],[10600,10760],[10760,10900],[10900,11120],[11120,11300],[11300,11400],[11400,11580],[11580,11700],[11700,11800],[11800,11920],[11920,12020],[12020,12160],[12160,12320],[12320,12440],[12440,12560],[12560,12740],[12740,12915]]","wav_name":"demo"}
  272. async def message(id):
  273. global websocket,voices,offline_msg_done
  274. text_print = ""
  275. text_print_2pass_online = ""
  276. text_print_2pass_offline = ""
  277. if args.output_dir is not None:
  278. ibest_writer = open(os.path.join(args.output_dir, "text.{}".format(id)), "a", encoding="utf-8")
  279. else:
  280. ibest_writer = None
  281. try:
  282. while True:
  283. meg = await websocket.recv()
  284. meg = json.loads(meg)
  285. wav_name = meg.get("wav_name", "demo")
  286. text = meg["text"]
  287. timestamp=""
  288. offline_msg_done = meg.get("is_final", False)
  289. if "timestamp" in meg:
  290. timestamp = meg["timestamp"]
  291. if ibest_writer is not None:
  292. if timestamp !="":
  293. text_write_line = "{}\t{}\t{}\n".format(wav_name, text, timestamp)
  294. else:
  295. text_write_line = "{}\t{}\n".format(wav_name, text)
  296. ibest_writer.write(text_write_line)
  297. if 'mode' not in meg:
  298. continue
  299. if args.output_json:
  300. json.dump(meg, open(args.output_json, 'w'))
  301. if meg["mode"] == "online":
  302. text_print += "{}".format(text)
  303. text_print = text_print[-args.words_max_print:]
  304. os.system('clear')
  305. print("\rpid" + str(id) + ": " + text_print)
  306. elif meg["mode"] == "offline":
  307. if timestamp !="":
  308. text_print += "{} timestamp: {}".format(text, timestamp)
  309. else:
  310. text_print += "{}".format(text)
  311. # text_print = text_print[-args.words_max_print:]
  312. # os.system('clear')
  313. print("\rpid" + str(id) + ": " + wav_name + ": " + text_print)
  314. offline_msg_done = True
  315. else:
  316. if meg["mode"] == "2pass-online":
  317. text_print_2pass_online += "{}".format(text)
  318. text_print = text_print_2pass_offline + text_print_2pass_online
  319. else:
  320. text_print_2pass_online = ""
  321. text_print = text_print_2pass_offline + "{}".format(text)
  322. text_print_2pass_offline += "{}".format(text)
  323. text_print = text_print[-args.words_max_print:]
  324. os.system('clear')
  325. print("\rpid" + str(id) + ": " + text_print)
  326. # offline_msg_done=True
  327. except Exception as e:
  328. print("Exception:", e)
  329. #traceback.print_exc()
  330. #await websocket.close()
  331. import builtins
  332. async def ws_client(id, chunk_begin, chunk_size, proc_args=None):
  333. global args
  334. args = proc_args
  335. logger.info(f"sub process: {args}")
  336. if args.audio_in is None:
  337. chunk_begin=0
  338. chunk_size=1
  339. global websocket,voices,offline_msg_done
  340. for i in range(chunk_begin,chunk_begin+chunk_size):
  341. offline_msg_done=False
  342. voices = Queue()
  343. if args.ssl == 1:
  344. ssl_context = ssl.SSLContext()
  345. ssl_context.check_hostname = False
  346. ssl_context.verify_mode = ssl.CERT_NONE
  347. uri = "wss://{}:{}".format(args.host, args.port)
  348. else:
  349. uri = "ws://{}:{}".format(args.host, args.port)
  350. ssl_context = None
  351. print("connect to", uri)
  352. async with websockets.connect(uri, subprotocols=["binary"], ping_interval=None, ssl=ssl_context) as websocket:
  353. if args.audio_in is not None:
  354. print(f"i = {i}")
  355. task = asyncio.create_task(record_from_scp(i, 1))
  356. else:
  357. task = asyncio.create_task(record_microphone())
  358. task3 = asyncio.create_task(message(str(id)+"_"+str(i))) #processid+fileid
  359. await asyncio.gather(task, task3)
  360. os._exit(0)
  361. # builtins.exit(0)
  362. # python mtest/funasr_wss_client.py --host "10.0.0.32" --port 10095 --mode offline --output_dir "./results" --ssl 0 --audio_in "/home/user/program/modelscope-whisper/funasr-runtime-resources/models/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/example/asr_example.wav"
  363. def one_thread(id, chunk_begin, chunk_size, args):
  364. asyncio.get_event_loop().run_until_complete(ws_client(id, chunk_begin, chunk_size,args))
  365. asyncio.get_event_loop().run_forever()
  366. if __name__ == '__main__':
  367. import multiprocessing
  368. multiprocessing.freeze_support()
  369. # for microphone
  370. if args.audio_in is None:
  371. logger.info(f"args.audio_in is None")
  372. p = Process(target=one_thread, args=(0, 0, 0))
  373. p.start()
  374. p.join()
  375. print('end')
  376. else:
  377. # calculate the number of wavs for each preocess
  378. if args.audio_in.endswith(".scp"):
  379. f_scp = open(args.audio_in)
  380. wavs = f_scp.readlines()
  381. else:
  382. wavs = [args.audio_in]
  383. for wav in wavs:
  384. wav_splits = wav.strip().split()
  385. wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
  386. wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
  387. audio_type = os.path.splitext(wav_path)[-1].lower()
  388. logger.debug(f"wavs: {wavs}")
  389. total_len = len(wavs)
  390. if total_len >= args.thread_num:
  391. chunk_size = int(total_len / args.thread_num)
  392. remain_wavs = total_len - chunk_size * args.thread_num
  393. else:
  394. chunk_size = 1
  395. remain_wavs = 0
  396. process_list = []
  397. chunk_begin = 0
  398. for i in range(args.thread_num):
  399. now_chunk_size = chunk_size
  400. if remain_wavs > 0:
  401. now_chunk_size = chunk_size + 1
  402. remain_wavs = remain_wavs - 1
  403. # process i handle wavs at chunk_begin and size of now_chunk_size
  404. print("start:", i, chunk_begin, now_chunk_size, args)
  405. p = Process(target=one_thread, args=(i, chunk_begin, now_chunk_size, args))
  406. chunk_begin = chunk_begin + now_chunk_size
  407. p.start()
  408. process_list.append(p)
  409. for i in process_list:
  410. p.join()
  411. print('end')