lyblsgo 2 лет назад
Родитель
Сommit
35a2bfffdf
26 измененных файлов с 417 добавлено и 443 удалено
  1. 17 18
      funasr/runtime/onnxruntime/include/audio.h
  2. 12 15
      funasr/runtime/onnxruntime/include/libfunasrapi.h
  3. 7 7
      funasr/runtime/onnxruntime/include/model.h
  4. 2 2
      funasr/runtime/onnxruntime/src/alignedmem.cpp
  5. 2 2
      funasr/runtime/onnxruntime/src/alignedmem.h
  6. 29 29
      funasr/runtime/onnxruntime/src/audio.cpp
  7. 6 6
      funasr/runtime/onnxruntime/src/commonfunc.h
  8. 13 13
      funasr/runtime/onnxruntime/src/ct-transformer.cpp
  9. 6 0
      funasr/runtime/onnxruntime/src/fsmn-vad.h
  10. 14 21
      funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
  11. 9 9
      funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
  12. 89 89
      funasr/runtime/onnxruntime/src/libfunasrapi.cpp
  13. 2 2
      funasr/runtime/onnxruntime/src/model.cpp
  14. 5 5
      funasr/runtime/onnxruntime/src/online-feature.cpp
  15. 4 4
      funasr/runtime/onnxruntime/src/online-feature.h
  16. 42 42
      funasr/runtime/onnxruntime/src/paraformer.cpp
  17. 53 0
      funasr/runtime/onnxruntime/src/paraformer.h
  18. 0 54
      funasr/runtime/onnxruntime/src/paraformer_onnx.h
  19. 1 1
      funasr/runtime/onnxruntime/src/precomp.h
  20. 1 1
      funasr/runtime/onnxruntime/src/tensor.h
  21. 46 46
      funasr/runtime/onnxruntime/src/tokenizer.cpp
  22. 15 15
      funasr/runtime/onnxruntime/src/tokenizer.h
  23. 14 14
      funasr/runtime/onnxruntime/src/util.cpp
  24. 13 13
      funasr/runtime/onnxruntime/src/util.h
  25. 9 29
      funasr/runtime/onnxruntime/src/vocab.cpp
  26. 6 6
      funasr/runtime/onnxruntime/src/vocab.h

+ 17 - 18
funasr/runtime/onnxruntime/include/audio.h

@@ -23,11 +23,11 @@ class AudioFrame {
     AudioFrame(int len);
 
     ~AudioFrame();
-    int set_start(int val);
-    int set_end(int val);
-    int get_start();
-    int get_len();
-    int disp();
+    int SetStart(int val);
+    int SetEnd(int val);
+    int GetStart();
+    int GetLen();
+    int Disp();
 };
 
 class Audio {
@@ -45,19 +45,18 @@ class Audio {
     Audio(int data_type);
     Audio(int data_type, int size);
     ~Audio();
-    void disp();
-    bool loadwav(const char* filename, int32_t* sampling_rate);
-    void wavResample(int32_t sampling_rate, const float *waveform, int32_t n);
-    bool loadwav(const char* buf, int nLen, int32_t* sampling_rate);
-    bool loadpcmwav(const char* buf, int nFileLen, int32_t* sampling_rate);
-    bool loadpcmwav(const char* filename, int32_t* sampling_rate);
-    int fetch_chunck(float *&dout, int len);
-    int fetch(float *&dout, int &len, int &flag);
-    void padding();
-    void split(Model* pRecogObj);
-    float get_time_len();
-
-    int get_queue_size() { return (int)frame_queue.size(); }
+    void Disp();
+    bool LoadWav(const char* filename, int32_t* sampling_rate);
+    void WavResample(int32_t sampling_rate, const float *waveform, int32_t n);
+    bool LoadWav(const char* buf, int n_len, int32_t* sampling_rate);
+    bool LoadPcmwav(const char* buf, int n_file_len, int32_t* sampling_rate);
+    bool LoadPcmwav(const char* filename, int32_t* sampling_rate);
+    int FetchChunck(float *&dout, int len);
+    int Fetch(float *&dout, int &len, int &flag);
+    void Padding();
+    void Split(Model* recog_obj);
+    float GetTimeLen();
+    int GetQueueSize() { return (int)frame_queue.size(); }
 };
 
 #endif

+ 12 - 15
funasr/runtime/onnxruntime/include/libfunasrapi.h

@@ -35,7 +35,6 @@ typedef enum
  RASRM_CTC_GREEDY_SEARCH=0,
  RASRM_CTC_RPEFIX_BEAM_SEARCH = 1,
  RASRM_ATTENSION_RESCORING = 2,
- 
 }FUNASR_MODE;
 
 typedef enum {
@@ -43,33 +42,31 @@ typedef enum {
 	FUNASR_MODEL_PADDLE_2 = 1,
 	FUNASR_MODEL_K2 = 2,
 	FUNASR_MODEL_PARAFORMER = 3,
-
 }FUNASR_MODEL_TYPE;
 
-typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
+typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
 	
 // APIs for funasr
-_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThread, bool quantize=false, bool use_vad=false, bool use_punc=false);
-
+_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* sz_model_dir, int thread_num, bool quantize=false, bool use_vad=false, bool use_punc=false);
 
-// if not give a fnCallback ,it should be NULL 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
+// if not give a fn_callback ,it should be NULL 
+_FUNASRAPI FUNASR_RESULT	FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
+_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
+_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
+_FUNASRAPI FUNASR_RESULT	FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
 
-_FUNASRAPI const char*	FunASRGetResult(FUNASR_RESULT Result,int nIndex);
+_FUNASRAPI const char*	FunASRGetResult(FUNASR_RESULT result,int n_index);
 
-_FUNASRAPI const int		FunASRGetRetNumber(FUNASR_RESULT Result);
+_FUNASRAPI const int		FunASRGetRetNumber(FUNASR_RESULT result);
 
-_FUNASRAPI void			FunASRFreeResult(FUNASR_RESULT Result);
+_FUNASRAPI void			FunASRFreeResult(FUNASR_RESULT result);
 
-_FUNASRAPI void			FunASRUninit(FUNASR_HANDLE Handle);
+_FUNASRAPI void			FunASRUninit(FUNASR_HANDLE handle);
 
-_FUNASRAPI const float	FunASRGetRetSnippetTime(FUNASR_RESULT Result);
+_FUNASRAPI const float	FunASRGetRetSnippetTime(FUNASR_RESULT result);
 
 #ifdef __cplusplus 
 

+ 7 - 7
funasr/runtime/onnxruntime/include/model.h

@@ -7,13 +7,13 @@
 class Model {
   public:
     virtual ~Model(){};
-    virtual void reset() = 0;
-    virtual std::string forward_chunk(float *din, int len, int flag) = 0;
-    virtual std::string forward(float *din, int len, int flag) = 0;
-    virtual std::string rescoring() = 0;
-    virtual std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data)=0;
-    virtual std::string AddPunc(const char* szInput)=0;
+    virtual void Reset() = 0;
+    virtual std::string ForwardChunk(float *din, int len, int flag) = 0;
+    virtual std::string Forward(float *din, int len, int flag) = 0;
+    virtual std::string Rescoring() = 0;
+    virtual std::vector<std::vector<int>> VadSeg(std::vector<float>& pcm_data)=0;
+    virtual std::string AddPunc(const char* sz_input)=0;
 };
 
-Model *CreateModel(const char *path,int nThread=0,bool quantize=false, bool use_vad=false, bool use_punc=false);
+Model *CreateModel(const char *path,int thread_num=1,bool quantize=false, bool use_vad=false, bool use_punc=false);
 #endif

+ 2 - 2
funasr/runtime/onnxruntime/src/alignedmem.cpp

@@ -1,5 +1,5 @@
 #include "precomp.h"
-void *aligned_malloc(size_t alignment, size_t required_bytes)
+void *AlignedMalloc(size_t alignment, size_t required_bytes)
 {
     void *p1;  // original block
     void **p2; // aligned block
@@ -12,7 +12,7 @@ void *aligned_malloc(size_t alignment, size_t required_bytes)
     return p2;
 }
 
-void aligned_free(void *p)
+void AlignedFree(void *p)
 {
     free(((void **)p)[-1]);
 }

+ 2 - 2
funasr/runtime/onnxruntime/src/alignedmem.h

@@ -4,7 +4,7 @@
 
 
 
-extern void *aligned_malloc(size_t alignment, size_t required_bytes);
-extern void aligned_free(void *p);
+extern void *AlignedMalloc(size_t alignment, size_t required_bytes);
+extern void AlignedFree(void *p);
 
 #endif

+ 29 - 29
funasr/runtime/onnxruntime/src/audio.cpp

@@ -128,30 +128,30 @@ AudioFrame::AudioFrame(int len) : len(len)
     start = 0;
 };
 AudioFrame::~AudioFrame(){};
-int AudioFrame::set_start(int val)
+int AudioFrame::SetStart(int val)
 {
     start = val < 0 ? 0 : val;
     return start;
 };
 
-int AudioFrame::set_end(int val)
+int AudioFrame::SetEnd(int val)
 {
     end = val;
     len = end - start;
     return end;
 };
 
-int AudioFrame::get_start()
+int AudioFrame::GetStart()
 {
     return start;
 };
 
