util.cpp 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077
  1. #include "precomp.h"
  2. namespace funasr {
  3. float *LoadParams(const char *filename)
  4. {
  5. FILE *fp;
  6. fp = fopen(filename, "rb");
  7. fseek(fp, 0, SEEK_END);
  8. uint32_t nFileLen = ftell(fp);
  9. fseek(fp, 0, SEEK_SET);
  10. float *params_addr = (float *)AlignedMalloc(32, nFileLen);
  11. int n = fread(params_addr, 1, nFileLen, fp);
  12. fclose(fp);
  13. return params_addr;
  14. }
  15. int ValAlign(int val, int align)
  16. {
  17. float tmp = ceil((float)val / (float)align) * (float)align;
  18. return (int)tmp;
  19. }
  20. void DispParams(float *din, int size)
  21. {
  22. int i;
  23. for (i = 0; i < size; i++) {
  24. printf("%f ", din[i]);
  25. }
  26. printf("\n");
  27. }
  28. void SaveDataFile(const char *filename, void *data, uint32_t len)
  29. {
  30. FILE *fp;
  31. fp = fopen(filename, "wb+");
  32. fwrite(data, 1, len, fp);
  33. fclose(fp);
  34. }
  35. void BasicNorm(Tensor<float> *&din, float norm)
  36. {
  37. int Tmax = din->size[2];
  38. int i, j;
  39. for (i = 0; i < Tmax; i++) {
  40. float sum = 0;
  41. for (j = 0; j < 512; j++) {
  42. int ii = i * 512 + j;
  43. sum += din->buff[ii] * din->buff[ii];
  44. }
  45. float mean = sqrt(sum / 512 + norm);
  46. for (j = 0; j < 512; j++) {
  47. int ii = i * 512 + j;
  48. din->buff[ii] = din->buff[ii] / mean;
  49. }
  50. }
  51. }
  52. void FindMax(float *din, int len, float &max_val, int &max_idx)
  53. {
  54. int i;
  55. max_val = -INFINITY;
  56. max_idx = -1;
  57. for (i = 0; i < len; i++) {
  58. if (din[i] > max_val) {
  59. max_val = din[i];
  60. max_idx = i;
  61. }
  62. }
  63. }
  64. string PathAppend(const string &p1, const string &p2)
  65. {
  66. char sep = '/';
  67. string tmp = p1;
  68. #ifdef _WIN32
  69. sep = '\\';
  70. #endif
  71. if (p1[p1.length()-1] != sep) { // Need to add a
  72. tmp += sep; // path separator
  73. return (tmp + p2);
  74. } else
  75. return (p1 + p2);
  76. }
  77. void Relu(Tensor<float> *din)
  78. {
  79. int i;
  80. for (i = 0; i < din->buff_size; i++) {
  81. float val = din->buff[i];
  82. din->buff[i] = val < 0 ? 0 : val;
  83. }
  84. }
  85. void Swish(Tensor<float> *din)
  86. {
  87. int i;
  88. for (i = 0; i < din->buff_size; i++) {
  89. float val = din->buff[i];
  90. din->buff[i] = val / (1 + exp(-val));
  91. }
  92. }
  93. void Sigmoid(Tensor<float> *din)
  94. {
  95. int i;
  96. for (i = 0; i < din->buff_size; i++) {
  97. float val = din->buff[i];
  98. din->buff[i] = 1 / (1 + exp(-val));
  99. }
  100. }
  101. void DoubleSwish(Tensor<float> *din)
  102. {
  103. int i;
  104. for (i = 0; i < din->buff_size; i++) {
  105. float val = din->buff[i];
  106. din->buff[i] = val / (1 + exp(-val + 1));
  107. }
  108. }
  109. void Softmax(float *din, int mask, int len)
  110. {
  111. float *tmp = (float *)malloc(mask * sizeof(float));
  112. int i;
  113. float sum = 0;
  114. float max = -INFINITY;
  115. for (i = 0; i < mask; i++) {
  116. max = max < din[i] ? din[i] : max;
  117. }
  118. for (i = 0; i < mask; i++) {
  119. tmp[i] = exp(din[i] - max);
  120. sum += tmp[i];
  121. }
  122. for (i = 0; i < mask; i++) {
  123. din[i] = tmp[i] / sum;
  124. }
  125. free(tmp);
  126. for (i = mask; i < len; i++) {
  127. din[i] = 0;
  128. }
  129. }
  130. void LogSoftmax(float *din, int len)
  131. {
  132. float *tmp = (float *)malloc(len * sizeof(float));
  133. int i;
  134. float sum = 0;
  135. for (i = 0; i < len; i++) {
  136. tmp[i] = exp(din[i]);
  137. sum += tmp[i];
  138. }
  139. for (i = 0; i < len; i++) {
  140. din[i] = log(tmp[i] / sum);
  141. }
  142. free(tmp);
  143. }
  144. void Glu(Tensor<float> *din, Tensor<float> *dout)
  145. {
  146. int mm = din->buff_size / 1024;
  147. int i, j;
  148. for (i = 0; i < mm; i++) {
  149. for (j = 0; j < 512; j++) {
  150. int in_off = i * 1024 + j;
  151. int out_off = i * 512 + j;
  152. float a = din->buff[in_off];
  153. float b = din->buff[in_off + 512];
  154. dout->buff[out_off] = a / (1 + exp(-b));
  155. }
  156. }
  157. }
  158. bool is_target_file(const std::string& filename, const std::string target) {
  159. std::size_t pos = filename.find_last_of(".");
  160. if (pos == std::string::npos) {
  161. return false;
  162. }
  163. std::string extension = filename.substr(pos + 1);
  164. return (extension == target);
  165. }
  166. void KeepChineseCharacterAndSplit(const std::string &input_str,
  167. std::vector<std::string> &chinese_characters) {
  168. chinese_characters.resize(0);
  169. std::vector<U16CHAR_T> u16_buf;
  170. u16_buf.resize(std::max(u16_buf.size(), input_str.size() + 1));
  171. U16CHAR_T* pu16 = u16_buf.data();
  172. U8CHAR_T * pu8 = (U8CHAR_T*)input_str.data();
  173. size_t ilen = input_str.size();
  174. size_t len = EncodeConverter::Utf8ToUtf16(pu8, ilen, pu16, ilen + 1);
  175. for (size_t i = 0; i < len; i++) {
  176. if (EncodeConverter::IsChineseCharacter(pu16[i])) {
  177. U8CHAR_T u8buf[4];
  178. size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
  179. u8buf[n] = '\0';
  180. chinese_characters.push_back((const char*)u8buf);
  181. }
  182. }
  183. }
  184. void SplitChiEngCharacters(const std::string &input_str,
  185. std::vector<std::string> &characters) {
  186. characters.resize(0);
  187. std::string eng_word = "";
  188. U16CHAR_T space = 0x0020;
  189. std::vector<U16CHAR_T> u16_buf;
  190. u16_buf.resize(std::max(u16_buf.size(), input_str.size() + 1));
  191. U16CHAR_T* pu16 = u16_buf.data();
  192. U8CHAR_T * pu8 = (U8CHAR_T*)input_str.data();
  193. size_t ilen = input_str.size();
  194. size_t len = EncodeConverter::Utf8ToUtf16(pu8, ilen, pu16, ilen + 1);
  195. for (size_t i = 0; i < len; i++) {
  196. if (EncodeConverter::IsChineseCharacter(pu16[i])) {
  197. if(!eng_word.empty()){
  198. characters.push_back(eng_word);
  199. eng_word = "";
  200. }
  201. U8CHAR_T u8buf[4];
  202. size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
  203. u8buf[n] = '\0';
  204. characters.push_back((const char*)u8buf);
  205. } else if (pu16[i] == space){
  206. if(!eng_word.empty()){
  207. characters.push_back(eng_word);
  208. eng_word = "";
  209. }
  210. }else{
  211. U8CHAR_T u8buf[4];
  212. size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
  213. u8buf[n] = '\0';
  214. eng_word += (const char*)u8buf;
  215. }
  216. }
  217. if(!eng_word.empty()){
  218. characters.push_back(eng_word);
  219. eng_word = "";
  220. }
  221. }
  222. // Timestamp Smooth
  223. void TimestampAdd(std::deque<string> &alignment_str1, std::string str_word){
  224. if(!TimestampIsPunctuation(str_word)){
  225. alignment_str1.push_front(str_word);
  226. }
  227. }
  228. bool TimestampIsPunctuation(const std::string& str) {
  229. const std::string punctuation = u8",。?、,?";
  230. // const std::string punctuation = u8",。?、,.?";
  231. for (char ch : str) {
  232. if (punctuation.find(ch) == std::string::npos) {
  233. return false;
  234. }
  235. }
  236. return true;
  237. }
  238. vector<vector<int>> ParseTimestamps(const std::string& str) {
  239. vector<vector<int>> timestamps;
  240. std::istringstream ss(str);
  241. std::string segment;
  242. // skip first'['
  243. ss.ignore(1);
  244. while (std::getline(ss, segment, ']')) {
  245. std::istringstream segmentStream(segment);
  246. std::string number;
  247. vector<int> ts;
  248. // skip'['
  249. segmentStream.ignore(1);
  250. while (std::getline(segmentStream, number, ',')) {
  251. ts.push_back(std::stoi(number));
  252. }
  253. if(ts.size() != 2){
  254. LOG(ERROR) << "ParseTimestamps Failed";
  255. timestamps.clear();
  256. return timestamps;
  257. }
  258. timestamps.push_back(ts);
  259. ss.ignore(1);
  260. }
  261. return timestamps;
  262. }
  263. bool TimestampIsDigit(U16CHAR_T &u16) {
  264. return u16 >= L'0' && u16 <= L'9';
  265. }
  266. bool TimestampIsAlpha(U16CHAR_T &u16) {
  267. return (u16 >= L'A' && u16 <= L'Z') || (u16 >= L'a' && u16 <= L'z');
  268. }
  269. bool TimestampIsPunctuation(U16CHAR_T &u16) {
  270. // (& ' -) in the dict
  271. if (u16 == 0x26 || u16 == 0x27 || u16 == 0x2D){
  272. return false;
  273. }
  274. return (u16 >= 0x21 && u16 <= 0x2F) // 标准ASCII标点
  275. || (u16 >= 0x3A && u16 <= 0x40) // 标准ASCII标点
  276. || (u16 >= 0x5B && u16 <= 0x60) // 标准ASCII标点
  277. || (u16 >= 0x7B && u16 <= 0x7E) // 标准ASCII标点
  278. || (u16 >= 0x2000 && u16 <= 0x206F) // 常用的Unicode标点
  279. || (u16 >= 0x3000 && u16 <= 0x303F); // CJK符号和标点
  280. }
  281. void TimestampSplitChiEngCharacters(const std::string &input_str,
  282. std::vector<std::string> &characters) {
  283. characters.resize(0);
  284. std::string eng_word = "";
  285. U16CHAR_T space = 0x0020;
  286. std::vector<U16CHAR_T> u16_buf;
  287. u16_buf.resize(std::max(u16_buf.size(), input_str.size() + 1));
  288. U16CHAR_T* pu16 = u16_buf.data();
  289. U8CHAR_T * pu8 = (U8CHAR_T*)input_str.data();
  290. size_t ilen = input_str.size();
  291. size_t len = EncodeConverter::Utf8ToUtf16(pu8, ilen, pu16, ilen + 1);
  292. for (size_t i = 0; i < len; i++) {
  293. if (EncodeConverter::IsChineseCharacter(pu16[i])) {
  294. if(!eng_word.empty()){
  295. characters.push_back(eng_word);
  296. eng_word = "";
  297. }
  298. U8CHAR_T u8buf[4];
  299. size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
  300. u8buf[n] = '\0';
  301. characters.push_back((const char*)u8buf);
  302. } else if (TimestampIsDigit(pu16[i]) || TimestampIsPunctuation(pu16[i])){
  303. if(!eng_word.empty()){
  304. characters.push_back(eng_word);
  305. eng_word = "";
  306. }
  307. U8CHAR_T u8buf[4];
  308. size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
  309. u8buf[n] = '\0';
  310. characters.push_back((const char*)u8buf);
  311. } else if (pu16[i] == space){
  312. if(!eng_word.empty()){
  313. characters.push_back(eng_word);
  314. eng_word = "";
  315. }
  316. }else{
  317. U8CHAR_T u8buf[4];
  318. size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
  319. u8buf[n] = '\0';
  320. eng_word += (const char*)u8buf;
  321. }
  322. }
  323. if(!eng_word.empty()){
  324. characters.push_back(eng_word);
  325. eng_word = "";
  326. }
  327. }
  328. std::string VectorToString(const std::vector<std::vector<int>>& vec, bool out_empty) {
  329. if(vec.size() == 0){
  330. if(out_empty){
  331. return "";
  332. }else{
  333. return "[]";
  334. }
  335. }
  336. std::ostringstream out;
  337. out << "[";
  338. for (size_t i = 0; i < vec.size(); ++i) {
  339. out << "[";
  340. for (size_t j = 0; j < vec[i].size(); ++j) {
  341. out << vec[i][j];
  342. if (j < vec[i].size() - 1) {
  343. out << ",";
  344. }
  345. }
  346. out << "]";
  347. if (i < vec.size() - 1) {
  348. out << ",";
  349. }
  350. }
  351. out << "]";
  352. return out.str();
  353. }
  354. std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time){
  355. vector<vector<int>> timestamps_out;
  356. std::string timestamps_str = "";
  357. // process string to vector<string>
  358. std::vector<std::string> characters;
  359. funasr::TimestampSplitChiEngCharacters(text, characters);
  360. std::vector<std::string> characters_itn;
  361. funasr::TimestampSplitChiEngCharacters(text_itn, characters_itn);
  362. //convert string to vector<vector<int>>
  363. vector<vector<int>> timestamps = funasr::ParseTimestamps(str_time);
  364. if (timestamps.size() == 0){
  365. LOG(ERROR) << "Timestamp Smooth Failed: Length of timestamp is zero";
  366. return timestamps_str;
  367. }
  368. // edit distance
  369. int m = characters.size();
  370. int n = characters_itn.size();
  371. std::vector<std::vector<int>> dp(m + 1, std::vector<int>(n + 1, 0));
  372. // init
  373. for (int i = 0; i <= m; ++i) {
  374. dp[i][0] = i;
  375. }
  376. for (int j = 0; j <= n; ++j) {
  377. dp[0][j] = j;
  378. }
  379. // dp
  380. for (int i = 1; i <= m; ++i) {
  381. for (int j = 1; j <= n; ++j) {
  382. if (characters[i - 1] == characters_itn[j - 1]) {
  383. dp[i][j] = dp[i - 1][j - 1];
  384. } else {
  385. dp[i][j] = std::min({dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]}) + 1;
  386. }
  387. }
  388. }
  389. // backtrack
  390. std::deque<string> alignment_str1, alignment_str2;
  391. int i = m, j = n;
  392. while (i > 0 || j > 0) {
  393. if (i > 0 && j > 0 && dp[i][j] == dp[i - 1][j - 1]) {
  394. funasr::TimestampAdd(alignment_str1, characters[i - 1]);
  395. funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
  396. i -= 1;
  397. j -= 1;
  398. } else if (i > 0 && dp[i][j] == dp[i - 1][j] + 1) {
  399. funasr::TimestampAdd(alignment_str1, characters[i - 1]);
  400. alignment_str2.push_front("");
  401. i -= 1;
  402. } else if (j > 0 && dp[i][j] == dp[i][j - 1] + 1) {
  403. alignment_str1.push_front("");
  404. funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
  405. j -= 1;
  406. } else{
  407. funasr::TimestampAdd(alignment_str1, characters[i - 1]);
  408. funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
  409. i -= 1;
  410. j -= 1;
  411. }
  412. }
  413. // smooth
  414. int itn_count = 0;
  415. int idx_tp = 0;
  416. int idx_itn = 0;
  417. vector<vector<int>> timestamps_tmp;
  418. for(int index = 0; index < alignment_str1.size(); index++){
  419. if (alignment_str1[index] == alignment_str2[index]){
  420. bool subsidy = false;
  421. if (itn_count > 0 && timestamps_tmp.size() == 0){
  422. if(idx_tp >= timestamps.size()){
  423. LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
  424. return timestamps_str;
  425. }
  426. timestamps_tmp.push_back(timestamps[idx_tp]);
  427. subsidy = true;
  428. itn_count++;
  429. }
  430. if (timestamps_tmp.size() > 0){
  431. if (itn_count > 0){
  432. int begin = timestamps_tmp[0][0];
  433. int end = timestamps_tmp.back()[1];
  434. int total_time = end - begin;
  435. int interval = total_time / itn_count;
  436. for(int idx_cnt=0; idx_cnt < itn_count; idx_cnt++){
  437. vector<int> ts;
  438. ts.push_back(begin + interval*idx_cnt);
  439. if(idx_cnt == itn_count-1){
  440. ts.push_back(end);
  441. }else {
  442. ts.push_back(begin + interval*(idx_cnt + 1));
  443. }
  444. timestamps_out.push_back(ts);
  445. }
  446. }
  447. timestamps_tmp.clear();
  448. }
  449. if(!subsidy){
  450. if(idx_tp >= timestamps.size()){
  451. LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
  452. return timestamps_str;
  453. }
  454. timestamps_out.push_back(timestamps[idx_tp]);
  455. }
  456. idx_tp++;
  457. itn_count = 0;
  458. }else{
  459. if (!alignment_str1[index].empty()){
  460. if(idx_tp >= timestamps.size()){
  461. LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
  462. return timestamps_str;
  463. }
  464. timestamps_tmp.push_back(timestamps[idx_tp]);
  465. idx_tp++;
  466. }
  467. if (!alignment_str2[index].empty()){
  468. itn_count++;
  469. }
  470. }
  471. // count length of itn
  472. if (!alignment_str2[index].empty()){
  473. idx_itn++;
  474. }
  475. }
  476. {
  477. if (itn_count > 0 && timestamps_tmp.size() == 0){
  478. if (timestamps_out.size() > 0){
  479. timestamps_tmp.push_back(timestamps_out.back());
  480. itn_count++;
  481. timestamps_out.pop_back();
  482. } else{
  483. LOG(ERROR) << "Timestamp Smooth Failed: Last itn has no timestamp.";
  484. return timestamps_str;
  485. }
  486. }
  487. if (timestamps_tmp.size() > 0){
  488. if (itn_count > 0){
  489. int begin = timestamps_tmp[0][0];
  490. int end = timestamps_tmp.back()[1];
  491. int total_time = end - begin;
  492. int interval = total_time / itn_count;
  493. for(int idx_cnt=0; idx_cnt < itn_count; idx_cnt++){
  494. vector<int> ts;
  495. ts.push_back(begin + interval*idx_cnt);
  496. if(idx_cnt == itn_count-1){
  497. ts.push_back(end);
  498. }else {
  499. ts.push_back(begin + interval*(idx_cnt + 1));
  500. }
  501. timestamps_out.push_back(ts);
  502. }
  503. }
  504. timestamps_tmp.clear();
  505. }
  506. }
  507. if(timestamps_out.size() != idx_itn){
  508. LOG(ERROR) << "Timestamp Smooth Failed: Timestamp length does not matched.";
  509. return timestamps_str;
  510. }
  511. timestamps_str = VectorToString(timestamps_out);
  512. return timestamps_str;
  513. }
  514. std::string TimestampSentence(std::string &text, std::string &str_time){
  515. std::vector<std::string> characters;
  516. funasr::TimestampSplitChiEngCharacters(text, characters);
  517. vector<vector<int>> timestamps = funasr::ParseTimestamps(str_time);
  518. int idx_str = 0, idx_ts = 0;
  519. int start = -1, end = -1;
  520. std::string text_seg = "";
  521. std::string ts_sentences = "";
  522. std::string ts_sent = "";
  523. vector<vector<int>> ts_seg;
  524. while(idx_str < characters.size()){
  525. if (TimestampIsPunctuation(characters[idx_str])){
  526. if(ts_seg.size() >0){
  527. if (ts_seg[0].size() == 2){
  528. start = ts_seg[0][0];
  529. }
  530. if (ts_seg[ts_seg.size()-1].size() == 2){
  531. end = ts_seg[ts_seg.size()-1][1];
  532. }
  533. }
  534. // format
  535. ts_sent += "{\"text_seg\":\"" + text_seg + "\",";
  536. ts_sent += "\"punc\":\"" + characters[idx_str] + "\",";
  537. ts_sent += "\"start\":" + to_string(start) + ",";
  538. ts_sent += "\"end\":" + to_string(end) + ",";
  539. ts_sent += "\"ts_list\":" + VectorToString(ts_seg, false) + "}";
  540. if (idx_str == characters.size()-1){
  541. ts_sentences += ts_sent;
  542. } else{
  543. ts_sentences += ts_sent + ",";
  544. }
  545. // clear
  546. text_seg = "";
  547. ts_sent = "";
  548. start = 0;
  549. end = 0;
  550. ts_seg.clear();
  551. } else if(idx_ts < timestamps.size()) {
  552. if (text_seg.empty()){
  553. text_seg = characters[idx_str];
  554. }else{
  555. text_seg += " " + characters[idx_str];
  556. }
  557. ts_seg.push_back(timestamps[idx_ts]);
  558. idx_ts++;
  559. }
  560. idx_str++;
  561. }
  562. // for none punc results
  563. if(ts_seg.size() >0){
  564. if (ts_seg[0].size() == 2){
  565. start = ts_seg[0][0];
  566. }
  567. if (ts_seg[ts_seg.size()-1].size() == 2){
  568. end = ts_seg[ts_seg.size()-1][1];
  569. }
  570. // format
  571. ts_sent += "{\"text_seg\":\"" + text_seg + "\",";
  572. ts_sent += "\"punc\":\"\",";
  573. ts_sent += "\"start\":" + to_string(start) + ",";
  574. ts_sent += "\"end\":" + to_string(end) + ",";
  575. ts_sent += "\"ts_list\":" + VectorToString(ts_seg, false) + "}";
  576. ts_sentences += ts_sent;
  577. }
  578. return "[" +ts_sentences + "]";
  579. }
  580. std::vector<std::string> split(const std::string &s, char delim) {
  581. std::vector<std::string> elems;
  582. std::stringstream ss(s);
  583. std::string item;
  584. while(std::getline(ss, item, delim)) {
  585. elems.push_back(item);
  586. }
  587. return elems;
  588. }
  589. template<typename T>
  590. void PrintMat(const std::vector<std::vector<T>> &mat, const std::string &name) {
  591. std::cout << name << ":" << std::endl;
  592. for (auto item : mat) {
  593. for (auto item_ : item) {
  594. std::cout << item_ << " ";
  595. }
  596. std::cout << std::endl;
  597. }
  598. }
  599. size_t Utf8ToCharset(const std::string &input, std::vector<std::string> &output) {
  600. std::string ch;
  601. for (size_t i = 0, len = 0; i != input.length(); i += len) {
  602. unsigned char byte = (unsigned)input[i];
  603. if (byte >= 0xFC) // lenght 6
  604. len = 6;
  605. else if (byte >= 0xF8)
  606. len = 5;
  607. else if (byte >= 0xF0)
  608. len = 4;
  609. else if (byte >= 0xE0)
  610. len = 3;
  611. else if (byte >= 0xC0)
  612. len = 2;
  613. else
  614. len = 1;
  615. ch = input.substr(i, len);
  616. output.push_back(ch);
  617. }
  618. return output.size();
  619. }
  620. int Str2IntFunc(string str)
  621. {
  622. const char *ch_array = str.c_str();
  623. if (((ch_array[0] & 0xf0) != 0xe0) || ((ch_array[1] & 0xc0) != 0x80) ||
  624. ((ch_array[2] & 0xc0) != 0x80))
  625. return 0;
  626. int val = ((ch_array[0] & 0x0f) << 12) | ((ch_array[1] & 0x3f) << 6) |
  627. (ch_array[2] & 0x3f);
  628. return val;
  629. }
  630. bool IsChinese(string ch)
  631. {
  632. if (ch.size() != 3) {
  633. return false;
  634. }
  635. int unicode = Str2IntFunc(ch);
  636. if (unicode >= 19968 && unicode <= 40959) {
  637. return true;
  638. }
  639. return false;
  640. }
  641. string PostProcess(std::vector<string> &raw_char, std::vector<std::vector<float>> &timestamp_list){
  642. std::vector<std::vector<float>> timestamp_merge;
  643. int i;
  644. list<string> words;
  645. int is_pre_english = false;
  646. int pre_english_len = 0;
  647. int is_combining = false;
  648. string combine = "";
  649. float begin=-1;
  650. for (i=0; i<raw_char.size(); i++){
  651. string word = raw_char[i];
  652. // step1 space character skips
  653. if (word == "<s>" || word == "</s>" || word == "<unk>")
  654. continue;
  655. // step2 combie phoneme to full word
  656. {
  657. int sub_word = !(word.find("@@") == string::npos);
  658. // process word start and middle part
  659. if (sub_word) {
  660. // if badcase: lo@@ chinese
  661. if (i == raw_char.size()-1 || i<raw_char.size()-1 && IsChinese(raw_char[i+1])){
  662. word = word.erase(word.length() - 2) + " ";
  663. if (is_combining) {
  664. combine += word;
  665. is_combining = false;
  666. word = combine;
  667. combine = "";
  668. }
  669. }else{
  670. combine += word.erase(word.length() - 2);
  671. if(!is_combining){
  672. begin = timestamp_list[i][0];
  673. }
  674. is_combining = true;
  675. continue;
  676. }
  677. }
  678. // process word end part
  679. else if (is_combining) {
  680. combine += word;
  681. is_combining = false;
  682. word = combine;
  683. combine = "";
  684. }
  685. }
  686. // step3 process english word deal with space , turn abbreviation to upper case
  687. {
  688. // input word is chinese, not need process
  689. if (IsChinese(word)) {
  690. words.push_back(word);
  691. timestamp_merge.emplace_back(timestamp_list[i]);
  692. is_pre_english = false;
  693. }
  694. // input word is english word
  695. else {
  696. // pre word is chinese
  697. if (!is_pre_english) {
  698. // word[0] = word[0] - 32;
  699. words.push_back(word);
  700. begin = (begin==-1)?timestamp_list[i][0]:begin;
  701. std::vector<float> vec = {begin, timestamp_list[i][1]};
  702. timestamp_merge.emplace_back(vec);
  703. begin = -1;
  704. pre_english_len = word.size();
  705. }
  706. // pre word is english word
  707. else {
  708. // single letter turn to upper case
  709. // if (word.size() == 1) {
  710. // word[0] = word[0] - 32;
  711. // }
  712. if (pre_english_len > 1) {
  713. words.push_back(" ");
  714. words.push_back(word);
  715. begin = (begin==-1)?timestamp_list[i][0]:begin;
  716. std::vector<float> vec = {begin, timestamp_list[i][1]};
  717. timestamp_merge.emplace_back(vec);
  718. begin = -1;
  719. pre_english_len = word.size();
  720. }
  721. else {
  722. // if (word.size() > 1) {
  723. // words.push_back(" ");
  724. // }
  725. words.push_back(" ");
  726. words.push_back(word);
  727. begin = (begin==-1)?timestamp_list[i][0]:begin;
  728. std::vector<float> vec = {begin, timestamp_list[i][1]};
  729. timestamp_merge.emplace_back(vec);
  730. begin = -1;
  731. pre_english_len = word.size();
  732. }
  733. }
  734. is_pre_english = true;
  735. }
  736. }
  737. }
  738. string stamp_str="";
  739. for (i=0; i<timestamp_merge.size(); i++) {
  740. stamp_str += std::to_string(timestamp_merge[i][0]);
  741. stamp_str += ", ";
  742. stamp_str += std::to_string(timestamp_merge[i][1]);
  743. if(i!=timestamp_merge.size()-1){
  744. stamp_str += ",";
  745. }
  746. }
  747. stringstream ss;
  748. for (auto it = words.begin(); it != words.end(); it++) {
  749. ss << *it;
  750. }
  751. return ss.str()+" | "+stamp_str;
  752. }
  753. void TimestampOnnx( std::vector<float>& us_alphas,
  754. std::vector<float> us_cif_peak,
  755. std::vector<string>& char_list,
  756. std::string &res_str,
  757. std::vector<std::vector<float>> &timestamp_vec,
  758. float begin_time,
  759. float total_offset){
  760. if (char_list.empty()) {
  761. return ;
  762. }
  763. const float START_END_THRESHOLD = 5.0;
  764. const float MAX_TOKEN_DURATION = 30.0;
  765. const float TIME_RATE = 10.0 * 6 / 1000 / 3;
  766. // 3 times upsampled, cif_peak is flattened into a 1D array
  767. std::vector<float> cif_peak = us_cif_peak;
  768. int num_frames = cif_peak.size();
  769. if (char_list.back() == "</s>") {
  770. char_list.pop_back();
  771. }
  772. if (char_list.empty()) {
  773. return ;
  774. }
  775. vector<vector<float>> timestamp_list;
  776. vector<string> new_char_list;
  777. vector<float> fire_place;
  778. // for bicif model trained with large data, cif2 actually fires when a character starts
  779. // so treat the frames between two peaks as the duration of the former token
  780. for (int i = 0; i < num_frames; i++) {
  781. if (cif_peak[i] > 1.0 - 1e-4) {
  782. fire_place.push_back(i + total_offset);
  783. }
  784. }
  785. int num_peak = fire_place.size();
  786. if(num_peak != (int)char_list.size() + 1){
  787. float sum = std::accumulate(us_alphas.begin(), us_alphas.end(), 0.0f);
  788. float scale = sum/((int)char_list.size() + 1);
  789. if(scale == 0){
  790. return;
  791. }
  792. cif_peak.clear();
  793. sum = 0.0;
  794. for(auto &alpha:us_alphas){
  795. alpha = alpha/scale;
  796. sum += alpha;
  797. cif_peak.emplace_back(sum);
  798. if(sum>=1.0 - 1e-4){
  799. sum -=(1.0 - 1e-4);
  800. }
  801. }
  802. fire_place.clear();
  803. for (int i = 0; i < num_frames; i++) {
  804. if (cif_peak[i] > 1.0 - 1e-4) {
  805. fire_place.push_back(i + total_offset);
  806. }
  807. }
  808. }
  809. num_peak = fire_place.size();
  810. if(fire_place.size() == 0){
  811. return;
  812. }
  813. // begin silence
  814. if (fire_place[0] > START_END_THRESHOLD) {
  815. new_char_list.push_back("<sil>");
  816. timestamp_list.push_back({0.0, fire_place[0] * TIME_RATE});
  817. }
  818. // tokens timestamp
  819. for (int i = 0; i < num_peak - 1; i++) {
  820. new_char_list.push_back(char_list[i]);
  821. if (i == num_peak - 2 || MAX_TOKEN_DURATION < 0 || fire_place[i + 1] - fire_place[i] < MAX_TOKEN_DURATION) {
  822. timestamp_list.push_back({fire_place[i] * TIME_RATE, fire_place[i + 1] * TIME_RATE});
  823. } else {
  824. // cut the duration to token and sil of the 0-weight frames last long
  825. float _split = fire_place[i] + MAX_TOKEN_DURATION;
  826. timestamp_list.push_back({fire_place[i] * TIME_RATE, _split * TIME_RATE});
  827. timestamp_list.push_back({_split * TIME_RATE, fire_place[i + 1] * TIME_RATE});
  828. new_char_list.push_back("<sil>");
  829. }
  830. }
  831. // tail token and end silence
  832. if(timestamp_list.size()==0){
  833. LOG(ERROR)<<"timestamp_list's size is 0!";
  834. return;
  835. }
  836. if (num_frames - fire_place.back() > START_END_THRESHOLD) {
  837. float _end = (num_frames + fire_place.back()) / 2.0;
  838. timestamp_list.back()[1] = _end * TIME_RATE;
  839. timestamp_list.push_back({_end * TIME_RATE, num_frames * TIME_RATE});
  840. new_char_list.push_back("<sil>");
  841. } else {
  842. timestamp_list.back()[1] = num_frames * TIME_RATE;
  843. }
  844. if (begin_time) { // add offset time in model with vad
  845. for (auto& timestamp : timestamp_list) {
  846. timestamp[0] += begin_time / 1000.0;
  847. timestamp[1] += begin_time / 1000.0;
  848. }
  849. }
  850. assert(new_char_list.size() == timestamp_list.size());
  851. for (int i = 0; i < (int)new_char_list.size(); i++) {
  852. res_str += new_char_list[i] + " " + to_string(timestamp_list[i][0]) + " " + to_string(timestamp_list[i][1]) + ";";
  853. }
  854. for (int i = 0; i < (int)new_char_list.size(); i++) {
  855. if(new_char_list[i] != "<sil>"){
  856. timestamp_vec.push_back(timestamp_list[i]);
  857. }
  858. }
  859. }
  860. bool IsTargetFile(const std::string& filename, const std::string target) {
  861. std::size_t pos = filename.find_last_of(".");
  862. if (pos == std::string::npos) {
  863. return false;
  864. }
  865. std::string extension = filename.substr(pos + 1);
  866. return (extension == target);
  867. }
  868. void Trim(std::string *str) {
  869. const char *white_chars = " \t\n\r\f\v";
  870. std::string::size_type pos = str->find_last_not_of(white_chars);
  871. if (pos != std::string::npos) {
  872. str->erase(pos + 1);
  873. pos = str->find_first_not_of(white_chars);
  874. if (pos != std::string::npos) str->erase(0, pos);
  875. } else {
  876. str->erase(str->begin(), str->end());
  877. }
  878. }
  879. void SplitStringToVector(const std::string &full, const char *delim,
  880. bool omit_empty_strings,
  881. std::vector<std::string> *out) {
  882. size_t start = 0, found = 0, end = full.size();
  883. out->clear();
  884. while (found != std::string::npos) {
  885. found = full.find_first_of(delim, start);
  886. // start != end condition is for when the delimiter is at the end
  887. if (!omit_empty_strings || (found != start && start != end))
  888. out->push_back(full.substr(start, found - start));
  889. start = found + 1;
  890. }
  891. }
  892. void ExtractHws(string hws_file, unordered_map<string, int> &hws_map)
  893. {
  894. if(hws_file.empty()){
  895. return;
  896. }
  897. std::string line;
  898. std::ifstream ifs_hws(hws_file.c_str());
  899. if(!ifs_hws.is_open()){
  900. LOG(ERROR) << "Unable to open hotwords file: " << hws_file
  901. << ". If you have not set hotwords, please ignore this message.";
  902. return;
  903. }
  904. LOG(INFO) << "hotwords: ";
  905. while (getline(ifs_hws, line)) {
  906. Trim(&line);
  907. if (line.empty()) {
  908. continue;
  909. }
  910. float score = 1.0f;
  911. std::vector<std::string> text;
  912. SplitStringToVector(line, " ", true, &text);
  913. if (text.size() > 1) {
  914. try{
  915. score = std::stof(text[text.size() - 1]);
  916. }catch (std::exception const &e)
  917. {
  918. LOG(ERROR)<<e.what();
  919. continue;
  920. }
  921. } else {
  922. continue;
  923. }
  924. std::string hotword = "";
  925. for (size_t i = 0; i < text.size()-1; ++i) {
  926. hotword = hotword + text[i];
  927. if(i != text.size()-2){
  928. hotword = hotword + " ";
  929. }
  930. }
  931. LOG(INFO) << hotword << " : " << score;
  932. hws_map.emplace(hotword, score);
  933. }
  934. ifs_hws.close();
  935. }
  936. void ExtractHws(string hws_file, unordered_map<string, int> &hws_map, string& nn_hotwords_)
  937. {
  938. if(hws_file.empty()){
  939. return;
  940. }
  941. std::string line;
  942. std::ifstream ifs_hws(hws_file.c_str());
  943. if(!ifs_hws.is_open()){
  944. LOG(ERROR) << "Unable to open hotwords file: " << hws_file
  945. << ". If you have not set hotwords, please ignore this message.";
  946. return;
  947. }
  948. LOG(INFO) << "hotwords: ";
  949. while (getline(ifs_hws, line)) {
  950. Trim(&line);
  951. if (line.empty()) {
  952. continue;
  953. }
  954. float score = 1.0f;
  955. std::vector<std::string> text;
  956. SplitStringToVector(line, " ", true, &text);
  957. if (text.size() > 1) {
  958. try{
  959. score = std::stof(text[text.size() - 1]);
  960. }catch (std::exception const &e)
  961. {
  962. LOG(ERROR)<<e.what();
  963. continue;
  964. }
  965. } else {
  966. continue;
  967. }
  968. std::string hotword = "";
  969. for (size_t i = 0; i < text.size()-1; ++i) {
  970. hotword = hotword + text[i];
  971. if(i != text.size()-2){
  972. hotword = hotword + " ";
  973. }
  974. }
  975. nn_hotwords_ += " " + hotword;
  976. LOG(INFO) << hotword << " : " << score;
  977. hws_map.emplace(hotword, score);
  978. }
  979. ifs_hws.close();
  980. }
  981. void SmoothTimestamps(std::string &str_punc, std::string &str_itn, std::string &str_timetamp){
  982. return;
  983. }
  984. } // namespace funasr