lyblsgo пре 2 година
родитељ
комит
b6d0ab4bfb

+ 0 - 1
funasr/runtime/onnxruntime/include/com-define.h

@@ -28,7 +28,6 @@
 // punc
 // punc
 #define PUNC_MODEL_FILE  "punc_model.onnx"
 #define PUNC_MODEL_FILE  "punc_model.onnx"
 #define PUNC_YAML_FILE "punc.yaml"
 #define PUNC_YAML_FILE "punc.yaml"
-
 #define UNK_CHAR "<unk>"
 #define UNK_CHAR "<unk>"
 
 
 #define  INPUT_NUM  2
 #define  INPUT_NUM  2

+ 1 - 8
funasr/runtime/onnxruntime/include/libfunasrapi.h

@@ -51,21 +51,14 @@ _FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* sz_model_dir, int thread_num, b
 
 
 // if not give a fn_callback ,it should be NULL 
 // if not give a fn_callback ,it should be NULL 
 _FUNASRAPI FUNASR_RESULT	FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
 _FUNASRAPI FUNASR_RESULT	FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
-
 _FUNASRAPI FUNASR_RESULT	FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
 _FUNASRAPI FUNASR_RESULT	FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
-
 _FUNASRAPI FUNASR_RESULT	FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
 _FUNASRAPI FUNASR_RESULT	FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
-
 _FUNASRAPI FUNASR_RESULT	FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
 _FUNASRAPI FUNASR_RESULT	FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
 
 
 _FUNASRAPI const char*	FunASRGetResult(FUNASR_RESULT result,int n_index);
 _FUNASRAPI const char*	FunASRGetResult(FUNASR_RESULT result,int n_index);
-
-_FUNASRAPI const int		FunASRGetRetNumber(FUNASR_RESULT result);
-
+_FUNASRAPI const int	FunASRGetRetNumber(FUNASR_RESULT result);
 _FUNASRAPI void			FunASRFreeResult(FUNASR_RESULT result);
 _FUNASRAPI void			FunASRFreeResult(FUNASR_RESULT result);
-
 _FUNASRAPI void			FunASRUninit(FUNASR_HANDLE handle);
 _FUNASRAPI void			FunASRUninit(FUNASR_HANDLE handle);
-
 _FUNASRAPI const float	FunASRGetRetSnippetTime(FUNASR_RESULT result);
 _FUNASRAPI const float	FunASRGetRetSnippetTime(FUNASR_RESULT result);
 
 
 #ifdef __cplusplus 
 #ifdef __cplusplus 

+ 0 - 2
funasr/runtime/onnxruntime/src/alignedmem.h

@@ -2,8 +2,6 @@
 #ifndef ALIGNEDMEM_H
 #ifndef ALIGNEDMEM_H
 #define ALIGNEDMEM_H
 #define ALIGNEDMEM_H
 
 
-
-
 extern void *AlignedMalloc(size_t alignment, size_t required_bytes);
 extern void *AlignedMalloc(size_t alignment, size_t required_bytes);
 extern void AlignedFree(void *p);
 extern void AlignedFree(void *p);
 
 

+ 0 - 2
funasr/runtime/onnxruntime/src/commonfunc.h

@@ -33,7 +33,6 @@ inline void GetInputName(Ort::Session* session, string& inputName,int nIndex=0)
         {
         {
             auto t = session->GetInputNameAllocated(nIndex, allocator);
             auto t = session->GetInputNameAllocated(nIndex, allocator);
             inputName = t.get();
             inputName = t.get();
-
         }
         }
     }
     }
 }
 }
@@ -45,7 +44,6 @@ inline void GetOutputName(Ort::Session* session, string& outputName, int nIndex
         {
         {
             auto t = session->GetOutputNameAllocated(nIndex, allocator);
             auto t = session->GetOutputNameAllocated(nIndex, allocator);
             outputName = t.get();
             outputName = t.get();
-
         }
         }
     }
     }
 }
 }

+ 0 - 1
funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp

@@ -58,7 +58,6 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list,
         }else{
         }else{
             cout <<"No return data!";
             cout <<"No return data!";
         }
         }
