FunasrWsClient.java 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. //
  2. // Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
  3. // Reserved. MIT License (https://opensource.org/licenses/MIT)
  4. //
  5. /*
  6. * // 2022-2023 by zhaomingwork@qq.com
  7. */
  8. // java FunasrWsClient
  9. // usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS]
  10. // [--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE]
  11. package websocket;
  12. import java.io.*;
  13. import java.net.URI;
  14. import java.net.URISyntaxException;
  15. import java.nio.*;
  16. import java.util.Map;
  17. import net.sourceforge.argparse4j.ArgumentParsers;
  18. import net.sourceforge.argparse4j.inf.ArgumentParser;
  19. import net.sourceforge.argparse4j.inf.ArgumentParserException;
  20. import net.sourceforge.argparse4j.inf.Namespace;
  21. import org.java_websocket.client.WebSocketClient;
  22. import org.java_websocket.drafts.Draft;
  23. import org.java_websocket.handshake.ServerHandshake;
  24. import org.json.simple.JSONArray;
  25. import org.json.simple.JSONObject;
  26. import org.json.simple.parser.JSONParser;
  27. import org.slf4j.Logger;
  28. import org.slf4j.LoggerFactory;
  29. import java.util.regex.Matcher;
  30. import java.util.regex.Pattern;
  31. /** This example demonstrates how to connect to websocket server. */
  32. public class FunasrWsClient extends WebSocketClient {
  33. public class RecWavThread extends Thread {
  34. private FunasrWsClient funasrClient;
  35. public RecWavThread(FunasrWsClient funasrClient) {
  36. this.funasrClient = funasrClient;
  37. }
  38. public void run() {
  39. this.funasrClient.recWav();
  40. }
  41. }
  42. private static final Logger logger = LoggerFactory.getLogger(FunasrWsClient.class);
  43. public FunasrWsClient(URI serverUri, Draft draft) {
  44. super(serverUri, draft);
  45. }
  46. public FunasrWsClient(URI serverURI) {
  47. super(serverURI);
  48. }
  49. public FunasrWsClient(URI serverUri, Map<String, String> httpHeaders) {
  50. super(serverUri, httpHeaders);
  51. }
  52. public void getSslContext(String keyfile, String certfile) {
  53. // TODO
  54. return;
  55. }
  56. // send json at first time
  57. public void sendJson(
  58. String mode, String strChunkSize, int chunkInterval, String wavName, boolean isSpeaking,String suffix) {
  59. try {
  60. JSONObject obj = new JSONObject();
  61. obj.put("mode", mode);
  62. JSONArray array = new JSONArray();
  63. String[] chunkList = strChunkSize.split(",");
  64. for (int i = 0; i < chunkList.length; i++) {
  65. array.add(Integer.valueOf(chunkList[i].trim()));
  66. }
  67. obj.put("chunk_size", array);
  68. obj.put("chunk_interval", new Integer(chunkInterval));
  69. obj.put("wav_name", wavName);
  70. if(FunasrWsClient.hotwords.trim().length()>0)
  71. {
  72. String regex = "\\d+";
  73. JSONObject jsonitems = new JSONObject();
  74. String[] items=FunasrWsClient.hotwords.trim().split(" ");
  75. Pattern pattern = Pattern.compile(regex);
  76. String tmpWords="";
  77. for(int i=0;i<items.length;i++)
  78. {
  79. Matcher matcher = pattern.matcher(items[i]);
  80. if (matcher.matches()) {
  81. jsonitems.put(tmpWords.trim(), items[i].trim());
  82. tmpWords="";
  83. continue;
  84. }
  85. tmpWords=tmpWords+items[i]+" ";
  86. }
  87. obj.put("hotwords", jsonitems.toString());
  88. }
  89. if(suffix.equals("wav")){
  90. suffix="pcm";
  91. }
  92. obj.put("wav_format", suffix);
  93. if (isSpeaking) {
  94. obj.put("is_speaking", new Boolean(true));
  95. } else {
  96. obj.put("is_speaking", new Boolean(false));
  97. }
  98. logger.info("sendJson: " + obj);
  99. // return;
  100. send(obj.toString());
  101. return;
  102. } catch (Exception e) {
  103. e.printStackTrace();
  104. }
  105. }
  106. // send json at end of wav
  107. public void sendEof() {
  108. try {
  109. JSONObject obj = new JSONObject();
  110. obj.put("is_speaking", new Boolean(false));
  111. logger.info("sendEof: " + obj);
  112. // return;
  113. send(obj.toString());
  114. iseof = true;
  115. return;
  116. } catch (Exception e) {
  117. e.printStackTrace();
  118. }
  119. }
  120. // function for rec wav file
  121. public void recWav() {
  122. String fileName=FunasrWsClient.wavPath;
  123. String suffix=fileName.split("\\.")[fileName.split("\\.").length-1];
  124. sendJson(mode, strChunkSize, chunkInterval, wavName, true,suffix);
  125. File file = new File(FunasrWsClient.wavPath);
  126. int chunkSize = sendChunkSize;
  127. byte[] bytes = new byte[chunkSize];
  128. int readSize = 0;
  129. try (FileInputStream fis = new FileInputStream(file)) {
  130. if (FunasrWsClient.wavPath.endsWith(".wav")) {
  131. fis.read(bytes, 0, 44); //skip first 44 wav header
  132. }
  133. readSize = fis.read(bytes, 0, chunkSize);
  134. while (readSize > 0) {
  135. // send when it is chunk size
  136. if (readSize == chunkSize) {
  137. send(bytes); // send buf to server
  138. } else {
  139. // send when at last or not is chunk size
  140. byte[] tmpBytes = new byte[readSize];
  141. for (int i = 0; i < readSize; i++) {
  142. tmpBytes[i] = bytes[i];
  143. }
  144. send(tmpBytes);
  145. }
  146. // if not in offline mode, we simulate online stream by sleep
  147. if (!mode.equals("offline")) {
  148. Thread.sleep(Integer.valueOf(chunkSize / 32));
  149. }
  150. readSize = fis.read(bytes, 0, chunkSize);
  151. }
  152. if (!mode.equals("offline")) {
  153. // if not offline, we send eof and wait for 3 seconds to close
  154. Thread.sleep(2000);
  155. sendEof();
  156. Thread.sleep(3000);
  157. close();
  158. } else {
  159. // if offline, just send eof
  160. sendEof();
  161. }
  162. } catch (Exception e) {
  163. e.printStackTrace();
  164. }
  165. }
  166. @Override
  167. public void onOpen(ServerHandshake handshakedata) {
  168. RecWavThread thread = new RecWavThread(this);
  169. thread.start();
  170. }
  171. @Override
  172. public void onMessage(String message) {
  173. JSONObject jsonObject = new JSONObject();
  174. JSONParser jsonParser = new JSONParser();
  175. logger.info("received: " + message);
  176. try {
  177. jsonObject = (JSONObject) jsonParser.parse(message);
  178. logger.info("text: " + jsonObject.get("text"));
  179. if(jsonObject.containsKey("timestamp"))
  180. {
  181. logger.info("timestamp: " + jsonObject.get("timestamp"));
  182. }
  183. } catch (org.json.simple.parser.ParseException e) {
  184. e.printStackTrace();
  185. }
  186. if (iseof && mode.equals("offline") && !jsonObject.containsKey("is_final")) {
  187. close();
  188. }
  189. if (iseof && mode.equals("offline") && jsonObject.containsKey("is_final") && jsonObject.get("is_final").equals("false")) {
  190. close();
  191. }
  192. }
  193. @Override
  194. public void onClose(int code, String reason, boolean remote) {
  195. logger.info(
  196. "Connection closed by "
  197. + (remote ? "remote peer" : "us")
  198. + " Code: "
  199. + code
  200. + " Reason: "
  201. + reason);
  202. }
  203. @Override
  204. public void onError(Exception ex) {
  205. logger.info("ex: " + ex);
  206. ex.printStackTrace();
  207. // if the error is fatal then onClose will be called additionally
  208. }
  209. private boolean iseof = false;
  210. public static String wavPath;
  211. static String mode = "online";
  212. static String strChunkSize = "5,10,5";
  213. static int chunkInterval = 10;
  214. static int sendChunkSize = 1920;
  215. static String hotwords="";
  216. static String fsthotwords="";
  217. String wavName = "javatest";
  218. public static void main(String[] args) throws URISyntaxException {
  219. ArgumentParser parser = ArgumentParsers.newArgumentParser("ws client").defaultHelp(true);
  220. parser
  221. .addArgument("--port")
  222. .help("Port on which to listen.")
  223. .setDefault("8889")
  224. .type(String.class)
  225. .required(false);
  226. parser
  227. .addArgument("--host")
  228. .help("the IP address of server.")
  229. .setDefault("127.0.0.1")
  230. .type(String.class)
  231. .required(false);
  232. parser
  233. .addArgument("--audio_in")
  234. .help("wav path for decoding.")
  235. .setDefault("asr_example.wav")
  236. .type(String.class)
  237. .required(false);
  238. parser
  239. .addArgument("--num_threads")
  240. .help("num of threads for test.")
  241. .setDefault(1)
  242. .type(Integer.class)
  243. .required(false);
  244. parser
  245. .addArgument("--chunk_size")
  246. .help("chunk size for asr.")
  247. .setDefault("5, 10, 5")
  248. .type(String.class)
  249. .required(false);
  250. parser
  251. .addArgument("--chunk_interval")
  252. .help("chunk for asr.")
  253. .setDefault(10)
  254. .type(Integer.class)
  255. .required(false);
  256. parser
  257. .addArgument("--mode")
  258. .help("mode for asr.")
  259. .setDefault("offline")
  260. .type(String.class)
  261. .required(false);
  262. parser
  263. .addArgument("--hotwords")
  264. .help("hotwords, splited by space, hello 30 nihao 40")
  265. .setDefault("")
  266. .type(String.class)
  267. .required(false);
  268. String srvIp = "";
  269. String srvPort = "";
  270. String wavPath = "";
  271. int numThreads = 1;
  272. String chunk_size = "";
  273. int chunk_interval = 10;
  274. String strmode = "offline";
  275. String hot="";
  276. try {
  277. Namespace ns = parser.parseArgs(args);
  278. srvIp = ns.get("host");
  279. srvPort = ns.get("port");
  280. wavPath = ns.get("audio_in");
  281. numThreads = ns.get("num_threads");
  282. chunk_size = ns.get("chunk_size");
  283. chunk_interval = ns.get("chunk_interval");
  284. strmode = ns.get("mode");
  285. hot=ns.get("hotwords");
  286. System.out.println(srvPort);
  287. } catch (ArgumentParserException ex) {
  288. ex.getParser().handleError(ex);
  289. return;
  290. }
  291. FunasrWsClient.strChunkSize = chunk_size;
  292. FunasrWsClient.chunkInterval = chunk_interval;
  293. FunasrWsClient.wavPath = wavPath;
  294. FunasrWsClient.mode = strmode;
  295. FunasrWsClient.hotwords=hot;
  296. System.out.println(
  297. "serIp="
  298. + srvIp
  299. + ",srvPort="
  300. + srvPort
  301. + ",wavPath="
  302. + wavPath
  303. + ",strChunkSize"
  304. + strChunkSize);
  305. class ClientThread implements Runnable {
  306. String srvIp;
  307. String srvPort;
  308. ClientThread(String srvIp, String srvPort, String wavPath) {
  309. this.srvIp = srvIp;
  310. this.srvPort = srvPort;
  311. }
  312. public void run() {
  313. try {
  314. int RATE = 16000;
  315. String[] chunkList = strChunkSize.split(",");
  316. int int_chunk_size = 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval;
  317. int CHUNK = Integer.valueOf(RATE / 1000 * int_chunk_size);
  318. int stride =
  319. Integer.valueOf(
  320. 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval / 1000 * 16000 * 2);
  321. System.out.println("chunk_size:" + String.valueOf(int_chunk_size));
  322. System.out.println("CHUNK:" + CHUNK);
  323. System.out.println("stride:" + String.valueOf(stride));
  324. FunasrWsClient.sendChunkSize = CHUNK * 2;
  325. String wsAddress = "ws://" + srvIp + ":" + srvPort;
  326. FunasrWsClient c = new FunasrWsClient(new URI(wsAddress));
  327. c.connect();
  328. System.out.println("wsAddress:" + wsAddress);
  329. } catch (Exception e) {
  330. e.printStackTrace();
  331. System.out.println("e:" + e);
  332. }
  333. }
  334. }
  335. for (int i = 0; i < numThreads; i++) {
  336. System.out.println("Thread1 is running...");
  337. Thread t = new Thread(new ClientThread(srvIp, srvPort, wavPath));
  338. t.start();
  339. }
  340. }
  341. }