paraformer.h 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. /**
  2. * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  3. * MIT License (https://opensource.org/licenses/MIT)
  4. */
  5. #pragma once
  6. #include "precomp.h"
  7. #include "fst/fstlib.h"
  8. #include "fst/symbol-table.h"
  9. #include "bias-lm.h"
  10. #include "phone-set.h"
  11. namespace funasr {
  12. class Paraformer : public Model {
  13. /**
  14. * Author: Speech Lab of DAMO Academy, Alibaba Group
  15. * Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  16. * https://arxiv.org/pdf/2206.08317.pdf
  17. */
  18. private:
  19. Vocab* vocab = nullptr;
  20. Vocab* lm_vocab = nullptr;
  21. SegDict* seg_dict = nullptr;
  22. PhoneSet* phone_set_ = nullptr;
  23. //const float scale = 22.6274169979695;
  24. const float scale = 1.0;
  25. void LoadConfigFromYaml(const char* filename);
  26. void LoadOnlineConfigFromYaml(const char* filename);
  27. void LoadCmvn(const char *filename);
  28. void LfrCmvn(std::vector<std::vector<float>> &asr_feats);
  29. std::shared_ptr<Ort::Session> hw_m_session = nullptr;
  30. Ort::Env hw_env_;
  31. Ort::SessionOptions hw_session_options;
  32. vector<string> hw_m_strInputNames, hw_m_strOutputNames;
  33. vector<const char*> hw_m_szInputNames;
  34. vector<const char*> hw_m_szOutputNames;
  35. bool use_hotword;
  36. public:
  37. Paraformer();
  38. ~Paraformer();
  39. void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
  40. // online
  41. void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
  42. // 2pass
  43. void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
  44. void InitHwCompiler(const std::string &hw_model, int thread_num);
  45. void InitSegDict(const std::string &seg_dict_model);
  46. std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
  47. void Reset();
  48. void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
  49. string Forward(float* din, int len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
  50. string GreedySearch( float* in, int n_len, int64_t token_nums,
  51. bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
  52. string Rescoring();
  53. string GetLang(){return language;};
  54. int GetAsrSampleRate() { return asr_sample_rate; };
  55. void StartUtterance();
  56. void EndUtterance();
  57. void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
  58. string BeamSearch(WfstDecoder* &wfst_decoder, float* in, int n_len, int64_t token_nums);
  59. string FinalizeDecode(WfstDecoder* &wfst_decoder,
  60. bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
  61. Vocab* GetVocab();
  62. Vocab* GetLmVocab();
  63. PhoneSet* GetPhoneSet();
  64. knf::FbankOptions fbank_opts_;
  65. vector<float> means_list_;
  66. vector<float> vars_list_;
  67. int lfr_m = PARA_LFR_M;
  68. int lfr_n = PARA_LFR_N;
  69. // paraformer-offline
  70. std::shared_ptr<Ort::Session> m_session_ = nullptr;
  71. Ort::Env env_;
  72. Ort::SessionOptions session_options_;
  73. vector<string> m_strInputNames, m_strOutputNames;
  74. vector<const char*> m_szInputNames;
  75. vector<const char*> m_szOutputNames;
  76. std::string language="zh-cn";
  77. // paraformer-online
  78. std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
  79. std::shared_ptr<Ort::Session> decoder_session_ = nullptr;
  80. vector<string> en_strInputNames, en_strOutputNames;
  81. vector<const char*> en_szInputNames_;
  82. vector<const char*> en_szOutputNames_;
  83. vector<string> de_strInputNames, de_strOutputNames;
  84. vector<const char*> de_szInputNames_;
  85. vector<const char*> de_szOutputNames_;
  86. // lm
  87. std::shared_ptr<fst::Fst<fst::StdArc>> lm_ = nullptr;
  88. string window_type = "hamming";
  89. int frame_length = 25;
  90. int frame_shift = 10;
  91. int n_mels = 80;
  92. int encoder_size = 512;
  93. int fsmn_layers = 16;
  94. int fsmn_lorder = 10;
  95. int fsmn_dims = 512;
  96. float cif_threshold = 1.0;
  97. float tail_alphas = 0.45;
  98. int asr_sample_rate = MODEL_SAMPLE_RATE;
  99. };
  100. } // namespace funasr