-
     }
     }
     {
     {
         lock_guard<mutex> guard(mtx);
         lock_guard<mutex> guard(mtx);

+ 0 - 2
funasr/runtime/onnxruntime/src/model.cpp

@@ -3,8 +3,6 @@
 Model *CreateModel(const char *path, int thread_num, bool quantize, bool use_vad, bool use_punc)
 Model *CreateModel(const char *path, int thread_num, bool quantize, bool use_vad, bool use_punc)
 {
 {
     Model *mm;
     Model *mm;
-
     mm = new paraformer::Paraformer(path, thread_num, quantize, use_vad, use_punc);
     mm = new paraformer::Paraformer(path, thread_num, quantize, use_vad, use_punc);
-
     return mm;
     return mm;
 }
 }

+ 1 - 4
funasr/runtime/onnxruntime/src/online-feature.h

@@ -12,15 +12,12 @@ public:
 
 
   void ExtractFeats(vector<vector<float>> &vad_feats, vector<float> waves, bool input_finished);
   void ExtractFeats(vector<vector<float>> &vad_feats, vector<float> waves, bool input_finished);
 
 
-
 private:
 private:
   void OnlineFbank(vector<vector<float>> &vad_feats, vector<float> &waves);
   void OnlineFbank(vector<vector<float>> &vad_feats, vector<float> &waves);
-
   int OnlineLfrCmvn(vector<vector<float>> &vad_feats);
   int OnlineLfrCmvn(vector<vector<float>> &vad_feats);
-
+  
   static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
   static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
     int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
     int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
-
     if (frame_num >= 1 && sample_length >= frame_sample_length)
     if (frame_num >= 1 && sample_length >= frame_sample_length)
       return frame_num;
       return frame_num;
     else
     else

+ 3 - 3
funasr/runtime/onnxruntime/src/paraformer.cpp

@@ -143,14 +143,14 @@ void Paraformer::LoadCmvn(const char *filename)
     }
     }
 }
 }
 
 
-string Paraformer::GreedySearch(float * in, int n_len )
+string Paraformer::GreedySearch(float * in, int n_len,  int64_t token_nums)
 {
 {
     vector<int> hyps;
     vector<int> hyps;
     int Tmax = n_len;
     int Tmax = n_len;
     for (int i = 0; i < Tmax; i++) {
     for (int i = 0; i < Tmax; i++) {
         int max_idx;
         int max_idx;
         float max_val;
         float max_val;
-        FindMax(in + i * 8404, 8404, max_val, max_idx);
+        FindMax(in + i * token_nums, token_nums, max_val, max_idx);
         hyps.push_back(max_idx);
         hyps.push_back(max_idx);
     }
     }
 
 
@@ -238,7 +238,7 @@ string Paraformer::Forward(float* din, int len, int flag)
         int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
         int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
         float* floatData = outputTensor[0].GetTensorMutableData<float>();
         float* floatData = outputTensor[0].GetTensorMutableData<float>();
         auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
         auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
-        result = GreedySearch(floatData, *encoder_out_lens);
+        result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
     }
     }
     catch (std::exception const &e)
     catch (std::exception const &e)
     {
     {

+ 6 - 1
funasr/runtime/onnxruntime/src/paraformer.h

@@ -9,6 +9,11 @@
 namespace paraformer {
 namespace paraformer {
 
 
     class Paraformer : public Model {
     class Paraformer : public Model {
+    /**
+     * Author: Speech Lab of DAMO Academy, Alibaba Group
+     * Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+     * https://arxiv.org/pdf/2206.08317.pdf
+    */
     private:
     private:
         //std::unique_ptr<knf::OnlineFbank> fbank_;
         //std::unique_ptr<knf::OnlineFbank> fbank_;
         knf::FbankOptions fbank_opts;
         knf::FbankOptions fbank_opts;
@@ -27,7 +32,7 @@ namespace paraformer {
         vector<float> ApplyLfr(const vector<float> &in);
         vector<float> ApplyLfr(const vector<float> &in);
         void ApplyCmvn(vector<float> *v);
         void ApplyCmvn(vector<float> *v);
 
 
-        string GreedySearch( float* in, int n_len);
+        string GreedySearch( float* in, int n_len, int64_t token_nums);
 
 
         std::shared_ptr<Ort::Session> m_session;
         std::shared_ptr<Ort::Session> m_session;
         Ort::Env env_;
         Ort::Env env_;