ct-transformer.cpp 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. /**
  2. * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  3. * MIT License (https://opensource.org/licenses/MIT)
  4. */
  5. #include "precomp.h"
  6. namespace funasr {
  7. CTTransformer::CTTransformer()
  8. :env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options{}
  9. {
  10. }
  11. void CTTransformer::InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num){
  12. session_options.SetIntraOpNumThreads(thread_num);
  13. session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
  14. session_options.DisableCpuMemArena();
  15. try{
  16. m_session = std::make_unique<Ort::Session>(env_, punc_model.c_str(), session_options);
  17. LOG(INFO) << "Successfully load model from " << punc_model;
  18. }
  19. catch (std::exception const &e) {
  20. LOG(ERROR) << "Error when load punc onnx model: " << e.what();
  21. exit(-1);
  22. }
  23. // read inputnames outputnames
  24. string strName;
  25. GetInputName(m_session.get(), strName);
  26. m_strInputNames.push_back(strName.c_str());
  27. GetInputName(m_session.get(), strName, 1);
  28. m_strInputNames.push_back(strName);
  29. GetOutputName(m_session.get(), strName);
  30. m_strOutputNames.push_back(strName);
  31. for (auto& item : m_strInputNames)
  32. m_szInputNames.push_back(item.c_str());
  33. for (auto& item : m_strOutputNames)
  34. m_szOutputNames.push_back(item.c_str());
  35. m_tokenizer.OpenYaml(punc_config.c_str());
  36. }
  37. CTTransformer::~CTTransformer()
  38. {
  39. }
  40. string CTTransformer::AddPunc(const char* sz_input)
  41. {
  42. string strResult;
  43. vector<string> strOut;
  44. vector<int> InputData;
  45. m_tokenizer.Tokenize(sz_input, strOut, InputData);
  46. int nTotalBatch = ceil((float)InputData.size() / TOKEN_LEN);
  47. int nCurBatch = -1;
  48. int nSentEnd = -1, nLastCommaIndex = -1;
  49. vector<int32_t> RemainIDs; //
  50. vector<string> RemainStr; //
  51. vector<int> NewPunctuation; //
  52. vector<string> NewString; //
  53. vector<string> NewSentenceOut;
  54. vector<int> NewPuncOut;
  55. int nDiff = 0;
  56. for (size_t i = 0; i < InputData.size(); i += TOKEN_LEN)
  57. {
  58. nDiff = (i + TOKEN_LEN) < InputData.size() ? (0) : (i + TOKEN_LEN - InputData.size());
  59. vector<int32_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff);
  60. vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + TOKEN_LEN - nDiff);
  61. InputIDs.insert(InputIDs.begin(), RemainIDs.begin(), RemainIDs.end()); // RemainIDs+InputIDs;
  62. InputStr.insert(InputStr.begin(), RemainStr.begin(), RemainStr.end()); // RemainStr+InputStr;
  63. auto Punction = Infer(InputIDs);
  64. nCurBatch = i / TOKEN_LEN;
  65. if (nCurBatch < nTotalBatch - 1) // not the last minisetence
  66. {
  67. nSentEnd = -1;
  68. nLastCommaIndex = -1;
  69. for (int nIndex = Punction.size() - 2; nIndex > 0; nIndex--)
  70. {
  71. if (m_tokenizer.Id2Punc(Punction[nIndex]) == m_tokenizer.Id2Punc(PERIOD_INDEX) || m_tokenizer.Id2Punc(Punction[nIndex]) == m_tokenizer.Id2Punc(QUESTION_INDEX))
  72. {
  73. nSentEnd = nIndex;
  74. break;
  75. }
  76. if (nLastCommaIndex < 0 && m_tokenizer.Id2Punc(Punction[nIndex]) == m_tokenizer.Id2Punc(COMMA_INDEX))
  77. {
  78. nLastCommaIndex = nIndex;
  79. }
  80. }
  81. if (nSentEnd < 0 && InputStr.size() > CACHE_POP_TRIGGER_LIMIT && nLastCommaIndex > 0)
  82. {
  83. nSentEnd = nLastCommaIndex;
  84. Punction[nSentEnd] = PERIOD_INDEX;
  85. }
  86. RemainStr.assign(InputStr.begin() + nSentEnd + 1, InputStr.end());
  87. RemainIDs.assign(InputIDs.begin() + nSentEnd + 1, InputIDs.end());
  88. InputStr.assign(InputStr.begin(), InputStr.begin() + nSentEnd + 1); // minit_sentence
  89. Punction.assign(Punction.begin(), Punction.begin() + nSentEnd + 1);
  90. }
  91. NewPunctuation.insert(NewPunctuation.end(), Punction.begin(), Punction.end());
  92. vector<string> WordWithPunc;
  93. for (int i = 0; i < InputStr.size(); i++)
  94. {
  95. // if (i > 0 && !(InputStr[i][0] & 0x80) && (i + 1) <InputStr.size() && !(InputStr[i+1][0] & 0x80))// �м��Ӣ�ģ�
  96. if (i > 0 && !(InputStr[i-1][0] & 0x80) && !(InputStr[i][0] & 0x80))
  97. {
  98. InputStr[i] = " " + InputStr[i];
  99. }
  100. WordWithPunc.push_back(InputStr[i]);
  101. if (Punction[i] != NOTPUNC_INDEX) // �»���
  102. {
  103. WordWithPunc.push_back(m_tokenizer.Id2Punc(Punction[i]));
  104. }
  105. }
  106. NewString.insert(NewString.end(), WordWithPunc.begin(), WordWithPunc.end()); // new_mini_sentence += "".join(words_with_punc)
  107. NewSentenceOut = NewString;
  108. NewPuncOut = NewPunctuation;
  109. // last mini sentence
  110. if(nCurBatch == nTotalBatch - 1)
  111. {
  112. if (NewString[NewString.size() - 1] == m_tokenizer.Id2Punc(COMMA_INDEX) || NewString[NewString.size() - 1] == m_tokenizer.Id2Punc(DUN_INDEX))
  113. {
  114. NewSentenceOut.assign(NewString.begin(), NewString.end() - 1);
  115. NewSentenceOut.push_back(m_tokenizer.Id2Punc(PERIOD_INDEX));
  116. NewPuncOut.assign(NewPunctuation.begin(), NewPunctuation.end() - 1);
  117. NewPuncOut.push_back(PERIOD_INDEX);
  118. }
  119. else if (NewString[NewString.size() - 1] != m_tokenizer.Id2Punc(PERIOD_INDEX) && NewString[NewString.size() - 1] != m_tokenizer.Id2Punc(QUESTION_INDEX))
  120. {
  121. NewSentenceOut = NewString;
  122. NewSentenceOut.push_back(m_tokenizer.Id2Punc(PERIOD_INDEX));
  123. NewPuncOut = NewPunctuation;
  124. NewPuncOut.push_back(PERIOD_INDEX);
  125. }
  126. }
  127. }
  128. for (auto& item : NewSentenceOut)
  129. strResult += item;
  130. return strResult;
  131. }
  132. vector<int> CTTransformer::Infer(vector<int32_t> input_data)
  133. {
  134. Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
  135. vector<int> punction;
  136. std::array<int64_t, 2> input_shape_{ 1, (int64_t)input_data.size()};
  137. Ort::Value onnx_input = Ort::Value::CreateTensor<int32_t>(
  138. m_memoryInfo,
  139. input_data.data(),
  140. input_data.size(),
  141. input_shape_.data(),
  142. input_shape_.size());
  143. std::array<int32_t,1> text_lengths{ (int32_t)input_data.size() };
  144. std::array<int64_t,1> text_lengths_dim{ 1 };
  145. Ort::Value onnx_text_lengths = Ort::Value::CreateTensor(
  146. m_memoryInfo,
  147. text_lengths.data(),
  148. text_lengths.size() * sizeof(int32_t),
  149. text_lengths_dim.data(),
  150. text_lengths_dim.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
  151. std::vector<Ort::Value> input_onnx;
  152. input_onnx.emplace_back(std::move(onnx_input));
  153. input_onnx.emplace_back(std::move(onnx_text_lengths));
  154. try {
  155. auto outputTensor = m_session->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
  156. std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
  157. int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
  158. float * floatData = outputTensor[0].GetTensorMutableData<float>();
  159. for (int i = 0; i < outputCount; i += CANDIDATE_NUM)
  160. {
  161. int index = Argmax(floatData + i, floatData + i + CANDIDATE_NUM-1);
  162. punction.push_back(index);
  163. }
  164. }
  165. catch (std::exception const &e)
  166. {
  167. LOG(ERROR) << "Error when run punc onnx forword: " << (e.what());
  168. }
  169. return punction;
  170. }
  171. } // namespace funasr