FunasrWsClient.java 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. //
  2. // Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
  3. // Reserved. MIT License (https://opensource.org/licenses/MIT)
  4. //
  5. /*
  6. * // 2022-2023 by zhaomingwork@qq.com
  7. */
  8. // java FunasrWsClient
  9. // usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS]
  10. // [--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE]
  11. package websocket;
  12. import java.io.*;
  13. import java.net.URI;
  14. import java.net.URISyntaxException;
  15. import java.nio.*;
  16. import java.util.Map;
  17. import net.sourceforge.argparse4j.ArgumentParsers;
  18. import net.sourceforge.argparse4j.inf.ArgumentParser;
  19. import net.sourceforge.argparse4j.inf.ArgumentParserException;
  20. import net.sourceforge.argparse4j.inf.Namespace;
  21. import org.java_websocket.client.WebSocketClient;
  22. import org.java_websocket.drafts.Draft;
  23. import org.java_websocket.handshake.ServerHandshake;
  24. import org.json.simple.JSONArray;
  25. import org.json.simple.JSONObject;
  26. import org.json.simple.parser.JSONParser;
  27. import org.slf4j.Logger;
  28. import org.slf4j.LoggerFactory;
  29. /** This example demonstrates how to connect to websocket server. */
  30. public class FunasrWsClient extends WebSocketClient {
  31. public class RecWavThread extends Thread {
  32. private FunasrWsClient funasrClient;
  33. public RecWavThread(FunasrWsClient funasrClient) {
  34. this.funasrClient = funasrClient;
  35. }
  36. public void run() {
  37. this.funasrClient.recWav();
  38. }
  39. }
  40. private static final Logger logger = LoggerFactory.getLogger(FunasrWsClient.class);
  41. public FunasrWsClient(URI serverUri, Draft draft) {
  42. super(serverUri, draft);
  43. }
  44. public FunasrWsClient(URI serverURI) {
  45. super(serverURI);
  46. }
  47. public FunasrWsClient(URI serverUri, Map<String, String> httpHeaders) {
  48. super(serverUri, httpHeaders);
  49. }
  50. public void getSslContext(String keyfile, String certfile) {
  51. // TODO
  52. return;
  53. }
  54. // send json at first time
  55. public void sendJson(
  56. String mode, String strChunkSize, int chunkInterval, String wavName, boolean isSpeaking,String suffix) {
  57. try {
  58. JSONObject obj = new JSONObject();
  59. obj.put("mode", mode);
  60. JSONArray array = new JSONArray();
  61. String[] chunkList = strChunkSize.split(",");
  62. for (int i = 0; i < chunkList.length; i++) {
  63. array.add(Integer.valueOf(chunkList[i].trim()));
  64. }
  65. obj.put("chunk_size", array);
  66. obj.put("chunk_interval", new Integer(chunkInterval));
  67. obj.put("wav_name", wavName);
  68. if(FunasrWsClient.hotwords.trim().length()>0)
  69. {
  70. obj.put("hotwords", FunasrWsClient.hotwords.trim());
  71. }
  72. if(suffix.equals("wav")){
  73. suffix="pcm";
  74. }
  75. obj.put("wav_format", suffix);
  76. if (isSpeaking) {
  77. obj.put("is_speaking", new Boolean(true));
  78. } else {
  79. obj.put("is_speaking", new Boolean(false));
  80. }
  81. logger.info("sendJson: " + obj);
  82. // return;
  83. send(obj.toString());
  84. return;
  85. } catch (Exception e) {
  86. e.printStackTrace();
  87. }
  88. }
  89. // send json at end of wav
  90. public void sendEof() {
  91. try {
  92. JSONObject obj = new JSONObject();
  93. obj.put("is_speaking", new Boolean(false));
  94. logger.info("sendEof: " + obj);
  95. // return;
  96. send(obj.toString());
  97. iseof = true;
  98. return;
  99. } catch (Exception e) {
  100. e.printStackTrace();
  101. }
  102. }
  103. // function for rec wav file
  104. public void recWav() {
  105. String fileName=FunasrWsClient.wavPath;
  106. String suffix=fileName.split("\\.")[fileName.split("\\.").length-1];
  107. sendJson(mode, strChunkSize, chunkInterval, wavName, true,suffix);
  108. File file = new File(FunasrWsClient.wavPath);
  109. int chunkSize = sendChunkSize;
  110. byte[] bytes = new byte[chunkSize];
  111. int readSize = 0;
  112. try (FileInputStream fis = new FileInputStream(file)) {
  113. if (FunasrWsClient.wavPath.endsWith(".wav")) {
  114. fis.read(bytes, 0, 44); //skip first 44 wav header
  115. }
  116. readSize = fis.read(bytes, 0, chunkSize);
  117. while (readSize > 0) {
  118. // send when it is chunk size
  119. if (readSize == chunkSize) {
  120. send(bytes); // send buf to server
  121. } else {
  122. // send when at last or not is chunk size
  123. byte[] tmpBytes = new byte[readSize];
  124. for (int i = 0; i < readSize; i++) {
  125. tmpBytes[i] = bytes[i];
  126. }
  127. send(tmpBytes);
  128. }
  129. // if not in offline mode, we simulate online stream by sleep
  130. if (!mode.equals("offline")) {
  131. Thread.sleep(Integer.valueOf(chunkSize / 32));
  132. }
  133. readSize = fis.read(bytes, 0, chunkSize);
  134. }
  135. if (!mode.equals("offline")) {
  136. // if not offline, we send eof and wait for 3 seconds to close
  137. Thread.sleep(2000);
  138. sendEof();
  139. Thread.sleep(3000);
  140. close();
  141. } else {
  142. // if offline, just send eof
  143. sendEof();
  144. }
  145. } catch (Exception e) {
  146. e.printStackTrace();
  147. }
  148. }
  149. @Override
  150. public void onOpen(ServerHandshake handshakedata) {
  151. RecWavThread thread = new RecWavThread(this);
  152. thread.start();
  153. }
  154. @Override
  155. public void onMessage(String message) {
  156. JSONObject jsonObject = new JSONObject();
  157. JSONParser jsonParser = new JSONParser();
  158. logger.info("received: " + message);
  159. try {
  160. jsonObject = (JSONObject) jsonParser.parse(message);
  161. logger.info("text: " + jsonObject.get("text"));
  162. if(jsonObject.containsKey("timestamp"))
  163. {
  164. logger.info("timestamp: " + jsonObject.get("timestamp"));
  165. }
  166. } catch (org.json.simple.parser.ParseException e) {
  167. e.printStackTrace();
  168. }
  169. if (iseof && mode.equals("offline") && !jsonObject.containsKey("is_final")) {
  170. close();
  171. }
  172. if (iseof && mode.equals("offline") && jsonObject.containsKey("is_final") && jsonObject.get("is_final").equals("false")) {
  173. close();
  174. }
  175. }
  176. @Override
  177. public void onClose(int code, String reason, boolean remote) {
  178. logger.info(
  179. "Connection closed by "
  180. + (remote ? "remote peer" : "us")
  181. + " Code: "
  182. + code
  183. + " Reason: "
  184. + reason);
  185. }
  186. @Override
  187. public void onError(Exception ex) {
  188. logger.info("ex: " + ex);
  189. ex.printStackTrace();
  190. // if the error is fatal then onClose will be called additionally
  191. }
  192. private boolean iseof = false;
  193. public static String wavPath;
  194. static String mode = "online";
  195. static String strChunkSize = "5,10,5";
  196. static int chunkInterval = 10;
  197. static int sendChunkSize = 1920;
  198. static String hotwords="";
  199. String wavName = "javatest";
  200. public static void main(String[] args) throws URISyntaxException {
  201. ArgumentParser parser = ArgumentParsers.newArgumentParser("ws client").defaultHelp(true);
  202. parser
  203. .addArgument("--port")
  204. .help("Port on which to listen.")
  205. .setDefault("8889")
  206. .type(String.class)
  207. .required(false);
  208. parser
  209. .addArgument("--host")
  210. .help("the IP address of server.")
  211. .setDefault("127.0.0.1")
  212. .type(String.class)
  213. .required(false);
  214. parser
  215. .addArgument("--audio_in")
  216. .help("wav path for decoding.")
  217. .setDefault("asr_example.wav")
  218. .type(String.class)
  219. .required(false);
  220. parser
  221. .addArgument("--num_threads")
  222. .help("num of threads for test.")
  223. .setDefault(1)
  224. .type(Integer.class)
  225. .required(false);
  226. parser
  227. .addArgument("--chunk_size")
  228. .help("chunk size for asr.")
  229. .setDefault("5, 10, 5")
  230. .type(String.class)
  231. .required(false);
  232. parser
  233. .addArgument("--chunk_interval")
  234. .help("chunk for asr.")
  235. .setDefault(10)
  236. .type(Integer.class)
  237. .required(false);
  238. parser
  239. .addArgument("--mode")
  240. .help("mode for asr.")
  241. .setDefault("offline")
  242. .type(String.class)
  243. .required(false);
  244. parser
  245. .addArgument("--hotwords")
  246. .help("hotwords, splited by space")
  247. .setDefault("")
  248. .type(String.class)
  249. .required(false);
  250. String srvIp = "";
  251. String srvPort = "";
  252. String wavPath = "";
  253. int numThreads = 1;
  254. String chunk_size = "";
  255. int chunk_interval = 10;
  256. String strmode = "offline";
  257. String hot="";
  258. try {
  259. Namespace ns = parser.parseArgs(args);
  260. srvIp = ns.get("host");
  261. srvPort = ns.get("port");
  262. wavPath = ns.get("audio_in");
  263. numThreads = ns.get("num_threads");
  264. chunk_size = ns.get("chunk_size");
  265. chunk_interval = ns.get("chunk_interval");
  266. strmode = ns.get("mode");
  267. hot=ns.get("hotwords");
  268. System.out.println(srvPort);
  269. } catch (ArgumentParserException ex) {
  270. ex.getParser().handleError(ex);
  271. return;
  272. }
  273. FunasrWsClient.strChunkSize = chunk_size;
  274. FunasrWsClient.chunkInterval = chunk_interval;
  275. FunasrWsClient.wavPath = wavPath;
  276. FunasrWsClient.mode = strmode;
  277. FunasrWsClient.hotwords=hot;
  278. System.out.println(
  279. "serIp="
  280. + srvIp
  281. + ",srvPort="
  282. + srvPort
  283. + ",wavPath="
  284. + wavPath
  285. + ",strChunkSize"
  286. + strChunkSize);
  287. class ClientThread implements Runnable {
  288. String srvIp;
  289. String srvPort;
  290. ClientThread(String srvIp, String srvPort, String wavPath) {
  291. this.srvIp = srvIp;
  292. this.srvPort = srvPort;
  293. }
  294. public void run() {
  295. try {
  296. int RATE = 16000;
  297. String[] chunkList = strChunkSize.split(",");
  298. int int_chunk_size = 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval;
  299. int CHUNK = Integer.valueOf(RATE / 1000 * int_chunk_size);
  300. int stride =
  301. Integer.valueOf(
  302. 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval / 1000 * 16000 * 2);
  303. System.out.println("chunk_size:" + String.valueOf(int_chunk_size));
  304. System.out.println("CHUNK:" + CHUNK);
  305. System.out.println("stride:" + String.valueOf(stride));
  306. FunasrWsClient.sendChunkSize = CHUNK * 2;
  307. String wsAddress = "ws://" + srvIp + ":" + srvPort;
  308. FunasrWsClient c = new FunasrWsClient(new URI(wsAddress));
  309. c.connect();
  310. System.out.println("wsAddress:" + wsAddress);
  311. } catch (Exception e) {
  312. e.printStackTrace();
  313. System.out.println("e:" + e);
  314. }
  315. }
  316. }
  317. for (int i = 0; i < numThreads; i++) {
  318. System.out.println("Thread1 is running...");
  319. Thread t = new Thread(new ClientThread(srvIp, srvPort, wavPath));
  320. t.start();
  321. }
  322. }
  323. }