FunasrWsClient.java 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. //
  2. // Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
  3. // Reserved. MIT License (https://opensource.org/licenses/MIT)
  4. //
  5. /*
  6. * // 2022-2023 by zhaomingwork@qq.com
  7. */
  8. // java FunasrWsClient
  9. // usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS]
  10. // [--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE]
  11. package websocket;
  12. import java.io.*;
  13. import java.net.URI;
  14. import java.net.URISyntaxException;
  15. import java.nio.*;
  16. import java.util.Map;
  17. import net.sourceforge.argparse4j.ArgumentParsers;
  18. import net.sourceforge.argparse4j.inf.ArgumentParser;
  19. import net.sourceforge.argparse4j.inf.ArgumentParserException;
  20. import net.sourceforge.argparse4j.inf.Namespace;
  21. import org.java_websocket.client.WebSocketClient;
  22. import org.java_websocket.drafts.Draft;
  23. import org.java_websocket.handshake.ServerHandshake;
  24. import org.json.simple.JSONArray;
  25. import org.json.simple.JSONObject;
  26. import org.json.simple.parser.JSONParser;
  27. import org.slf4j.Logger;
  28. import org.slf4j.LoggerFactory;
  29. /** This example demonstrates how to connect to websocket server. */
  30. public class FunasrWsClient extends WebSocketClient {
  31. public class RecWavThread extends Thread {
  32. private FunasrWsClient funasrClient;
  33. public RecWavThread(FunasrWsClient funasrClient) {
  34. this.funasrClient = funasrClient;
  35. }
  36. public void run() {
  37. this.funasrClient.recWav();
  38. }
  39. }
  40. private static final Logger logger = LoggerFactory.getLogger(FunasrWsClient.class);
  41. public FunasrWsClient(URI serverUri, Draft draft) {
  42. super(serverUri, draft);
  43. }
  44. public FunasrWsClient(URI serverURI) {
  45. super(serverURI);
  46. }
  47. public FunasrWsClient(URI serverUri, Map<String, String> httpHeaders) {
  48. super(serverUri, httpHeaders);
  49. }
  50. public void getSslContext(String keyfile, String certfile) {
  51. // TODO
  52. return;
  53. }
  54. // send json at first time
  55. public void sendJson(
  56. String mode, String strChunkSize, int chunkInterval, String wavName, boolean isSpeaking) {
  57. try {
  58. JSONObject obj = new JSONObject();
  59. obj.put("mode", mode);
  60. JSONArray array = new JSONArray();
  61. String[] chunkList = strChunkSize.split(",");
  62. for (int i = 0; i < chunkList.length; i++) {
  63. array.add(Integer.valueOf(chunkList[i].trim()));
  64. }
  65. obj.put("chunk_size", array);
  66. obj.put("chunk_interval", new Integer(chunkInterval));
  67. obj.put("wav_name", wavName);
  68. if (isSpeaking) {
  69. obj.put("is_speaking", new Boolean(true));
  70. } else {
  71. obj.put("is_speaking", new Boolean(false));
  72. }
  73. logger.info("sendJson: " + obj);
  74. // return;
  75. send(obj.toString());
  76. return;
  77. } catch (Exception e) {
  78. e.printStackTrace();
  79. }
  80. }
  81. // send json at end of wav
  82. public void sendEof() {
  83. try {
  84. JSONObject obj = new JSONObject();
  85. obj.put("is_speaking", new Boolean(false));
  86. logger.info("sendEof: " + obj);
  87. // return;
  88. send(obj.toString());
  89. iseof = true;
  90. return;
  91. } catch (Exception e) {
  92. e.printStackTrace();
  93. }
  94. }
  95. // function for rec wav file
  96. public void recWav() {
  97. sendJson(mode, strChunkSize, chunkInterval, wavName, true);
  98. File file = new File(FunasrWsClient.wavPath);
  99. int chunkSize = sendChunkSize;
  100. byte[] bytes = new byte[chunkSize];
  101. int readSize = 0;
  102. try (FileInputStream fis = new FileInputStream(file)) {
  103. if (FunasrWsClient.wavPath.endsWith(".wav")) {
  104. fis.read(bytes, 0, 44); //skip first 44 wav header
  105. }
  106. readSize = fis.read(bytes, 0, chunkSize);
  107. while (readSize > 0) {
  108. // send when it is chunk size
  109. if (readSize == chunkSize) {
  110. send(bytes); // send buf to server
  111. } else {
  112. // send when at last or not is chunk size
  113. byte[] tmpBytes = new byte[readSize];
  114. for (int i = 0; i < readSize; i++) {
  115. tmpBytes[i] = bytes[i];
  116. }
  117. send(tmpBytes);
  118. }
  119. // if not in offline mode, we simulate online stream by sleep
  120. if (!mode.equals("offline")) {
  121. Thread.sleep(Integer.valueOf(chunkSize / 32));
  122. }
  123. readSize = fis.read(bytes, 0, chunkSize);
  124. }
  125. if (!mode.equals("offline")) {
  126. // if not offline, we send eof and wait for 3 seconds to close
  127. Thread.sleep(2000);
  128. sendEof();
  129. Thread.sleep(3000);
  130. close();
  131. } else {
  132. // if offline, just send eof
  133. sendEof();
  134. }
  135. } catch (Exception e) {
  136. e.printStackTrace();
  137. }
  138. }
  139. @Override
  140. public void onOpen(ServerHandshake handshakedata) {
  141. RecWavThread thread = new RecWavThread(this);
  142. thread.start();
  143. }
  144. @Override
  145. public void onMessage(String message) {
  146. JSONObject jsonObject = new JSONObject();
  147. JSONParser jsonParser = new JSONParser();
  148. logger.info("received: " + message);
  149. try {
  150. jsonObject = (JSONObject) jsonParser.parse(message);
  151. logger.info("text: " + jsonObject.get("text"));
  152. } catch (org.json.simple.parser.ParseException e) {
  153. e.printStackTrace();
  154. }
  155. if (iseof && mode.equals("offline") && !jsonObject.containsKey("is_final")) {
  156. close();
  157. }
  158. if (iseof && mode.equals("offline") && jsonObject.containsKey("is_final") && jsonObject.get("is_final").equals("false")) {
  159. close();
  160. }
  161. }
  162. @Override
  163. public void onClose(int code, String reason, boolean remote) {
  164. logger.info(
  165. "Connection closed by "
  166. + (remote ? "remote peer" : "us")
  167. + " Code: "
  168. + code
  169. + " Reason: "
  170. + reason);
  171. }
  172. @Override
  173. public void onError(Exception ex) {
  174. logger.info("ex: " + ex);
  175. ex.printStackTrace();
  176. // if the error is fatal then onClose will be called additionally
  177. }
  178. private boolean iseof = false;
  179. public static String wavPath;
  180. static String mode = "online";
  181. static String strChunkSize = "5,10,5";
  182. static int chunkInterval = 10;
  183. static int sendChunkSize = 1920;
  184. String wavName = "javatest";
  185. public static void main(String[] args) throws URISyntaxException {
  186. ArgumentParser parser = ArgumentParsers.newArgumentParser("ws client").defaultHelp(true);
  187. parser
  188. .addArgument("--port")
  189. .help("Port on which to listen.")
  190. .setDefault("8889")
  191. .type(String.class)
  192. .required(false);
  193. parser
  194. .addArgument("--host")
  195. .help("the IP address of server.")
  196. .setDefault("127.0.0.1")
  197. .type(String.class)
  198. .required(false);
  199. parser
  200. .addArgument("--audio_in")
  201. .help("wav path for decoding.")
  202. .setDefault("asr_example.wav")
  203. .type(String.class)
  204. .required(false);
  205. parser
  206. .addArgument("--num_threads")
  207. .help("num of threads for test.")
  208. .setDefault(1)
  209. .type(Integer.class)
  210. .required(false);
  211. parser
  212. .addArgument("--chunk_size")
  213. .help("chunk size for asr.")
  214. .setDefault("5, 10, 5")
  215. .type(String.class)
  216. .required(false);
  217. parser
  218. .addArgument("--chunk_interval")
  219. .help("chunk for asr.")
  220. .setDefault(10)
  221. .type(Integer.class)
  222. .required(false);
  223. parser
  224. .addArgument("--mode")
  225. .help("mode for asr.")
  226. .setDefault("offline")
  227. .type(String.class)
  228. .required(false);
  229. String srvIp = "";
  230. String srvPort = "";
  231. String wavPath = "";
  232. int numThreads = 1;
  233. String chunk_size = "";
  234. int chunk_interval = 10;
  235. String strmode = "offline";
  236. try {
  237. Namespace ns = parser.parseArgs(args);
  238. srvIp = ns.get("host");
  239. srvPort = ns.get("port");
  240. wavPath = ns.get("audio_in");
  241. numThreads = ns.get("num_threads");
  242. chunk_size = ns.get("chunk_size");
  243. chunk_interval = ns.get("chunk_interval");
  244. strmode = ns.get("mode");
  245. System.out.println(srvPort);
  246. } catch (ArgumentParserException ex) {
  247. ex.getParser().handleError(ex);
  248. return;
  249. }
  250. FunasrWsClient.strChunkSize = chunk_size;
  251. FunasrWsClient.chunkInterval = chunk_interval;
  252. FunasrWsClient.wavPath = wavPath;
  253. FunasrWsClient.mode = strmode;
  254. System.out.println(
  255. "serIp="
  256. + srvIp
  257. + ",srvPort="
  258. + srvPort
  259. + ",wavPath="
  260. + wavPath
  261. + ",strChunkSize"
  262. + strChunkSize);
  263. class ClientThread implements Runnable {
  264. String srvIp;
  265. String srvPort;
  266. ClientThread(String srvIp, String srvPort, String wavPath) {
  267. this.srvIp = srvIp;
  268. this.srvPort = srvPort;
  269. }
  270. public void run() {
  271. try {
  272. int RATE = 16000;
  273. String[] chunkList = strChunkSize.split(",");
  274. int int_chunk_size = 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval;
  275. int CHUNK = Integer.valueOf(RATE / 1000 * int_chunk_size);
  276. int stride =
  277. Integer.valueOf(
  278. 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval / 1000 * 16000 * 2);
  279. System.out.println("chunk_size:" + String.valueOf(int_chunk_size));
  280. System.out.println("CHUNK:" + CHUNK);
  281. System.out.println("stride:" + String.valueOf(stride));
  282. FunasrWsClient.sendChunkSize = CHUNK * 2;
  283. String wsAddress = "ws://" + srvIp + ":" + srvPort;
  284. FunasrWsClient c = new FunasrWsClient(new URI(wsAddress));
  285. c.connect();
  286. System.out.println("wsAddress:" + wsAddress);
  287. } catch (Exception e) {
  288. e.printStackTrace();
  289. System.out.println("e:" + e);
  290. }
  291. }
  292. }
  293. for (int i = 0; i < numThreads; i++) {
  294. System.out.println("Thread1 is running...");
  295. Thread t = new Thread(new ClientThread(srvIp, srvPort, wavPath));
  296. t.start();
  297. }
  298. }
  299. }