paraformer.cpp 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. #include "precomp.h"
  2. using namespace std;
  3. using namespace paraformer;
  4. Paraformer::Paraformer(const char* path,int thread_num, bool quantize, bool use_vad, bool use_punc)
  5. :env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
  6. string model_path;
  7. string cmvn_path;
  8. string config_path;
  9. // VAD model
  10. if(use_vad){
  11. string vad_path = PathAppend(path, "vad_model.onnx");
  12. string mvn_path = PathAppend(path, "vad.mvn");
  13. vad_handle = make_unique<FsmnVad>();
  14. vad_handle->InitVad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
  15. }
  16. // PUNC model
  17. if(use_punc){
  18. punc_handle = make_unique<CTTransformer>(path, thread_num);
  19. }
  20. if(quantize)
  21. {
  22. model_path = PathAppend(path, "model_quant.onnx");
  23. }else{
  24. model_path = PathAppend(path, "model.onnx");
  25. }
  26. cmvn_path = PathAppend(path, "am.mvn");
  27. config_path = PathAppend(path, "config.yaml");
  28. // knf options
  29. fbank_opts.frame_opts.dither = 0;
  30. fbank_opts.mel_opts.num_bins = 80;
  31. fbank_opts.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
  32. fbank_opts.frame_opts.window_type = "hamming";
  33. fbank_opts.frame_opts.frame_shift_ms = 10;
  34. fbank_opts.frame_opts.frame_length_ms = 25;
  35. fbank_opts.energy_floor = 0;
  36. fbank_opts.mel_opts.debug_mel = false;
  37. // fbank_ = std::make_unique<knf::OnlineFbank>(fbank_opts);
  38. // session_options.SetInterOpNumThreads(1);
  39. session_options.SetIntraOpNumThreads(thread_num);
  40. session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
  41. // DisableCpuMemArena can improve performance
  42. session_options.DisableCpuMemArena();
  43. #ifdef _WIN32
  44. wstring wstrPath = strToWstr(model_path);
  45. m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
  46. #else
  47. m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
  48. #endif
  49. string strName;
  50. GetInputName(m_session.get(), strName);
  51. m_strInputNames.push_back(strName.c_str());
  52. GetInputName(m_session.get(), strName,1);
  53. m_strInputNames.push_back(strName);
  54. GetOutputName(m_session.get(), strName);
  55. m_strOutputNames.push_back(strName);
  56. GetOutputName(m_session.get(), strName,1);
  57. m_strOutputNames.push_back(strName);
  58. for (auto& item : m_strInputNames)
  59. m_szInputNames.push_back(item.c_str());
  60. for (auto& item : m_strOutputNames)
  61. m_szOutputNames.push_back(item.c_str());
  62. vocab = new Vocab(config_path.c_str());
  63. LoadCmvn(cmvn_path.c_str());
  64. }
  65. Paraformer::~Paraformer()
  66. {
  67. if(vocab)
  68. delete vocab;
  69. }
  70. void Paraformer::Reset()
  71. {
  72. }
  73. vector<std::vector<int>> Paraformer::VadSeg(std::vector<float>& pcm_data){
  74. return vad_handle->Infer(pcm_data);
  75. }
  76. string Paraformer::AddPunc(const char* sz_input){
  77. return punc_handle->AddPunc(sz_input);
  78. }
  79. vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int len) {
  80. knf::OnlineFbank fbank_(fbank_opts);
  81. fbank_.AcceptWaveform(sample_rate, waves, len);
  82. //fbank_->InputFinished();
  83. int32_t frames = fbank_.NumFramesReady();
  84. int32_t feature_dim = fbank_opts.mel_opts.num_bins;
  85. vector<float> features(frames * feature_dim);
  86. float *p = features.data();
  87. for (int32_t i = 0; i != frames; ++i) {
  88. const float *f = fbank_.GetFrame(i);
  89. std::copy(f, f + feature_dim, p);
  90. p += feature_dim;
  91. }
  92. return features;
  93. }
  94. void Paraformer::LoadCmvn(const char *filename)
  95. {
  96. ifstream cmvn_stream(filename);
  97. string line;
  98. while (getline(cmvn_stream, line)) {
  99. istringstream iss(line);
  100. vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
  101. if (line_item[0] == "<AddShift>") {
  102. getline(cmvn_stream, line);
  103. istringstream means_lines_stream(line);
  104. vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
  105. if (means_lines[0] == "<LearnRateCoef>") {
  106. for (int j = 3; j < means_lines.size() - 1; j++) {
  107. means_list.push_back(stof(means_lines[j]));
  108. }
  109. continue;
  110. }
  111. }
  112. else if (line_item[0] == "<Rescale>") {
  113. getline(cmvn_stream, line);
  114. istringstream vars_lines_stream(line);
  115. vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
  116. if (vars_lines[0] == "<LearnRateCoef>") {
  117. for (int j = 3; j < vars_lines.size() - 1; j++) {
  118. vars_list.push_back(stof(vars_lines[j])*scale);
  119. }
  120. continue;
  121. }
  122. }
  123. }
  124. }
  125. string Paraformer::GreedySearch(float * in, int n_len, int64_t token_nums)
  126. {
  127. vector<int> hyps;
  128. int Tmax = n_len;
  129. for (int i = 0; i < Tmax; i++) {
  130. int max_idx;
  131. float max_val;
  132. FindMax(in + i * token_nums, token_nums, max_val, max_idx);
  133. hyps.push_back(max_idx);
  134. }
  135. return vocab->Vector2StringV2(hyps);
  136. }
  137. vector<float> Paraformer::ApplyLfr(const std::vector<float> &in)
  138. {
  139. int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
  140. int32_t in_num_frames = in.size() / in_feat_dim;
  141. int32_t out_num_frames =
  142. (in_num_frames - lfr_window_size) / lfr_window_shift + 1;
  143. int32_t out_feat_dim = in_feat_dim * lfr_window_size;
  144. std::vector<float> out(out_num_frames * out_feat_dim);
  145. const float *p_in = in.data();
  146. float *p_out = out.data();
  147. for (int32_t i = 0; i != out_num_frames; ++i) {
  148. std::copy(p_in, p_in + out_feat_dim, p_out);
  149. p_out += out_feat_dim;
  150. p_in += lfr_window_shift * in_feat_dim;
  151. }
  152. return out;
  153. }
  154. void Paraformer::ApplyCmvn(std::vector<float> *v)
  155. {
  156. int32_t dim = means_list.size();
  157. int32_t num_frames = v->size() / dim;
  158. float *p = v->data();
  159. for (int32_t i = 0; i != num_frames; ++i) {
  160. for (int32_t k = 0; k != dim; ++k) {
  161. p[k] = (p[k] + means_list[k]) * vars_list[k];
  162. }
  163. p += dim;
  164. }
  165. }
  166. string Paraformer::Forward(float* din, int len, int flag)
  167. {
  168. int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
  169. std::vector<float> wav_feats = FbankKaldi(MODEL_SAMPLE_RATE, din, len);
  170. wav_feats = ApplyLfr(wav_feats);
  171. ApplyCmvn(&wav_feats);
  172. int32_t feat_dim = lfr_window_size*in_feat_dim;
  173. int32_t num_frames = wav_feats.size() / feat_dim;
  174. #ifdef _WIN_X86
  175. Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
  176. #else
  177. Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
  178. #endif
  179. const int64_t input_shape_[3] = {1, num_frames, feat_dim};
  180. Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo,
  181. wav_feats.data(),
  182. wav_feats.size(),
  183. input_shape_,
  184. 3);
  185. const int64_t paraformer_length_shape[1] = {1};
  186. std::vector<int32_t> paraformer_length;
  187. paraformer_length.emplace_back(num_frames);
  188. Ort::Value onnx_feats_len = Ort::Value::CreateTensor<int32_t>(
  189. m_memoryInfo, paraformer_length.data(), paraformer_length.size(), paraformer_length_shape, 1);
  190. std::vector<Ort::Value> input_onnx;
  191. input_onnx.emplace_back(std::move(onnx_feats));
  192. input_onnx.emplace_back(std::move(onnx_feats_len));
  193. string result;
  194. try {
  195. auto outputTensor = m_session->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
  196. std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
  197. int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
  198. float* floatData = outputTensor[0].GetTensorMutableData<float>();
  199. auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
  200. result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
  201. }
  202. catch (std::exception const &e)
  203. {
  204. printf(e.what());
  205. }
  206. return result;
  207. }
  208. string Paraformer::ForwardChunk(float* din, int len, int flag)
  209. {
  210. printf("Not Imp!!!!!!\n");
  211. return "Hello";
  212. }
  213. string Paraformer::Rescoring()
  214. {
  215. printf("Not Imp!!!!!!\n");
  216. return "Hello";
  217. }