FunasrWsClient.java 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. //
  2. // Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
  3. // Reserved. MIT License (https://opensource.org/licenses/MIT)
  4. //
  5. /*
  6. * // 2022-2023 by zhaomingwork@qq.com
  7. */
  8. // java FunasrWsClient
  9. // usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS]
  10. // [--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE]
  11. package websocket;
  12. import java.io.*;
  13. import java.net.URI;
  14. import java.net.URISyntaxException;
  15. import java.nio.*;
  16. import java.util.Map;
  17. import net.sourceforge.argparse4j.ArgumentParsers;
  18. import net.sourceforge.argparse4j.inf.ArgumentParser;
  19. import net.sourceforge.argparse4j.inf.ArgumentParserException;
  20. import net.sourceforge.argparse4j.inf.Namespace;
  21. import org.java_websocket.client.WebSocketClient;
  22. import org.java_websocket.drafts.Draft;
  23. import org.java_websocket.handshake.ServerHandshake;
  24. import org.json.simple.JSONArray;
  25. import org.json.simple.JSONObject;
  26. import org.json.simple.parser.JSONParser;
  27. import org.slf4j.Logger;
  28. import org.slf4j.LoggerFactory;
  29. /** This example demonstrates how to connect to websocket server. */
  30. public class FunasrWsClient extends WebSocketClient {
  31. public class RecWavThread extends Thread {
  32. private FunasrWsClient funasrClient;
  33. public RecWavThread(FunasrWsClient funasrClient) {
  34. this.funasrClient = funasrClient;
  35. }
  36. public void run() {
  37. this.funasrClient.recWav();
  38. }
  39. }
  40. private static final Logger logger = LoggerFactory.getLogger(FunasrWsClient.class);
  41. public FunasrWsClient(URI serverUri, Draft draft) {
  42. super(serverUri, draft);
  43. }
  44. public FunasrWsClient(URI serverURI) {
  45. super(serverURI);
  46. }
  47. public FunasrWsClient(URI serverUri, Map<String, String> httpHeaders) {
  48. super(serverUri, httpHeaders);
  49. }
  50. public void getSslContext(String keyfile, String certfile) {
  51. // TODO
  52. return;
  53. }
  54. // send json at first time
  55. public void sendJson(
  56. String mode, String strChunkSize, int chunkInterval, String wavName, boolean isSpeaking) {
  57. try {
  58. JSONObject obj = new JSONObject();
  59. obj.put("mode", mode);
  60. JSONArray array = new JSONArray();
  61. String[] chunkList = strChunkSize.split(",");
  62. for (int i = 0; i < chunkList.length; i++) {
  63. array.add(Integer.valueOf(chunkList[i].trim()));
  64. }
  65. obj.put("chunk_size", array);
  66. obj.put("chunk_interval", new Integer(chunkInterval));
  67. obj.put("wav_name", wavName);
  68. if (isSpeaking) {
  69. obj.put("is_speaking", new Boolean(true));
  70. } else {
  71. obj.put("is_speaking", new Boolean(false));
  72. }
  73. logger.info("sendJson: " + obj);
  74. // return;
  75. send(obj.toString());
  76. return;
  77. } catch (Exception e) {
  78. e.printStackTrace();
  79. }
  80. }
  81. // send json at end of wav
  82. public void sendEof() {
  83. try {
  84. JSONObject obj = new JSONObject();
  85. obj.put("is_speaking", new Boolean(false));
  86. logger.info("sendEof: " + obj);
  87. // return;
  88. send(obj.toString());
  89. iseof = true;
  90. return;
  91. } catch (Exception e) {
  92. e.printStackTrace();
  93. }
  94. }
  95. // function for rec wav file
  96. public void recWav() {
  97. sendJson(mode, strChunkSize, chunkInterval, wavName, true);
  98. File file = new File(FunasrWsClient.wavPath);
  99. int chunkSize = sendChunkSize;
  100. byte[] bytes = new byte[chunkSize];
  101. int readSize = 0;
  102. try (FileInputStream fis = new FileInputStream(file)) {
  103. if (FunasrWsClient.wavPath.endsWith(".wav")) {
  104. fis.read(bytes, 0, 44); //skip first 44 wav header
  105. }
  106. readSize = fis.read(bytes, 0, chunkSize);
  107. while (readSize > 0) {
  108. // send when it is chunk size
  109. if (readSize == chunkSize) {
  110. send(bytes); // send buf to server
  111. } else {
  112. // send when at last or not is chunk size
  113. byte[] tmpBytes = new byte[readSize];
  114. for (int i = 0; i < readSize; i++) {
  115. tmpBytes[i] = bytes[i];
  116. }
  117. send(tmpBytes);
  118. }
  119. // if not in offline mode, we simulate online stream by sleep
  120. if (!mode.equals("offline")) {
  121. Thread.sleep(Integer.valueOf(chunkSize / 32));
  122. }
  123. readSize = fis.read(bytes, 0, chunkSize);
  124. }
  125. if (!mode.equals("offline")) {
  126. // if not offline, we send eof and wait for 3 seconds to close
  127. Thread.sleep(2000);
  128. sendEof();
  129. Thread.sleep(3000);
  130. close();
  131. } else {
  132. // if offline, just send eof
  133. sendEof();
  134. }
  135. } catch (Exception e) {
  136. e.printStackTrace();
  137. }
  138. }
  139. @Override
  140. public void onOpen(ServerHandshake handshakedata) {
  141. RecWavThread thread = new RecWavThread(this);
  142. thread.start();
  143. }
  144. @Override
  145. public void onMessage(String message) {
  146. JSONObject jsonObject = new JSONObject();
  147. JSONParser jsonParser = new JSONParser();
  148. logger.info("received: " + message);
  149. try {
  150. jsonObject = (JSONObject) jsonParser.parse(message);
  151. logger.info("text: " + jsonObject.get("text"));
  152. } catch (org.json.simple.parser.ParseException e) {
  153. e.printStackTrace();
  154. }
  155. if (iseof && mode.equals("offline")) {
  156. close();
  157. }
  158. }
  159. @Override
  160. public void onClose(int code, String reason, boolean remote) {
  161. logger.info(
  162. "Connection closed by "
  163. + (remote ? "remote peer" : "us")
  164. + " Code: "
  165. + code
  166. + " Reason: "
  167. + reason);
  168. }
  169. @Override
  170. public void onError(Exception ex) {
  171. logger.info("ex: " + ex);
  172. ex.printStackTrace();
  173. // if the error is fatal then onClose will be called additionally
  174. }
  175. private boolean iseof = false;
  176. public static String wavPath;
  177. static String mode = "online";
  178. static String strChunkSize = "5,10,5";
  179. static int chunkInterval = 10;
  180. static int sendChunkSize = 1920;
  181. String wavName = "javatest";
  182. public static void main(String[] args) throws URISyntaxException {
  183. ArgumentParser parser = ArgumentParsers.newArgumentParser("ws client").defaultHelp(true);
  184. parser
  185. .addArgument("--port")
  186. .help("Port on which to listen.")
  187. .setDefault("8889")
  188. .type(String.class)
  189. .required(false);
  190. parser
  191. .addArgument("--host")
  192. .help("the IP address of server.")
  193. .setDefault("127.0.0.1")
  194. .type(String.class)
  195. .required(false);
  196. parser
  197. .addArgument("--audio_in")
  198. .help("wav path for decoding.")
  199. .setDefault("asr_example.wav")
  200. .type(String.class)
  201. .required(false);
  202. parser
  203. .addArgument("--num_threads")
  204. .help("num of threads for test.")
  205. .setDefault(1)
  206. .type(Integer.class)
  207. .required(false);
  208. parser
  209. .addArgument("--chunk_size")
  210. .help("chunk size for asr.")
  211. .setDefault("5, 10, 5")
  212. .type(String.class)
  213. .required(false);
  214. parser
  215. .addArgument("--chunk_interval")
  216. .help("chunk for asr.")
  217. .setDefault(10)
  218. .type(Integer.class)
  219. .required(false);
  220. parser
  221. .addArgument("--mode")
  222. .help("mode for asr.")
  223. .setDefault("offline")
  224. .type(String.class)
  225. .required(false);
  226. String srvIp = "";
  227. String srvPort = "";
  228. String wavPath = "";
  229. int numThreads = 1;
  230. String chunk_size = "";
  231. int chunk_interval = 10;
  232. String strmode = "offline";
  233. try {
  234. Namespace ns = parser.parseArgs(args);
  235. srvIp = ns.get("host");
  236. srvPort = ns.get("port");
  237. wavPath = ns.get("audio_in");
  238. numThreads = ns.get("num_threads");
  239. chunk_size = ns.get("chunk_size");
  240. chunk_interval = ns.get("chunk_interval");
  241. strmode = ns.get("mode");
  242. System.out.println(srvPort);
  243. } catch (ArgumentParserException ex) {
  244. ex.getParser().handleError(ex);
  245. return;
  246. }
  247. FunasrWsClient.strChunkSize = chunk_size;
  248. FunasrWsClient.chunkInterval = chunk_interval;
  249. FunasrWsClient.wavPath = wavPath;
  250. FunasrWsClient.mode = strmode;
  251. System.out.println(
  252. "serIp="
  253. + srvIp
  254. + ",srvPort="
  255. + srvPort
  256. + ",wavPath="
  257. + wavPath
  258. + ",strChunkSize"
  259. + strChunkSize);
  260. class ClientThread implements Runnable {
  261. String srvIp;
  262. String srvPort;
  263. ClientThread(String srvIp, String srvPort, String wavPath) {
  264. this.srvIp = srvIp;
  265. this.srvPort = srvPort;
  266. }
  267. public void run() {
  268. try {
  269. int RATE = 16000;
  270. String[] chunkList = strChunkSize.split(",");
  271. int int_chunk_size = 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval;
  272. int CHUNK = Integer.valueOf(RATE / 1000 * int_chunk_size);
  273. int stride =
  274. Integer.valueOf(
  275. 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval / 1000 * 16000 * 2);
  276. System.out.println("chunk_size:" + String.valueOf(int_chunk_size));
  277. System.out.println("CHUNK:" + CHUNK);
  278. System.out.println("stride:" + String.valueOf(stride));
  279. FunasrWsClient.sendChunkSize = CHUNK * 2;
  280. String wsAddress = "ws://" + srvIp + ":" + srvPort;
  281. FunasrWsClient c = new FunasrWsClient(new URI(wsAddress));
  282. c.connect();
  283. System.out.println("wsAddress:" + wsAddress);
  284. } catch (Exception e) {
  285. e.printStackTrace();
  286. System.out.println("e:" + e);
  287. }
  288. }
  289. }
  290. for (int i = 0; i < numThreads; i++) {
  291. System.out.println("Thread1 is running...");
  292. Thread t = new Thread(new ClientThread(srvIp, srvPort, wavPath));
  293. t.start();
  294. }
  295. }
  296. }