-int AudioFrame::get_len()
+int AudioFrame::GetLen()
 {
     return len;
 };
 
-int AudioFrame::disp()
+int AudioFrame::Disp()
 {
     printf("not imp!!!!\n");
 
@@ -185,18 +185,18 @@ Audio::~Audio()
     }
 }
 
-void Audio::disp()
+void Audio::Disp()
 {
     printf("Audio time is %f s. len is %d\n", (float)speech_len / MODEL_SAMPLE_RATE,
            speech_len);
 }
 
-float Audio::get_time_len()
+float Audio::GetTimeLen()
 {
     return (float)speech_len / MODEL_SAMPLE_RATE;
 }
 
-void Audio::wavResample(int32_t sampling_rate, const float *waveform,
+void Audio::WavResample(int32_t sampling_rate, const float *waveform,
                           int32_t n)
 {
     printf(
@@ -226,7 +226,7 @@ void Audio::wavResample(int32_t sampling_rate, const float *waveform,
     copy(samples.begin(), samples.end(), speech_data);
 }
 
-bool Audio::loadwav(const char *filename, int32_t* sampling_rate)
+bool Audio::LoadWav(const char *filename, int32_t* sampling_rate)
 {
     WaveHeader header;
     if (speech_data != NULL) {
@@ -271,7 +271,7 @@ bool Audio::loadwav(const char *filename, int32_t* sampling_rate)
 
         //resample
         if(*sampling_rate != MODEL_SAMPLE_RATE){
-            wavResample(*sampling_rate, speech_data, speech_len);
+            WavResample(*sampling_rate, speech_data, speech_len);
         }
 
         AudioFrame* frame = new AudioFrame(speech_len);
@@ -283,7 +283,7 @@ bool Audio::loadwav(const char *filename, int32_t* sampling_rate)
         return false;
 }
 
-bool Audio::loadwav(const char* buf, int nFileLen, int32_t* sampling_rate)
+bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate)
 {
     WaveHeader header;
     if (speech_data != NULL) {
@@ -318,7 +318,7 @@ bool Audio::loadwav(const char* buf, int nFileLen, int32_t* sampling_rate)
         
         //resample
         if(*sampling_rate != MODEL_SAMPLE_RATE){
-            wavResample(*sampling_rate, speech_data, speech_len);
+            WavResample(*sampling_rate, speech_data, speech_len);
         }
 
         AudioFrame* frame = new AudioFrame(speech_len);
@@ -330,7 +330,7 @@ bool Audio::loadwav(const char* buf, int nFileLen, int32_t* sampling_rate)
         return false;
 }
 
-bool Audio::loadpcmwav(const char* buf, int nBufLen, int32_t* sampling_rate)
+bool Audio::LoadPcmwav(const char* buf, int n_buf_len, int32_t* sampling_rate)
 {
     if (speech_data != NULL) {
         free(speech_data);
@@ -340,7 +340,7 @@ bool Audio::loadpcmwav(const char* buf, int nBufLen, int32_t* sampling_rate)
     }
     offset = 0;
 
-    speech_len = nBufLen / 2;
+    speech_len = n_buf_len / 2;
     speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
     if (speech_buff)
     {
@@ -361,7 +361,7 @@ bool Audio::loadpcmwav(const char* buf, int nBufLen, int32_t* sampling_rate)
         
         //resample
         if(*sampling_rate != MODEL_SAMPLE_RATE){
-            wavResample(*sampling_rate, speech_data, speech_len);
+            WavResample(*sampling_rate, speech_data, speech_len);
         }
 
         AudioFrame* frame = new AudioFrame(speech_len);
@@ -373,7 +373,7 @@ bool Audio::loadpcmwav(const char* buf, int nBufLen, int32_t* sampling_rate)
         return false;
 }
 
-bool Audio::loadpcmwav(const char* filename, int32_t* sampling_rate)
+bool Audio::LoadPcmwav(const char* filename, int32_t* sampling_rate)
 {
     if (speech_data != NULL) {
         free(speech_data);
@@ -388,10 +388,10 @@ bool Audio::loadpcmwav(const char* filename, int32_t* sampling_rate)
     if (fp == nullptr)
         return false;
     fseek(fp, 0, SEEK_END);
-    uint32_t nFileLen = ftell(fp);
+    uint32_t n_file_len = ftell(fp);
     fseek(fp, 0, SEEK_SET);
 
-    speech_len = (nFileLen) / 2;
+    speech_len = (n_file_len) / 2;
     speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
     if (speech_buff)
     {
@@ -412,7 +412,7 @@ bool Audio::loadpcmwav(const char* filename, int32_t* sampling_rate)
 
         //resample
         if(*sampling_rate != MODEL_SAMPLE_RATE){
-            wavResample(*sampling_rate, speech_data, speech_len);
+            WavResample(*sampling_rate, speech_data, speech_len);
         }
 
         AudioFrame* frame = new AudioFrame(speech_len);
@@ -425,7 +425,7 @@ bool Audio::loadpcmwav(const char* filename, int32_t* sampling_rate)
 
 }
 
-int Audio::fetch_chunck(float *&dout, int len)
+int Audio::FetchChunck(float *&dout, int len)
 {
     if (offset >= speech_align_len) {
         dout = NULL;
@@ -446,14 +446,14 @@ int Audio::fetch_chunck(float *&dout, int len)
     }
 }
 
-int Audio::fetch(float *&dout, int &len, int &flag)
+int Audio::Fetch(float *&dout, int &len, int &flag)
 {
     if (frame_queue.size() > 0) {
         AudioFrame *frame = frame_queue.front();
         frame_queue.pop();
 
-        dout = speech_data + frame->get_start();
-        len = frame->get_len();
+        dout = speech_data + frame->GetStart();
+        len = frame->GetLen();
         delete frame;
         flag = S_END;
         return 1;
@@ -462,7 +462,7 @@ int Audio::fetch(float *&dout, int &len, int &flag)
     }
 }
 
-void Audio::padding()
+void Audio::Padding()
 {
     float num_samples = speech_len;
     float frame_length = 400;
@@ -499,26 +499,26 @@ void Audio::padding()
     delete frame;
 }
 
-void Audio::split(Model* pRecogObj)
+void Audio::Split(Model* recog_obj)
 {
     AudioFrame *frame;
 
     frame = frame_queue.front();
     frame_queue.pop();
-    int sp_len = frame->get_len();
+    int sp_len = frame->GetLen();
     delete frame;
     frame = NULL;
 
     std::vector<float> pcm_data(speech_data, speech_data+sp_len);
-    vector<std::vector<int>> vad_segments = pRecogObj->vad_seg(pcm_data);
+    vector<std::vector<int>> vad_segments = recog_obj->VadSeg(pcm_data);
     int seg_sample = MODEL_SAMPLE_RATE/1000;
     for(vector<int> segment:vad_segments)
     {
         frame = new AudioFrame();
         int start = segment[0]*seg_sample;
         int end = segment[1]*seg_sample;
-        frame->set_start(start);
-        frame->set_end(end);
+        frame->SetStart(start);
+        frame->SetEnd(end);
         frame_queue.push(frame);
         frame = NULL;
     }

+ 6 - 6
funasr/runtime/onnxruntime/src/commonfunc.h

@@ -10,23 +10,23 @@ typedef struct
 #ifdef _WIN32
 #include <codecvt>
 
-inline std::wstring string2wstring(const std::string& str, const std::string& locale)
+inline std::wstring String2wstring(const std::string& str, const std::string& locale)
 {
     typedef std::codecvt_byname<wchar_t, char, std::mbstate_t> F;
     std::wstring_convert<F> strCnv(new F(locale));
     return strCnv.from_bytes(str);
 }
 
-inline std::wstring  strToWstr(std::string str) {
+inline std::wstring  StrToWstr(std::string str) {
     if (str.length() == 0)
         return L"";
-    return  string2wstring(str, "zh-CN");
+    return  String2wstring(str, "zh-CN");
 
 }
 
 #endif
 
-inline void getInputName(Ort::Session* session, string& inputName,int nIndex=0) {
+inline void GetInputName(Ort::Session* session, string& inputName,int nIndex=0) {
     size_t numInputNodes = session->GetInputCount();
     if (numInputNodes > 0) {
         Ort::AllocatorWithDefaultOptions allocator;
@@ -38,7 +38,7 @@ inline void getInputName(Ort::Session* session, string& inputName,int nIndex=0)
     }
 }
 
-inline void getOutputName(Ort::Session* session, string& outputName, int nIndex = 0) {
+inline void GetOutputName(Ort::Session* session, string& outputName, int nIndex = 0) {
     size_t numOutputNodes = session->GetOutputCount();
     if (numOutputNodes > 0) {
         Ort::AllocatorWithDefaultOptions allocator;
@@ -51,6 +51,6 @@ inline void getOutputName(Ort::Session* session, string& outputName, int nIndex
 }
 
 template <class ForwardIterator>
-inline static size_t argmax(ForwardIterator first, ForwardIterator last) {
+inline static size_t Argmax(ForwardIterator first, ForwardIterator last) {
     return std::distance(first, std::max_element(first, last));
 }

+ 13 - 13
funasr/runtime/onnxruntime/src/ct-transformer.cpp

@@ -7,8 +7,8 @@ CTTransformer::CTTransformer(const char* sz_model_dir, int thread_num)
     session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
     session_options.DisableCpuMemArena();
 
-	string strModelPath = pathAppend(sz_model_dir, PUNC_MODEL_FILE);
-	string strYamlPath = pathAppend(sz_model_dir, PUNC_YAML_FILE);
+	string strModelPath = PathAppend(sz_model_dir, PUNC_MODEL_FILE);
+	string strYamlPath = PathAppend(sz_model_dir, PUNC_YAML_FILE);
 
     try{
 #ifdef _WIN32
@@ -24,12 +24,12 @@ CTTransformer::CTTransformer(const char* sz_model_dir, int thread_num)
     }
     // read inputnames outputnamess
     string strName;
-    getInputName(m_session.get(), strName);
+    GetInputName(m_session.get(), strName);
     m_strInputNames.push_back(strName.c_str());
-    getInputName(m_session.get(), strName, 1);
+    GetInputName(m_session.get(), strName, 1);
     m_strInputNames.push_back(strName);
     
-    getOutputName(m_session.get(), strName);
+    GetOutputName(m_session.get(), strName);
     m_strOutputNames.push_back(strName);
 
     for (auto& item : m_strInputNames)
@@ -77,12 +77,12 @@ string CTTransformer::AddPunc(const char* sz_input)
             nLastCommaIndex = -1;
             for (int nIndex = Punction.size() - 2; nIndex > 0; nIndex--)
             {
-                if (m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(PERIOD_INDEX) || m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(QUESTION_INDEX))
+                if (m_tokenizer.Id2Punc(Punction[nIndex]) == m_tokenizer.Id2Punc(PERIOD_INDEX) || m_tokenizer.Id2Punc(Punction[nIndex]) == m_tokenizer.Id2Punc(QUESTION_INDEX))
                 {
                     nSentEnd = nIndex;
                     break;
                 }
-                if (nLastCommaIndex < 0 && m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(COMMA_INDEX))
+                if (nLastCommaIndex < 0 && m_tokenizer.Id2Punc(Punction[nIndex]) == m_tokenizer.Id2Punc(COMMA_INDEX))
                 {
                     nLastCommaIndex = nIndex;
                 }
@@ -110,7 +110,7 @@ string CTTransformer::AddPunc(const char* sz_input)
 
             if (Punction[i] != NOTPUNC_INDEX) // �»���
             {
-                WordWithPunc.push_back(m_tokenizer.ID2Punc(Punction[i]));
+                WordWithPunc.push_back(m_tokenizer.Id2Punc(Punction[i]));
             }
         }
 
@@ -120,17 +120,17 @@ string CTTransformer::AddPunc(const char* sz_input)
         // last mini sentence
         if(nCurBatch == nTotalBatch - 1)
         {
-            if (NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(COMMA_INDEX) || NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(DUN_INDEX))
+            if (NewString[NewString.size() - 1] == m_tokenizer.Id2Punc(COMMA_INDEX) || NewString[NewString.size() - 1] == m_tokenizer.Id2Punc(DUN_INDEX))
             {
                 NewSentenceOut.assign(NewString.begin(), NewString.end() - 1);
-                NewSentenceOut.push_back(m_tokenizer.ID2Punc(PERIOD_INDEX));
+                NewSentenceOut.push_back(m_tokenizer.Id2Punc(PERIOD_INDEX));
                 NewPuncOut.assign(NewPunctuation.begin(), NewPunctuation.end() - 1);
                 NewPuncOut.push_back(PERIOD_INDEX);
             }
-            else if (NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(PERIOD_INDEX) && NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(QUESTION_INDEX))
+            else if (NewString[NewString.size() - 1] == m_tokenizer.Id2Punc(PERIOD_INDEX) && NewString[NewString.size() - 1] == m_tokenizer.Id2Punc(QUESTION_INDEX))
             {
                 NewSentenceOut = NewString;
-                NewSentenceOut.push_back(m_tokenizer.ID2Punc(PERIOD_INDEX));
+                NewSentenceOut.push_back(m_tokenizer.Id2Punc(PERIOD_INDEX));
                 NewPuncOut = NewPunctuation;
                 NewPuncOut.push_back(PERIOD_INDEX);
             }
@@ -173,7 +173,7 @@ vector<int> CTTransformer::Infer(vector<int64_t> input_data)
 
         for (int i = 0; i < outputCount; i += CANDIDATE_NUM)
         {
-            int index = argmax(floatData + i, floatData + i + CANDIDATE_NUM-1);
+            int index = Argmax(floatData + i, floatData + i + CANDIDATE_NUM-1);
             punction.push_back(index);
         }
     }

+ 6 - 0
funasr/runtime/onnxruntime/src/fsmn-vad.h

@@ -5,6 +5,12 @@
 #include "precomp.h"
 
 class FsmnVad {
+/**
+ * Author: Speech Lab of DAMO Academy, Alibaba Group
+ * Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ * https://arxiv.org/abs/1803.05030
+*/
+
 public:
     FsmnVad();
     void Test();

+ 14 - 21
funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp

@@ -19,15 +19,8 @@ using namespace std;
 std::atomic<int> index(0);
 std::mutex mtx;
 
-void runReg(FUNASR_HANDLE AsrHandle, vector<string> wav_list, 
+void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, 
             float* total_length, long* total_time, int core_id) {
-
-    // cpu_set_t cpuset;
-    // CPU_ZERO(&cpuset);
-    // CPU_SET(core_id, &cpuset);
-    // if(pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset) < 0){
-    //     perror("pthread_setaffinity_np");
-    // }
     
     struct timeval start, end;
     long seconds = 0;
@@ -37,7 +30,7 @@ void runReg(FUNASR_HANDLE AsrHandle, vector<string> wav_list,
     // warm up
     for (size_t i = 0; i < 1; i++)
     {
-        FUNASR_RESULT Result=FunASRRecogFile(AsrHandle, wav_list[0].c_str(), RASR_NONE, NULL);
+        FUNASR_RESULT result=FunASRRecogFile(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL);
     }
 
     while (true) {
@@ -48,20 +41,20 @@ void runReg(FUNASR_HANDLE AsrHandle, vector<string> wav_list,
         }
 
         gettimeofday(&start, NULL);
-        FUNASR_RESULT Result=FunASRRecogFile(AsrHandle, wav_list[i].c_str(), RASR_NONE, NULL);
+        FUNASR_RESULT result=FunASRRecogFile(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL);
 
         gettimeofday(&end, NULL);
         seconds = (end.tv_sec - start.tv_sec);
         long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
         n_total_time += taking_micros;
 
-        if(Result){
-            string msg = FunASRGetResult(Result, 0);
+        if(result){
+            string msg = FunASRGetResult(result, 0);
             printf("Thread: %d Result: %s \n", this_thread::get_id(), msg.c_str());
 
-            float snippet_time = FunASRGetRetSnippetTime(Result);
+            float snippet_time = FunASRGetRetSnippetTime(result);
             n_total_length += snippet_time;
-            FunASRFreeResult(Result);
+            FunASRFreeResult(result);
         }else{
             cout <<"No return data!";
         }
@@ -109,11 +102,11 @@ int main(int argc, char *argv[])
     bool quantize = false;
     istringstream(argv[3]) >> boolalpha >> quantize;
     // thread num
-    int nThreadNum = 1;
-    nThreadNum = atoi(argv[4]);
+    int thread_num = 1;
+    thread_num = atoi(argv[4]);
 
-    FUNASR_HANDLE AsrHandle=FunASRInit(argv[1], 1, quantize);
-    if (!AsrHandle)
+    FUNASR_HANDLE asr_handle=FunASRInit(argv[1], 1, quantize);
+    if (!asr_handle)
     {
         printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
         exit(-1);
@@ -128,9 +121,9 @@ int main(int argc, char *argv[])
     long total_time = 0;
     std::vector<std::thread> threads;
 
-    for (int i = 0; i < nThreadNum; i++)
+    for (int i = 0; i < thread_num; i++)
     {
-        threads.emplace_back(thread(runReg, AsrHandle, wav_list, &total_length, &total_time, i));
+        threads.emplace_back(thread(runReg, asr_handle, wav_list, &total_length, &total_time, i));
     }
 
     for (auto& thread : threads)
@@ -143,6 +136,6 @@ int main(int argc, char *argv[])
     printf("total_rtf %05lf .\n", (double)total_time/ (total_length*1000000));
     printf("speedup %05lf .\n", 1.0/((double)total_time/ (total_length*1000000)));
 
-    FunASRUninit(AsrHandle);
+    FunASRUninit(asr_handle);
     return 0;
 }

+ 9 - 9
funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp

@@ -18,7 +18,7 @@ int main(int argc, char *argv[])
     }
     struct timeval start, end;
     gettimeofday(&start, NULL);
-    int nThreadNum = 1;
+    int thread_num = 1;
     // is quantize
     bool quantize = false;
     bool use_vad = false;
@@ -26,9 +26,9 @@ int main(int argc, char *argv[])
     istringstream(argv[3]) >> boolalpha >> quantize;
     istringstream(argv[4]) >> boolalpha >> use_vad;
     istringstream(argv[5]) >> boolalpha >> use_punc;
-    FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize, use_vad, use_punc);
+    FUNASR_HANDLE asr_hanlde=FunASRInit(argv[1], thread_num, quantize, use_vad, use_punc);
 
-    if (!AsrHanlde)
+    if (!asr_hanlde)
     {
         printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
         exit(-1);
@@ -40,17 +40,17 @@ int main(int argc, char *argv[])
     printf("Model initialization takes %lfs.\n", (double)modle_init_micros / 1000000);
 
     gettimeofday(&start, NULL);
-    FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL, use_vad, use_punc);
+    FUNASR_RESULT result=FunASRRecogFile(asr_hanlde, argv[2], RASR_NONE, NULL, use_vad, use_punc);
     gettimeofday(&end, NULL);
 
     float snippet_time = 0.0f;
-    if (Result)
+    if (result)
     {
-        string msg = FunASRGetResult(Result, 0);
+        string msg = FunASRGetResult(result, 0);
         setbuf(stdout, NULL);
         printf("Result: %s \n", msg.c_str());
-        snippet_time = FunASRGetRetSnippetTime(Result);
-        FunASRFreeResult(Result);
+        snippet_time = FunASRGetRetSnippetTime(result);
+        FunASRFreeResult(result);
     }
     else
     {
@@ -63,7 +63,7 @@ int main(int argc, char *argv[])
     printf("Model inference takes %lfs.\n", (double)taking_micros / 1000000);
     printf("Model inference RTF: %04lf.\n", (double)taking_micros/ (snippet_time*1000000));
 
-    FunASRUninit(AsrHanlde);
+    FunASRUninit(asr_hanlde);
 
     return 0;
 }

+ 89 - 89
funasr/runtime/onnxruntime/src/libfunasrapi.cpp

@@ -5,196 +5,196 @@ extern "C" {
 #endif
 
 	// APIs for funasr
-	_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThreadNum, bool quantize, bool use_vad, bool use_punc)
+	_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* sz_model_dir, int thread_num, bool quantize, bool use_vad, bool use_punc)
 	{
-		Model* mm = CreateModel(szModelDir, nThreadNum, quantize, use_vad, use_punc);
+		Model* mm = CreateModel(sz_model_dir, thread_num, quantize, use_vad, use_punc);
 		return mm;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad, bool use_punc)
 	{
-		Model* pRecogObj = (Model*)handle;
-		if (!pRecogObj)
+		Model* recog_obj = (Model*)handle;
+		if (!recog_obj)
 			return nullptr;
 
 		int32_t sampling_rate = -1;
 		Audio audio(1);
-		if (!audio.loadwav(szBuf, nLen, &sampling_rate))
+		if (!audio.LoadWav(sz_buf, n_len, &sampling_rate))
 			return nullptr;
 		if(use_vad){
-			audio.split(pRecogObj);
+			audio.Split(recog_obj);
 		}
 
 		float* buff;
 		int len;
 		int flag=0;
-		FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
-		pResult->snippet_time = audio.get_time_len();
-		int nStep = 0;
-		int nTotal = audio.get_queue_size();
-		while (audio.fetch(buff, len, flag) > 0) {
-			string msg = pRecogObj->forward(buff, len, flag);
-			pResult->msg += msg;
-			nStep++;
-			if (fnCallback)
-				fnCallback(nStep, nTotal);
+		FUNASR_RECOG_RESULT* p_result = new FUNASR_RECOG_RESULT;
+		p_result->snippet_time = audio.GetTimeLen();
+		int n_step = 0;
+		int n_total = audio.GetQueueSize();
+		while (audio.Fetch(buff, len, flag) > 0) {
+			string msg = recog_obj->Forward(buff, len, flag);
+			p_result->msg += msg;
+			n_step++;
+			if (fn_callback)
+				fn_callback(n_step, n_total);
 		}
 		if(use_punc){
-			string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
-			pResult->msg = punc_res;
+			string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
+			p_result->msg = punc_res;
 		}
 
-		return pResult;
+		return p_result;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad, bool use_punc)
 	{
-		Model* pRecogObj = (Model*)handle;
-		if (!pRecogObj)
+		Model* recog_obj = (Model*)handle;
+		if (!recog_obj)
 			return nullptr;
 
 		Audio audio(1);
-		if (!audio.loadpcmwav(szBuf, nLen, &sampling_rate))
+		if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
 			return nullptr;
 		if(use_vad){
-			audio.split(pRecogObj);
+			audio.Split(recog_obj);
 		}
 
 		float* buff;
 		int len;
 		int flag = 0;
-		FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
-		pResult->snippet_time = audio.get_time_len();
-		int nStep = 0;
-		int nTotal = audio.get_queue_size();
-		while (audio.fetch(buff, len, flag) > 0) {
-			string msg = pRecogObj->forward(buff, len, flag);
-			pResult->msg += msg;
-			nStep++;
-			if (fnCallback)
-				fnCallback(nStep, nTotal);
+		FUNASR_RECOG_RESULT* p_result = new FUNASR_RECOG_RESULT;
+		p_result->snippet_time = audio.GetTimeLen();
+		int n_step = 0;
+		int n_total = audio.GetQueueSize();
+		while (audio.Fetch(buff, len, flag) > 0) {
+			string msg = recog_obj->Forward(buff, len, flag);
+			p_result->msg += msg;
+			n_step++;
+			if (fn_callback)
+				fn_callback(n_step, n_total);
 		}
 		if(use_punc){
-			string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
-			pResult->msg = punc_res;
+			string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
+			p_result->msg = punc_res;
 		}
 
-		return pResult;
+		return p_result;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad, bool use_punc)
 	{
-		Model* pRecogObj = (Model*)handle;
-		if (!pRecogObj)
+		Model* recog_obj = (Model*)handle;
+		if (!recog_obj)
 			return nullptr;
 
 		Audio audio(1);
-		if (!audio.loadpcmwav(szFileName, &sampling_rate))
+		if (!audio.LoadPcmwav(sz_filename, &sampling_rate))
 			return nullptr;
 		if(use_vad){
-			audio.split(pRecogObj);
+			audio.Split(recog_obj);
 		}
 
 		float* buff;
 		int len;
 		int flag = 0;
-		FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
-		pResult->snippet_time = audio.get_time_len();
-		int nStep = 0;
-		int nTotal = audio.get_queue_size();
-		while (audio.fetch(buff, len, flag) > 0) {
-			string msg = pRecogObj->forward(buff, len, flag);
-			pResult->msg += msg;
-			nStep++;
-			if (fnCallback)
-				fnCallback(nStep, nTotal);
+		FUNASR_RECOG_RESULT* p_result = new FUNASR_RECOG_RESULT;
+		p_result->snippet_time = audio.GetTimeLen();
+		int n_step = 0;
+		int n_total = audio.GetQueueSize();
+		while (audio.Fetch(buff, len, flag) > 0) {
+			string msg = recog_obj->Forward(buff, len, flag);
+			p_result->msg += msg;
+			n_step++;
+			if (fn_callback)
+				fn_callback(n_step, n_total);
 		}
 		if(use_punc){
-			string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
-			pResult->msg = punc_res;
+			string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
+			p_result->msg = punc_res;
 		}
 
-		return pResult;
+		return p_result;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad, bool use_punc)
 	{
-		Model* pRecogObj = (Model*)handle;
-		if (!pRecogObj)
+		Model* recog_obj = (Model*)handle;
+		if (!recog_obj)
 			return nullptr;
 		
 		int32_t sampling_rate = -1;
 		Audio audio(1);
-		if(!audio.loadwav(szWavfile, &sampling_rate))
+		if(!audio.LoadWav(sz_wavfile, &sampling_rate))
 			return nullptr;
 		if(use_vad){
-			audio.split(pRecogObj);
+			audio.Split(recog_obj);
 		}
 
 		float* buff;
 		int len;
 		int flag = 0;
-		int nStep = 0;
-		int nTotal = audio.get_queue_size();
-		FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
-		pResult->snippet_time = audio.get_time_len();
-		while (audio.fetch(buff, len, flag) > 0) {
-			string msg = pRecogObj->forward(buff, len, flag);
-			pResult->msg+= msg;
-			nStep++;
-			if (fnCallback)
-				fnCallback(nStep, nTotal);
+		int n_step = 0;
+		int n_total = audio.GetQueueSize();
+		FUNASR_RECOG_RESULT* p_result = new FUNASR_RECOG_RESULT;
+		p_result->snippet_time = audio.GetTimeLen();
+		while (audio.Fetch(buff, len, flag) > 0) {
+			string msg = recog_obj->Forward(buff, len, flag);
+			p_result->msg+= msg;
+			n_step++;
+			if (fn_callback)
+				fn_callback(n_step, n_total);
 		}
 		if(use_punc){
-			string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
-			pResult->msg = punc_res;
+			string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
+			p_result->msg = punc_res;
 		}
 	
-		return pResult;
+		return p_result;
 	}
 
-	_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT Result)
+	_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result)
 	{
-		if (!Result)
+		if (!result)
 			return 0;
 
 		return 1;
 	}
 
 
-	_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT Result)
+	_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result)
 	{
-		if (!Result)
+		if (!result)
 			return 0.0f;
 
-		return ((FUNASR_RECOG_RESULT*)Result)->snippet_time;
+		return ((FUNASR_RECOG_RESULT*)result)->snippet_time;
 	}
 
-	_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT Result,int nIndex)
+	_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index)
 	{
-		FUNASR_RECOG_RESULT * pResult = (FUNASR_RECOG_RESULT*)Result;
-		if(!pResult)
+		FUNASR_RECOG_RESULT * p_result = (FUNASR_RECOG_RESULT*)result;
+		if(!p_result)
 			return nullptr;
 
-		return pResult->msg.c_str();
+		return p_result->msg.c_str();
 	}
 
-	_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT Result)
+	_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result)
 	{
-		if (Result)
+		if (result)
 		{
-			delete (FUNASR_RECOG_RESULT*)Result;
+			delete (FUNASR_RECOG_RESULT*)result;
 		}
 	}
 
 	_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle)
 	{
-		Model* pRecogObj = (Model*)handle;
+		Model* recog_obj = (Model*)handle;
 
-		if (!pRecogObj)
+		if (!recog_obj)
 			return;
 
-		delete pRecogObj;
+		delete recog_obj;
 	}
 
 #ifdef __cplusplus 

+ 2 - 2
funasr/runtime/onnxruntime/src/model.cpp

@@ -1,10 +1,10 @@
 #include "precomp.h"
 
-Model *CreateModel(const char *path, int nThread, bool quantize, bool use_vad, bool use_punc)
+Model *CreateModel(const char *path, int thread_num, bool quantize, bool use_vad, bool use_punc)
 {
     Model *mm;
 
-    mm = new paraformer::ModelImp(path, nThread, quantize, use_vad, use_punc);
+    mm = new paraformer::Paraformer(path, thread_num, quantize, use_vad, use_punc);
 
     return mm;
 }

+ 5 - 5
funasr/runtime/onnxruntime/src/online-feature.cpp

@@ -13,10 +13,10 @@ OnlineFeature::OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int
   frame_shift_sample_length_ = sample_rate_ / 1000 * 10;
 }
 
-void OnlineFeature::extractFeats(vector<std::vector<float>> &vad_feats,
+void OnlineFeature::ExtractFeats(vector<std::vector<float>> &vad_feats,
                                  vector<float> waves, bool input_finished) {
   input_finished_ = input_finished;
-  onlineFbank(vad_feats, waves);
+  OnlineFbank(vad_feats, waves);
   // cache deal & online lfr,cmvn
   if (vad_feats.size() > 0) {
     if (!reserve_waveforms_.empty()) {
@@ -53,7 +53,7 @@ void OnlineFeature::extractFeats(vector<std::vector<float>> &vad_feats,
       }
       vad_feats = lfr_splice_cache_;
       OnlineLfrCmvn(vad_feats);
-      reset_cache();
+      ResetCache();
     }
   }
 
@@ -102,13 +102,13 @@ int OnlineFeature::OnlineLfrCmvn(vector<vector<float>> &vad_feats) {
   return lfr_splice_frame_idxs;
 }
 
-void OnlineFeature::onlineFbank(vector<std::vector<float>> &vad_feats,
+void OnlineFeature::OnlineFbank(vector<std::vector<float>> &vad_feats,
                                 vector<float> &waves) {
 
   knf::OnlineFbank fbank(fbank_opts_);
   // cache merge
   waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
-  int frame_number = compute_frame_num(waves.size(), frame_sample_length_, frame_shift_sample_length_);
+  int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
   // Send the audio after the last frame shift position to the cache
   input_cache_.clear();
   input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());

+ 4 - 4
funasr/runtime/onnxruntime/src/online-feature.h

@@ -10,15 +10,15 @@ public:
   OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m_, int lfr_n_,
                 std::vector<std::vector<float>> cmvns_);
 
-  void extractFeats(vector<vector<float>> &vad_feats, vector<float> waves, bool input_finished);
+  void ExtractFeats(vector<vector<float>> &vad_feats, vector<float> waves, bool input_finished);
 
 
 private:
-  void onlineFbank(vector<vector<float>> &vad_feats, vector<float> &waves);
+  void OnlineFbank(vector<vector<float>> &vad_feats, vector<float> &waves);
 
   int OnlineLfrCmvn(vector<vector<float>> &vad_feats);
 
-  static int compute_frame_num(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
+  static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
     int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
 
     if (frame_num >= 1 && sample_length >= frame_sample_length)
@@ -27,7 +27,7 @@ private:
       return 0;
   }
 
-  void reset_cache() {
+  void ResetCache() {
     reserve_waveforms_.clear();
     input_cache_.clear();
     lfr_splice_cache_.clear();

+ 42 - 42
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp → funasr/runtime/onnxruntime/src/paraformer.cpp

@@ -3,33 +3,33 @@
 using namespace std;
 using namespace paraformer;
 
-ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad, bool use_punc)
-:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
+Paraformer::Paraformer(const char* path,int thread_num, bool quantize, bool use_vad, bool use_punc)
+:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
     string model_path;
     string cmvn_path;
     string config_path;
 
     // VAD model
     if(use_vad){
-        string vad_path = pathAppend(path, "vad_model.onnx");
-        string mvn_path = pathAppend(path, "vad.mvn");
-        vadHandle = make_unique<FsmnVad>();
-        vadHandle->InitVad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
+        string vad_path = PathAppend(path, "vad_model.onnx");
+        string mvn_path = PathAppend(path, "vad.mvn");
+        vad_handle = make_unique<FsmnVad>();
+        vad_handle->InitVad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
     }
 
     // PUNC model
     if(use_punc){
-        puncHandle = make_unique<CTTransformer>(path, nNumThread);
+        punc_handle = make_unique<CTTransformer>(path, thread_num);
     }
 
     if(quantize)
     {
-        model_path = pathAppend(path, "model_quant.onnx");
+        model_path = PathAppend(path, "model_quant.onnx");
     }else{
-        model_path = pathAppend(path, "model.onnx");
+        model_path = PathAppend(path, "model.onnx");
     }
-    cmvn_path = pathAppend(path, "am.mvn");
-    config_path = pathAppend(path, "config.yaml");
+    cmvn_path = PathAppend(path, "am.mvn");
+    config_path = PathAppend(path, "config.yaml");
 
     // knf options
     fbank_opts.frame_opts.dither = 0;
@@ -42,28 +42,28 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad,
     fbank_opts.mel_opts.debug_mel = false;
     // fbank_ = std::make_unique<knf::OnlineFbank>(fbank_opts);
 
-    // sessionOptions.SetInterOpNumThreads(1);
-    sessionOptions.SetIntraOpNumThreads(nNumThread);
-    sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
+    // session_options.SetInterOpNumThreads(1);
+    session_options.SetIntraOpNumThreads(thread_num);
+    session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
     // DisableCpuMemArena can improve performance
-    sessionOptions.DisableCpuMemArena();
+    session_options.DisableCpuMemArena();
 
 #ifdef _WIN32
     wstring wstrPath = strToWstr(model_path);
-    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
+    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
 #else
-    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
+    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
 #endif
 
     string strName;
-    getInputName(m_session.get(), strName);
+    GetInputName(m_session.get(), strName);
     m_strInputNames.push_back(strName.c_str());
-    getInputName(m_session.get(), strName,1);
+    GetInputName(m_session.get(), strName,1);
     m_strInputNames.push_back(strName);
     
-    getOutputName(m_session.get(), strName);
+    GetOutputName(m_session.get(), strName);
     m_strOutputNames.push_back(strName);
-    getOutputName(m_session.get(), strName,1);
+    GetOutputName(m_session.get(), strName,1);
     m_strOutputNames.push_back(strName);
 
     for (auto& item : m_strInputNames)
@@ -71,28 +71,28 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad,
     for (auto& item : m_strOutputNames)
         m_szOutputNames.push_back(item.c_str());
     vocab = new Vocab(config_path.c_str());
-    load_cmvn(cmvn_path.c_str());
+    LoadCmvn(cmvn_path.c_str());
 }
 
-ModelImp::~ModelImp()
+Paraformer::~Paraformer()
 {
     if(vocab)
         delete vocab;
 }
 
-void ModelImp::reset()
+void Paraformer::Reset()
 {
 }
 
-vector<std::vector<int>> ModelImp::vad_seg(std::vector<float>& pcm_data){
-    return vadHandle->Infer(pcm_data);
+vector<std::vector<int>> Paraformer::VadSeg(std::vector<float>& pcm_data){
+    return vad_handle->Infer(pcm_data);
 }
 
-string ModelImp::AddPunc(const char* szInput){
-    return puncHandle->AddPunc(szInput);
+string Paraformer::AddPunc(const char* sz_input){
+    return punc_handle->AddPunc(sz_input);
 }
 
-vector<float> ModelImp::FbankKaldi(float sample_rate, const float* waves, int len) {
+vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int len) {
     knf::OnlineFbank fbank_(fbank_opts);
     fbank_.AcceptWaveform(sample_rate, waves, len);
     //fbank_->InputFinished();
@@ -110,7 +110,7 @@ vector<float> ModelImp::FbankKaldi(float sample_rate, const float* waves, int le
     return features;
 }
 
-void ModelImp::load_cmvn(const char *filename)
+void Paraformer::LoadCmvn(const char *filename)
 {
     ifstream cmvn_stream(filename);
     string line;
@@ -143,21 +143,21 @@ void ModelImp::load_cmvn(const char *filename)
     }
 }
 
-string ModelImp::greedy_search(float * in, int nLen )
+string Paraformer::GreedySearch(float * in, int n_len )
 {
     vector<int> hyps;
-    int Tmax = nLen;
+    int Tmax = n_len;
     for (int i = 0; i < Tmax; i++) {
         int max_idx;
         float max_val;
-        findmax(in + i * 8404, 8404, max_val, max_idx);
+        FindMax(in + i * 8404, 8404, max_val, max_idx);
         hyps.push_back(max_idx);
     }
 
-    return vocab->vector2stringV2(hyps);
+    return vocab->Vector2StringV2(hyps);
 }
 
-vector<float> ModelImp::ApplyLFR(const std::vector<float> &in) 
+vector<float> Paraformer::ApplyLfr(const std::vector<float> &in) 
 {
     int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
     int32_t in_num_frames = in.size() / in_feat_dim;
@@ -180,7 +180,7 @@ vector<float> ModelImp::ApplyLFR(const std::vector<float> &in)
     return out;
   }
 
-  void ModelImp::ApplyCMVN(std::vector<float> *v)
+  void Paraformer::ApplyCmvn(std::vector<float> *v)
   {
     int32_t dim = means_list.size();
     int32_t num_frames = v->size() / dim;
@@ -196,13 +196,13 @@ vector<float> ModelImp::ApplyLFR(const std::vector<float> &in)
     }
   }
 
-string ModelImp::forward(float* din, int len, int flag)
+string Paraformer::Forward(float* din, int len, int flag)
 {
 
     int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
     std::vector<float> wav_feats = FbankKaldi(MODEL_SAMPLE_RATE, din, len);
-    wav_feats = ApplyLFR(wav_feats);
-    ApplyCMVN(&wav_feats);
+    wav_feats = ApplyLfr(wav_feats);
+    ApplyCmvn(&wav_feats);
 
     int32_t feat_dim = lfr_window_size*in_feat_dim;
     int32_t num_frames = wav_feats.size() / feat_dim;
@@ -238,7 +238,7 @@ string ModelImp::forward(float* din, int len, int flag)
         int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
         float* floatData = outputTensor[0].GetTensorMutableData<float>();
         auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
-        result = greedy_search(floatData, *encoder_out_lens);
+        result = GreedySearch(floatData, *encoder_out_lens);
     }
     catch (std::exception const &e)
     {
@@ -248,14 +248,14 @@ string ModelImp::forward(float* din, int len, int flag)
     return result;
 }
 
-string ModelImp::forward_chunk(float* din, int len, int flag)
+string Paraformer::ForwardChunk(float* din, int len, int flag)
 {
 
     printf("Not Imp!!!!!!\n");
     return "Hello";
 }
 
-string ModelImp::rescoring()
+string Paraformer::Rescoring()
 {
     printf("Not Imp!!!!!!\n");
     return "Hello";

+ 53 - 0
funasr/runtime/onnxruntime/src/paraformer.h

@@ -0,0 +1,53 @@
+#pragma once
+
+
+#ifndef PARAFORMER_MODELIMP_H
+#define PARAFORMER_MODELIMP_H
+
+#include "precomp.h"
+
+namespace paraformer {
+
+    class Paraformer : public Model {
+    private:
+        //std::unique_ptr<knf::OnlineFbank> fbank_;
+        knf::FbankOptions fbank_opts;
+
+        std::unique_ptr<FsmnVad> vad_handle;
+        std::unique_ptr<CTTransformer> punc_handle;
+
+        Vocab* vocab;
+        vector<float> means_list;
+        vector<float> vars_list;
+        const float scale = 22.6274169979695;
+        int32_t lfr_window_size = 7;
+        int32_t lfr_window_shift = 6;
+
+        void LoadCmvn(const char *filename);
+        vector<float> ApplyLfr(const vector<float> &in);
+        void ApplyCmvn(vector<float> *v);
+
+        string GreedySearch( float* in, int n_len);
+
+        std::shared_ptr<Ort::Session> m_session;
+        Ort::Env env_;
+        Ort::SessionOptions session_options;
+
+        vector<string> m_strInputNames, m_strOutputNames;
+        vector<const char*> m_szInputNames;
+        vector<const char*> m_szOutputNames;
+
+    public:
+        Paraformer(const char* path, int thread_num=0, bool quantize=false, bool use_vad=false, bool use_punc=false);
+        ~Paraformer();
+        void Reset();
+        vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
+        string ForwardChunk(float* din, int len, int flag);
+        string Forward(float* din, int len, int flag);
+        string Rescoring();
+        std::vector<std::vector<int>> VadSeg(std::vector<float>& pcm_data);
+        string AddPunc(const char* sz_input);
+    };
+
+} // namespace paraformer
+#endif

+ 0 - 54
funasr/runtime/onnxruntime/src/paraformer_onnx.h

@@ -1,54 +0,0 @@
-#pragma once
-
-
-#ifndef PARAFORMER_MODELIMP_H
-#define PARAFORMER_MODELIMP_H
-
-#include "precomp.h"
-
-namespace paraformer {
-
-    class ModelImp : public Model {
-    private:
-        //std::unique_ptr<knf::OnlineFbank> fbank_;
-        knf::FbankOptions fbank_opts;
-
-        std::unique_ptr<FsmnVad> vadHandle;
-        std::unique_ptr<CTTransformer> puncHandle;
-
-        Vocab* vocab;
-        vector<float> means_list;
-        vector<float> vars_list;
-        const float scale = 22.6274169979695;
-        int32_t lfr_window_size = 7;
-        int32_t lfr_window_shift = 6;
-
-        void load_cmvn(const char *filename);
-        vector<float> ApplyLFR(const vector<float> &in);
-        void ApplyCMVN(vector<float> *v);
-
-        string greedy_search( float* in, int nLen);
-
-        std::shared_ptr<Ort::Session> m_session;
-        Ort::Env env_;
-        Ort::SessionOptions sessionOptions;
-
-        vector<string> m_strInputNames, m_strOutputNames;
-        vector<const char*> m_szInputNames;
-        vector<const char*> m_szOutputNames;
-
-    public:
-        ModelImp(const char* path, int nNumThread=0, bool quantize=false, bool use_vad=false, bool use_punc=false);
-        ~ModelImp();
-        void reset();
-        vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
-        string forward_chunk(float* din, int len, int flag);
-        string forward(float* din, int len, int flag);
-        string rescoring();
-        std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data);
-        string AddPunc(const char* szInput);
-
-    };
-
-} // namespace paraformer
-#endif

+ 1 - 1
funasr/runtime/onnxruntime/src/precomp.h

@@ -39,7 +39,7 @@ using namespace std;
 #include "util.h"
 #include "resample.h"
 #include "model.h"
-#include "paraformer_onnx.h"
+#include "paraformer.h"
 #include "libfunasrapi.h"
 
 using namespace paraformer;

+ 1 - 1
funasr/runtime/onnxruntime/src/tensor.h

@@ -71,7 +71,7 @@ template <typename T> void Tensor<T>::alloc_buff()
 {
     buff_size = size[0] * size[1] * size[2] * size[3];
     mem_size = buff_size;
-    buff = (T *)aligned_malloc(32, buff_size * sizeof(T));
+    buff = (T *)AlignedMalloc(32, buff_size * sizeof(T));
 }
 
 template <typename T> void Tensor<T>::free_buff()

+ 46 - 46
funasr/runtime/onnxruntime/src/tokenizer.cpp

@@ -1,26 +1,26 @@
  #include "precomp.h"
 
-CTokenizer::CTokenizer(const char* szYmlFile):m_Ready(false)
+CTokenizer::CTokenizer(const char* sz_yamlfile):m_ready(false)
 {
-	OpenYaml(szYmlFile);
+	OpenYaml(sz_yamlfile);
 }
 
-CTokenizer::CTokenizer():m_Ready(false)
+CTokenizer::CTokenizer():m_ready(false)
 {
 }
 
-void CTokenizer::read_yml(const YAML::Node& node) 
+void CTokenizer::ReadYaml(const YAML::Node& node) 
 {
 	if (node.IsMap()) 
 	{//��map��
 		for (auto it = node.begin(); it != node.end(); ++it) 
 		{
-			read_yml(it->second);
+			ReadYaml(it->second);
 		}
 	}
 	if (node.IsSequence()) {//��������
 		for (size_t i = 0; i < node.size(); ++i) {
-			read_yml(node[i]);
+			ReadYaml(node[i]);
 		}
 	}
 	if (node.IsScalar()) {//�DZ�����
@@ -28,9 +28,9 @@ void CTokenizer::read_yml(const YAML::Node& node)
 	}
 }
 
-bool CTokenizer::OpenYaml(const char* szYmlFile)
+bool CTokenizer::OpenYaml(const char* sz_yamlfile)
 {
-	YAML::Node m_Config = YAML::LoadFile(szYmlFile);
+	YAML::Node m_Config = YAML::LoadFile(sz_yamlfile);
 	if (m_Config.IsNull())
 		return false;
 	try
@@ -42,8 +42,8 @@ bool CTokenizer::OpenYaml(const char* szYmlFile)
 			{
 				if (Tokens[i].IsScalar())
 				{
-					m_ID2Token.push_back(Tokens[i].as<string>());
-					m_Token2ID.insert(make_pair<string, int>(Tokens[i].as<string>(), i));
+					m_id2token.push_back(Tokens[i].as<string>());
+					m_token2id.insert(make_pair<string, int>(Tokens[i].as<string>(), i));
 				}
 			}
 		}
@@ -54,8 +54,8 @@ bool CTokenizer::OpenYaml(const char* szYmlFile)
 			{
 				if (Puncs[i].IsScalar())
 				{ 
-					m_ID2Punc.push_back(Puncs[i].as<string>());
-					m_Punc2ID.insert(make_pair<string, int>(Puncs[i].as<string>(), i));
+					m_id2punc.push_back(Puncs[i].as<string>());
+					m_punc2id.insert(make_pair<string, int>(Puncs[i].as<string>(), i));
 				}
 			}
 		}
@@ -64,87 +64,87 @@ bool CTokenizer::OpenYaml(const char* szYmlFile)
 		std::cout << "read error!" << std::endl;
 		return  false;
 	}
-	m_Ready = true;
-	return m_Ready;
+	m_ready = true;
+	return m_ready;
 }
 
-vector<string> CTokenizer::ID2String(vector<int> Input)
+vector<string> CTokenizer::Id2String(vector<int> input)
 {
 	vector<string> result;
-	for (auto& item : Input)
+	for (auto& item : input)
 	{
-		result.push_back(m_ID2Token[item]);
+		result.push_back(m_id2token[item]);
 	}
 	return result;
 }
 
-int CTokenizer::String2ID(string Input)
+int CTokenizer::String2Id(string input)
 {
 	int nID = 0; // <blank>
-	if (m_Token2ID.find(Input) != m_Token2ID.end())
-		nID=(m_Token2ID[Input]);
+	if (m_token2id.find(input) != m_token2id.end())
+		nID=(m_token2id[input]);
 	else
-		nID=(m_Token2ID[UNK_CHAR]);
+		nID=(m_token2id[UNK_CHAR]);
 	return nID;
 }
 
-vector<int> CTokenizer::String2IDs(vector<string> Input)
+vector<int> CTokenizer::String2Ids(vector<string> input)
 {
 	vector<int> result;
-	for (auto& item : Input)
+	for (auto& item : input)
 	{	
 		transform(item.begin(), item.end(), item.begin(), ::tolower);
-		if (m_Token2ID.find(item) != m_Token2ID.end())
-			result.push_back(m_Token2ID[item]);
+		if (m_token2id.find(item) != m_token2id.end())
+			result.push_back(m_token2id[item]);
 		else
-			result.push_back(m_Token2ID[UNK_CHAR]);
+			result.push_back(m_token2id[UNK_CHAR]);
 	}
 	return result;
 }
 
-vector<string> CTokenizer::ID2Punc(vector<int> Input)
+vector<string> CTokenizer::Id2Punc(vector<int> input)
 {
 	vector<string> result;
-	for (auto& item : Input)
+	for (auto& item : input)
 	{
-		result.push_back(m_ID2Punc[item]);
+		result.push_back(m_id2punc[item]);
 	}
 	return result;
 }
 
-string CTokenizer::ID2Punc(int nPuncID)
+string CTokenizer::Id2Punc(int n_punc_id)
 {
-	return m_ID2Punc[nPuncID];
+	return m_id2punc[n_punc_id];
 }
 
-vector<int> CTokenizer::Punc2IDs(vector<string> Input)
+vector<int> CTokenizer::Punc2Ids(vector<string> input)
 {
 	vector<int> result;
-	for (auto& item : Input)
+	for (auto& item : input)
 	{
-		result.push_back(m_Punc2ID[item]);
+		result.push_back(m_punc2id[item]);
 	}
 	return result;
 }
 
-vector<string> CTokenizer::SplitChineseString(const string & strInfo)
+vector<string> CTokenizer::SplitChineseString(const string & str_info)
 {
 	vector<string> list;
-	int strSize = strInfo.size();
+	int strSize = str_info.size();
 	int i = 0;
 
 	while (i < strSize) {
 		int len = 1;
-		for (int j = 0; j < 6 && (strInfo[i] & (0x80 >> j)); j++) {
+		for (int j = 0; j < 6 && (str_info[i] & (0x80 >> j)); j++) {
 			len = j + 1;
 		}
-		list.push_back(strInfo.substr(i, len));
+		list.push_back(str_info.substr(i, len));
 		i += len;
 	}
 	return list;
 }
 
-void CTokenizer::strSplit(const string& str, const char split, vector<string>& res)
+void CTokenizer::StrSplit(const string& str, const char split, vector<string>& res)
 {
 	if (str == "")
 	{
@@ -161,10 +161,10 @@ void CTokenizer::strSplit(const string& str, const char split, vector<string>& r
 	}
 }
 
- void CTokenizer::Tokenize(const char* strInfo, vector<string> & strOut, vector<int> & IDOut)
+ void CTokenizer::Tokenize(const char* str_info, vector<string> & str_out, vector<int> & id_out)
 {
 	vector<string>  strList;
-	strSplit(strInfo,' ', strList);
+	StrSplit(str_info,' ', strList);
 	string current_eng,current_chinese;
 	for (auto& item : strList)
 	{
@@ -178,7 +178,7 @@ void CTokenizer::strSplit(const string& str, const char split, vector<string>& r
 				{
 					// for utf-8 chinese
 					auto chineseList = SplitChineseString(current_chinese);
-					strOut.insert(strOut.end(), chineseList.begin(),chineseList.end());
+					str_out.insert(str_out.end(), chineseList.begin(),chineseList.end());
 					current_chinese = "";
 				}
 				current_eng += ch;
@@ -187,7 +187,7 @@ void CTokenizer::strSplit(const string& str, const char split, vector<string>& r
 			{
 				if (current_eng.size() > 0)
 				{
-					strOut.push_back(current_eng);
+					str_out.push_back(current_eng);
 					current_eng = "";
 				}
 				current_chinese += ch;
@@ -196,13 +196,13 @@ void CTokenizer::strSplit(const string& str, const char split, vector<string>& r
 		if (current_chinese.size() > 0)
 		{
 			auto chineseList = SplitChineseString(current_chinese);
-			strOut.insert(strOut.end(), chineseList.begin(), chineseList.end());
+			str_out.insert(str_out.end(), chineseList.begin(), chineseList.end());
 			current_chinese = "";
 		}
 		if (current_eng.size() > 0)
 		{
-			strOut.push_back(current_eng);
+			str_out.push_back(current_eng);
 		}
 	}
-	IDOut= String2IDs(strOut);
+	id_out= String2Ids(str_out);
 }

+ 15 - 15
funasr/runtime/onnxruntime/src/tokenizer.h

@@ -4,24 +4,24 @@
 class CTokenizer {
 private:
 
-	bool  m_Ready = false;
-	vector<string>   m_ID2Token,m_ID2Punc;
-	map<string, int>  m_Token2ID,m_Punc2ID;
+	bool  m_ready = false;
+	vector<string>   m_id2token,m_id2punc;
+	map<string, int>  m_token2id,m_punc2id;
 
 public:
 
-	CTokenizer(const char* szYmlFile);
+	CTokenizer(const char* sz_yamlfile);
 	CTokenizer();
-	bool OpenYaml(const char* szYmlFile);
-	void read_yml(const YAML::Node& node);
-	vector<string> ID2String(vector<int> Input);
-	vector<int> String2IDs(vector<string> Input);
-	int String2ID(string Input);
-	vector<string> ID2Punc(vector<int> Input);
-	string ID2Punc(int nPuncID);
-	vector<int> Punc2IDs(vector<string> Input);
-	vector<string> SplitChineseString(const string& strInfo);
-	void strSplit(const string& str, const char split, vector<string>& res);
-	void Tokenize(const char* strInfo, vector<string>& strOut, vector<int>& IDOut);
+	bool OpenYaml(const char* sz_yamlfile);
+	void ReadYaml(const YAML::Node& node);
+	vector<string> Id2String(vector<int> input);
+	vector<int> String2Ids(vector<string> input);
+	int String2Id(string input);
+	vector<string> Id2Punc(vector<int> input);
+	string Id2Punc(int n_punc_id);
+	vector<int> Punc2Ids(vector<string> input);
+	vector<string> SplitChineseString(const string& str_info);
+	void StrSplit(const string& str, const char split, vector<string>& res);
+	void Tokenize(const char* str_info, vector<string>& str_out, vector<int>& id_out);
 
 };

+ 14 - 14
funasr/runtime/onnxruntime/src/util.cpp

@@ -1,7 +1,7 @@
 
 #include "precomp.h"
 
-float *loadparams(const char *filename)
+float *LoadParams(const char *filename)
 {
 
     FILE *fp;
@@ -10,20 +10,20 @@ float *loadparams(const char *filename)
     uint32_t nFileLen = ftell(fp);
     fseek(fp, 0, SEEK_SET);
 
-    float *params_addr = (float *)aligned_malloc(32, nFileLen);
+    float *params_addr = (float *)AlignedMalloc(32, nFileLen);
     int n = fread(params_addr, 1, nFileLen, fp);
     fclose(fp);
 
     return params_addr;
 }
 
-int val_align(int val, int align)
+int ValAlign(int val, int align)
 {
     float tmp = ceil((float)val / (float)align) * (float)align;
     return (int)tmp;
 }
 
-void disp_params(float *din, int size)
+void DispParams(float *din, int size)
 {
     int i;
     for (i = 0; i < size; i++) {
@@ -39,7 +39,7 @@ void SaveDataFile(const char *filename, void *data, uint32_t len)
     fclose(fp);
 }
 
-void basic_norm(Tensor<float> *&din, float norm)
+void BasicNorm(Tensor<float> *&din, float norm)
 {
 
     int Tmax = din->size[2];
@@ -59,7 +59,7 @@ void basic_norm(Tensor<float> *&din, float norm)
     }
 }
 
-void findmax(float *din, int len, float &max_val, int &max_idx)
+void FindMax(float *din, int len, float &max_val, int &max_idx)
 {
     int i;
     max_val = -INFINITY;
@@ -72,7 +72,7 @@ void findmax(float *din, int len, float &max_val, int &max_idx)
     }
 }
 
-string pathAppend(const string &p1, const string &p2)
+string PathAppend(const string &p1, const string &p2)
 {
 
     char sep = '/';
@@ -89,7 +89,7 @@ string pathAppend(const string &p1, const string &p2)
         return (p1 + p2);
 }
 
-void relu(Tensor<float> *din)
+void Relu(Tensor<float> *din)
 {
     int i;
     for (i = 0; i < din->buff_size; i++) {
@@ -98,7 +98,7 @@ void relu(Tensor<float> *din)
     }
 }
 
-void swish(Tensor<float> *din)
+void Swish(Tensor<float> *din)
 {
     int i;
     for (i = 0; i < din->buff_size; i++) {
@@ -107,7 +107,7 @@ void swish(Tensor<float> *din)
     }
 }
 
-void sigmoid(Tensor<float> *din)
+void Sigmoid(Tensor<float> *din)
 {
     int i;
     for (i = 0; i < din->buff_size; i++) {
@@ -116,7 +116,7 @@ void sigmoid(Tensor<float> *din)
     }
 }
 
-void doubleswish(Tensor<float> *din)
+void DoubleSwish(Tensor<float> *din)
 {
     int i;
     for (i = 0; i < din->buff_size; i++) {
@@ -125,7 +125,7 @@ void doubleswish(Tensor<float> *din)
     }
 }
 
-void softmax(float *din, int mask, int len)
+void Softmax(float *din, int mask, int len)
 {
     float *tmp = (float *)malloc(mask * sizeof(float));
     int i;
@@ -149,7 +149,7 @@ void softmax(float *din, int mask, int len)
     }
 }
 
-void log_softmax(float *din, int len)
+void LogSoftmax(float *din, int len)
 {
     float *tmp = (float *)malloc(len * sizeof(float));
     int i;
@@ -164,7 +164,7 @@ void log_softmax(float *din, int len)
     free(tmp);
 }
 
-void glu(Tensor<float> *din, Tensor<float> *dout)
+void Glu(Tensor<float> *din, Tensor<float> *dout)
 {
     int mm = din->buff_size / 1024;
     int i, j;

+ 13 - 13
funasr/runtime/onnxruntime/src/util.h

@@ -5,26 +5,26 @@
 
 using namespace std;
 
-extern float *loadparams(const char *filename);
+extern float *LoadParams(const char *filename);
 
 extern void SaveDataFile(const char *filename, void *data, uint32_t len);
-extern void relu(Tensor<float> *din);
-extern void swish(Tensor<float> *din);
-extern void sigmoid(Tensor<float> *din);
-extern void doubleswish(Tensor<float> *din);
+extern void Relu(Tensor<float> *din);
+extern void Swish(Tensor<float> *din);
+extern void Sigmoid(Tensor<float> *din);
+extern void DoubleSwish(Tensor<float> *din);
 
-extern void softmax(float *din, int mask, int len);
+extern void Softmax(float *din, int mask, int len);
 
-extern void log_softmax(float *din, int len);
-extern int val_align(int val, int align);
-extern void disp_params(float *din, int size);
+extern void LogSoftmax(float *din, int len);
+extern int ValAlign(int val, int align);
+extern void DispParams(float *din, int size);
 
-extern void basic_norm(Tensor<float> *&din, float norm);
+extern void BasicNorm(Tensor<float> *&din, float norm);
 
-extern void findmax(float *din, int len, float &max_val, int &max_idx);
+extern void FindMax(float *din, int len, float &max_val, int &max_idx);
 
-extern void glu(Tensor<float> *din, Tensor<float> *dout);
+extern void Glu(Tensor<float> *din, Tensor<float> *dout);
 
-string pathAppend(const string &p1, const string &p2);
+string PathAppend(const string &p1, const string &p2);
 
 #endif

+ 9 - 29
funasr/runtime/onnxruntime/src/vocab.cpp

@@ -12,13 +12,13 @@ using namespace std;
 Vocab::Vocab(const char *filename)
 {
     ifstream in(filename);
-    loadVocabFromYaml(filename);
+    LoadVocabFromYaml(filename);
 }
 Vocab::~Vocab()
 {
 }
 
-void Vocab::loadVocabFromYaml(const char* filename){
+void Vocab::LoadVocabFromYaml(const char* filename){
     YAML::Node config;
     try{
         config = YAML::LoadFile(filename);
@@ -26,72 +26,62 @@ void Vocab::loadVocabFromYaml(const char* filename){
         printf("error loading file, yaml file error or not exist.\n");
         exit(-1);
     }
-
     YAML::Node myList = config["token_list"];
     for (YAML::const_iterator it = myList.begin(); it != myList.end(); ++it) {
         vocab.push_back(it->as<string>());
     }
 }
 
-string Vocab::vector2string(vector<int> in)
+string Vocab::Vector2String(vector<int> in)
 {
     int i;
     stringstream ss;
     for (auto it = in.begin(); it != in.end(); it++) {
         ss << vocab[*it];
     }
-
     return ss.str();
 }
 
-int str2int(string str)
+int Str2Int(string str)
 {
     const char *ch_array = str.c_str();
     if (((ch_array[0] & 0xf0) != 0xe0) || ((ch_array[1] & 0xc0) != 0x80) ||
         ((ch_array[2] & 0xc0) != 0x80))
         return 0;
-
     int val = ((ch_array[0] & 0x0f) << 12) | ((ch_array[1] & 0x3f) << 6) |
               (ch_array[2] & 0x3f);
     return val;
 }
 
-bool Vocab::isChinese(string ch)
+bool Vocab::IsChinese(string ch)
 {
     if (ch.size() != 3) {
         return false;
     }
-
-    int unicode = str2int(ch);
+    int unicode = Str2Int(ch);
     if (unicode >= 19968 && unicode <= 40959) {
         return true;
     }
-
     return false;
 }
 
-string Vocab::vector2stringV2(vector<int> in)
+string Vocab::Vector2StringV2(vector<int> in)
 {
     int i;
     list<string> words;
-
     int is_pre_english = false;
     int pre_english_len = 0;
-
     int is_combining = false;
     string combine = "";
 
     for (auto it = in.begin(); it != in.end(); it++) {
         string word = vocab[*it];
-
         // step1 space character skips
         if (word == "<s>" || word == "</s>" || word == "<unk>")
             continue;
-
         // step2 combie phoneme to full word
         {
             int sub_word = !(word.find("@@") == string::npos);
-
             // process word start and middle part
             if (sub_word) {
                 combine += word.erase(word.length() - 2);
@@ -109,15 +99,13 @@ string Vocab::vector2stringV2(vector<int> in)
 
         // step3 process english word deal with space , turn abbreviation to upper case
         {
-
             // input word is chinese, not need process 
-            if (isChinese(word)) {
+            if (IsChinese(word)) {
                 words.push_back(word);
                 is_pre_english = false;
             }
             // input word is english word
             else {
-
                 // pre word is chinese
                 if (!is_pre_english) {
                     word[0] = word[0] - 32;
@@ -125,10 +113,8 @@ string Vocab::vector2stringV2(vector<int> in)
                     pre_english_len = word.size();
 
                 }
-
                 // pre word is english word
                 else {
-
                     // single letter turn to upper case
                     if (word.size() == 1) {
                         word[0] = word[0] - 32;
@@ -147,17 +133,11 @@ string Vocab::vector2stringV2(vector<int> in)
                         pre_english_len = word.size();
                     }
                 }
-
                 is_pre_english = true;
-
             }
         }
     }
 
-    // for (auto it = words.begin(); it != words.end(); it++) {
-    //     cout << *it << endl;
-    // }
-
     stringstream ss;
     for (auto it = words.begin(); it != words.end(); it++) {
         ss << *it;
@@ -166,7 +146,7 @@ string Vocab::vector2stringV2(vector<int> in)
     return ss.str();
 }
 
-int Vocab::size()
+int Vocab::Size()
 {
     return vocab.size();
 }

+ 6 - 6
funasr/runtime/onnxruntime/src/vocab.h

@@ -10,16 +10,16 @@ using namespace std;
 class Vocab {
   private:
     vector<string> vocab;
-    bool isChinese(string ch);
-    bool isEnglish(string ch);
-    void loadVocabFromYaml(const char* filename);
+    bool IsChinese(string ch);
+    bool IsEnglish(string ch);
+    void LoadVocabFromYaml(const char* filename);
 
   public:
     Vocab(const char *filename);
     ~Vocab();
-    int size();
-    string vector2string(vector<int> in);
-    string vector2stringV2(vector<int> in);
+    int Size();
+    string Vector2String(vector<int> in);
+    string Vector2StringV2(vector<int> in);
 };
 
 #endif