vocab.cpp 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. #include "vocab.h"
  2. #include <yaml-cpp/yaml.h>
  3. #include <glog/logging.h>
  4. #include <fstream>
  5. #include <iostream>
  6. #include <list>
  7. #include <sstream>
  8. #include <string>
  9. using namespace std;
  10. namespace funasr {
  11. Vocab::Vocab(const char *filename)
  12. {
  13. ifstream in(filename);
  14. LoadVocabFromYaml(filename);
  15. }
  16. Vocab::Vocab(const char *filename, const char *lex_file)
  17. {
  18. ifstream in(filename);
  19. LoadVocabFromYaml(filename);
  20. LoadLex(lex_file);
  21. }
  22. Vocab::~Vocab()
  23. {
  24. }
  25. void Vocab::LoadVocabFromYaml(const char* filename){
  26. YAML::Node config;
  27. try{
  28. config = YAML::LoadFile(filename);
  29. }catch(exception const &e){
  30. LOG(INFO) << "Error loading file, yaml file error or not exist.";
  31. exit(-1);
  32. }
  33. YAML::Node myList = config["token_list"];
  34. int i = 0;
  35. for (YAML::const_iterator it = myList.begin(); it != myList.end(); ++it) {
  36. vocab.push_back(it->as<string>());
  37. token_id[it->as<string>()] = i;
  38. i ++;
  39. }
  40. }
  41. void Vocab::LoadLex(const char* filename){
  42. std::ifstream file(filename);
  43. std::string line;
  44. while (std::getline(file, line)) {
  45. std::string key, value;
  46. std::istringstream iss(line);
  47. std::getline(iss, key, '\t');
  48. std::getline(iss, value);
  49. if (!key.empty() && !value.empty()) {
  50. lex_map[key] = value;
  51. }
  52. }
  53. file.close();
  54. }
  55. string Vocab::Word2Lex(const std::string &word) const {
  56. auto it = lex_map.find(word);
  57. if (it != lex_map.end()) {
  58. return it->second;
  59. }
  60. return "";
  61. }
  62. int Vocab::GetIdByToken(const std::string &token) const {
  63. auto it = token_id.find(token);
  64. if (it != token_id.end()) {
  65. return it->second;
  66. }
  67. return -1;
  68. }
  69. void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds)
  70. {
  71. for (auto it = in.begin(); it != in.end(); it++) {
  72. string word = vocab[*it];
  73. preds.emplace_back(word);
  74. }
  75. }
  76. string Vocab::Vector2String(vector<int> in)
  77. {
  78. int i;
  79. stringstream ss;
  80. for (auto it = in.begin(); it != in.end(); it++) {
  81. ss << vocab[*it];
  82. }
  83. return ss.str();
  84. }
  85. int Str2Int(string str)
  86. {
  87. const char *ch_array = str.c_str();
  88. if (((ch_array[0] & 0xf0) != 0xe0) || ((ch_array[1] & 0xc0) != 0x80) ||
  89. ((ch_array[2] & 0xc0) != 0x80))
  90. return 0;
  91. int val = ((ch_array[0] & 0x0f) << 12) | ((ch_array[1] & 0x3f) << 6) |
  92. (ch_array[2] & 0x3f);
  93. return val;
  94. }
  95. string Vocab::Id2String(int id) const
  96. {
  97. if (id < 0 || id >= vocab.size()) {
  98. LOG(INFO) << "Error vocabulary id, this id do not exit.";
  99. return "";
  100. } else {
  101. return vocab[id];
  102. }
  103. }
  104. bool Vocab::IsChinese(string ch)
  105. {
  106. if (ch.size() != 3) {
  107. return false;
  108. }
  109. int unicode = Str2Int(ch);
  110. if (unicode >= 19968 && unicode <= 40959) {
  111. return true;
  112. }
  113. return false;
  114. }
  115. string Vocab::WordFormat(std::string word)
  116. {
  117. if(word == "i"){
  118. return "I";
  119. }else if(word == "i'm"){
  120. return "I'm";
  121. }else if(word == "i've"){
  122. return "I've";
  123. }else if(word == "i'll"){
  124. return "I'll";
  125. }else{
  126. return word;
  127. }
  128. }
  129. string Vocab::Vector2StringV2(vector<int> in, std::string language)
  130. {
  131. int i;
  132. list<string> words;
  133. int is_pre_english = false;
  134. int pre_english_len = 0;
  135. int is_combining = false;
  136. std::string combine = "";
  137. std::string unicodeChar = "▁";
  138. for (i=0; i<in.size(); i++){
  139. string word = vocab[in[i]];
  140. // step1 space character skips
  141. if (word == "<s>" || word == "</s>" || word == "<unk>")
  142. continue;
  143. if (language == "en-bpe"){
  144. size_t found = word.find(unicodeChar);
  145. if(found != std::string::npos){
  146. if (combine != ""){
  147. combine = WordFormat(combine);
  148. if (words.size() != 0){
  149. combine = " " + combine;
  150. }
  151. words.push_back(combine);
  152. }
  153. combine = word.substr(3);
  154. }else{
  155. combine += word;
  156. }
  157. continue;
  158. }
  159. // step2 combie phoneme to full word
  160. {
  161. int sub_word = !(word.find("@@") == string::npos);
  162. // process word start and middle part
  163. if (sub_word) {
  164. // if badcase: lo@@ chinese
  165. if (i == in.size()-1 || i<in.size()-1 && IsChinese(vocab[in[i+1]])){
  166. word = word.erase(word.length() - 2) + " ";
  167. if (is_combining) {
  168. combine += word;
  169. is_combining = false;
  170. word = combine;
  171. combine = "";
  172. }
  173. }else{
  174. combine += word.erase(word.length() - 2);
  175. is_combining = true;
  176. continue;
  177. }
  178. }
  179. // process word end part
  180. else if (is_combining) {
  181. combine += word;
  182. is_combining = false;
  183. word = combine;
  184. combine = "";
  185. }
  186. }
  187. // step3 process english word deal with space , turn abbreviation to upper case
  188. {
  189. // input word is chinese, not need process
  190. if (IsChinese(word)) {
  191. words.push_back(word);
  192. is_pre_english = false;
  193. }
  194. // input word is english word
  195. else {
  196. // pre word is chinese
  197. if (!is_pre_english) {
  198. // word[0] = word[0] - 32;
  199. words.push_back(word);
  200. pre_english_len = word.size();
  201. }
  202. // pre word is english word
  203. else {
  204. // single letter turn to upper case
  205. // if (word.size() == 1) {
  206. // word[0] = word[0] - 32;
  207. // }
  208. if (pre_english_len > 1) {
  209. words.push_back(" ");
  210. words.push_back(word);
  211. pre_english_len = word.size();
  212. }
  213. else {
  214. if (word.size() > 1) {
  215. words.push_back(" ");
  216. }
  217. words.push_back(word);
  218. pre_english_len = word.size();
  219. }
  220. }
  221. is_pre_english = true;
  222. }
  223. }
  224. }
  225. if (language == "en-bpe" && combine != ""){
  226. combine = WordFormat(combine);
  227. if (words.size() != 0){
  228. combine = " " + combine;
  229. }
  230. words.push_back(combine);
  231. }
  232. stringstream ss;
  233. for (auto it = words.begin(); it != words.end(); it++) {
  234. ss << *it;
  235. }
  236. return ss.str();
  237. }
  238. int Vocab::Size() const
  239. {
  240. return vocab.size();
  241. }
  242. } // namespace funasr