雾聪 2 лет назад
Родитель
Сommit
125855d55d

+ 11 - 10
funasr/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp

@@ -32,10 +32,8 @@ bool is_target_file(const std::string& filename, const std::string target) {
 
 void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
 {
-    if (value_arg.isSet()){
-        model_path.insert({key, value_arg.getValue()});
-        LOG(INFO)<< key << " : " << value_arg.getValue();
-    }
+    model_path.insert({key, value_arg.getValue()});
+    LOG(INFO)<< key << " : " << value_arg.getValue();
 }
 
 int main(int argc, char** argv)
@@ -52,6 +50,7 @@ int main(int argc, char** argv)
     TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
     TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
     TCLAP::ValueArg<std::string>    asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
+    TCLAP::ValueArg<std::int32_t>   onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
 
     TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
 
@@ -64,6 +63,7 @@ int main(int argc, char** argv)
     cmd.add(punc_quant);
     cmd.add(wav_path);
     cmd.add(asr_mode);
+    cmd.add(onnx_thread);
     cmd.parse(argc, argv);
 
     std::map<std::string, std::string> model_path;
@@ -79,7 +79,7 @@ int main(int argc, char** argv)
 
     struct timeval start, end;
     gettimeofday(&start, NULL);
-    int thread_num = 1;
+    int thread_num = onnx_thread.getValue();
     int asr_mode_ = -1;
     if(model_path[ASR_MODE] == "offline"){
         asr_mode_ = 0;
@@ -132,14 +132,15 @@ int main(int argc, char** argv)
     }
 
     // init online features
-    FunTpassOnlineInit(tpass_hanlde);
+    std::vector<int> chunk_size = {5,10,5};
+    FunTpassOnlineInit(tpass_hanlde, chunk_size);
     float snippet_time = 0.0f;
     long taking_micros = 0;
     for (int i = 0; i < wav_list.size(); i++) {
         auto& wav_file = wav_list[i];
         auto& wav_id = wav_ids[i];
 
-        int32_t sampling_rate_ = -1;
+        int32_t sampling_rate_ = 16000;
         funasr::Audio audio(1);
 		if(is_target_file(wav_file.c_str(), "wav")){
 			if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
@@ -174,7 +175,7 @@ int main(int argc, char** argv)
                     is_final = false;
             }
             gettimeofday(&start, NULL);
-            FUNASR_RESULT result = FunTpassInferBuffer(tpass_hanlde, speech_buff+sample_offset, step, punc_cache, is_final, 16000, "pcm", (ASR_TYPE)asr_mode_);
+            FUNASR_RESULT result = FunTpassInferBuffer(tpass_hanlde, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_);
             gettimeofday(&end, NULL);
             seconds = (end.tv_sec - start.tv_sec);
             taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
@@ -196,10 +197,10 @@ int main(int argc, char** argv)
             }
         }
         if(asr_mode_ != 0){
-            LOG(INFO) << wav_id <<"Final online  results "<<" : "<<online_res;
+            LOG(INFO) << wav_id << " Final online  results "<<" : "<<online_res;
         }
         if(asr_mode_ != 1){
-            LOG(INFO) << wav_id << "Final offline results " <<" : "<<tpass_res;
+            LOG(INFO) << wav_id << " Final offline results " <<" : "<<tpass_res;
         }
     }
  

+ 1 - 1
funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp

@@ -159,7 +159,7 @@ int main(int argc, char *argv[])
         char* speech_buff = audio.GetSpeechChar();
         int buff_len = audio.GetSpeechLen()*2;
 
-        int step = 1600*2;
+        int step = 800*2;
         bool is_final = false;
 
         for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {

+ 1 - 1
funasr/runtime/onnxruntime/include/audio.h

@@ -72,7 +72,7 @@ class Audio {
     void Padding();
     void Split(OfflineStream* offline_streamj);
     void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
-    void Split(VadModel* vad_obj, bool input_finished=true, ASR_TYPE asr_mode=ASR_TWO_PASS);
+    void Split(VadModel* vad_obj, int chunk_len, bool input_finished=true, ASR_TYPE asr_mode=ASR_TWO_PASS);
     float GetTimeLen();
     int GetQueueSize() { return (int)frame_queue.size(); }
     char* GetSpeechChar(){return speech_char;}

+ 2 - 2
funasr/runtime/onnxruntime/include/funasrruntime.h

@@ -61,7 +61,7 @@ typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps
 
 // ASR
 _FUNASRAPI FUNASR_HANDLE  	FunASRInit(std::map<std::string, std::string>& model_path, int thread_num, ASR_TYPE type=ASR_OFFLINE);
-_FUNASRAPI FUNASR_HANDLE  	FunASROnlineInit(FUNASR_HANDLE asr_handle);
+_FUNASRAPI FUNASR_HANDLE  	FunASROnlineInit(FUNASR_HANDLE asr_handle, std::vector<int> chunk_size={5,10,5});
 // buffer
 _FUNASRAPI FUNASR_RESULT	FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000, std::string wav_format="pcm");
 // file, support wav & pcm
@@ -104,7 +104,7 @@ _FUNASRAPI void				FunOfflineUninit(FUNASR_HANDLE handle);
 
 //2passStream
 _FUNASRAPI FUNASR_HANDLE  	FunTpassInit(std::map<std::string, std::string>& model_path, int thread_num);
-_FUNASRAPI void  	        FunTpassOnlineInit(FUNASR_HANDLE tpass_handle);
+_FUNASRAPI void  	        FunTpassOnlineInit(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size={5,10,5});
 // buffer
 _FUNASRAPI FUNASR_RESULT	FunTpassInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished=true, int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS);
 _FUNASRAPI void				FunTpassUninit(FUNASR_HANDLE handle);

+ 1 - 1
funasr/runtime/onnxruntime/include/model.h

@@ -18,7 +18,7 @@ class Model {
 };
 
 Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
-Model *CreateModel(void* asr_handle);
+Model *CreateModel(void* asr_handle, std::vector<int> chunk_size);
 
 } // namespace funasr
 #endif

+ 1 - 1
funasr/runtime/onnxruntime/include/tpass-stream.h

@@ -29,6 +29,6 @@ class TpassStream {
 };
 
 TpassStream *CreateTpassStream(std::map<std::string, std::string>& model_path, int thread_num=1);
-void CreateTpassOnlineStream(void* tpass_stream);
+void CreateTpassOnlineStream(void* tpass_stream, std::vector<int> chunk_size);
 } // namespace funasr
 #endif

+ 4 - 4
funasr/runtime/onnxruntime/src/audio.cpp

@@ -1056,7 +1056,7 @@ void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, boo
 }
 
 // 2pass
-void Audio::Split(VadModel* vad_obj, bool input_finished, ASR_TYPE asr_mode)
+void Audio::Split(VadModel* vad_obj, int chunk_len, bool input_finished, ASR_TYPE asr_mode)
 {
     AudioFrame *frame;
 
@@ -1075,7 +1075,7 @@ void Audio::Split(VadModel* vad_obj, bool input_finished, ASR_TYPE asr_mode)
             int start = speech_start*seg_sample;
             int end = speech_end*seg_sample;
             int buff_len = end-start;
-            int step = ONLINE_STEP;
+            int step = chunk_len;
 
             if(asr_mode != ASR_OFFLINE){
                 if(buff_len >= step){
@@ -1131,7 +1131,7 @@ void Audio::Split(VadModel* vad_obj, bool input_finished, ASR_TYPE asr_mode)
                 int start = speech_start*seg_sample;
                 int end = speech_end*seg_sample;
                 int buff_len = end-start;
-                int step = ONLINE_STEP;
+                int step = chunk_len;
 
                 if(asr_mode != ASR_OFFLINE){
                     if(buff_len >= step){
@@ -1154,7 +1154,7 @@ void Audio::Split(VadModel* vad_obj, bool input_finished, ASR_TYPE asr_mode)
                 int offline_start = speech_offline_start*seg_sample;
                 int end = speech_end_i*seg_sample;
                 int buff_len = end-start;
-                int step = ONLINE_STEP;
+                int step = chunk_len;
 
                 if(asr_mode != ASR_ONLINE){
                     frame = new AudioFrame(end-offline_start);

+ 5 - 4
funasr/runtime/onnxruntime/src/ct-transformer-online.cpp

@@ -181,11 +181,12 @@ vector<int> CTTransformerOnline::Infer(vector<int32_t> input_data, int nCacheSiz
         text_lengths_dim.size()); //, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
 
     //vad_mask
-    vector<float> arVadMask,arSubMask;
+    // vector<float> arVadMask,arSubMask;
+    vector<float> arVadMask;
     int nTextLength = input_data.size();
 
     VadMask(nTextLength, nCacheSize, arVadMask);
-    Triangle(nTextLength, arSubMask);
+    // Triangle(nTextLength, arSubMask);
     std::array<int64_t, 4> VadMask_Dim{ 1,1, nTextLength ,nTextLength };
     Ort::Value onnx_vad_mask = Ort::Value::CreateTensor<float>(
         m_memoryInfo,
@@ -198,8 +199,8 @@ vector<int> CTTransformerOnline::Infer(vector<int32_t> input_data, int nCacheSiz
     std::array<int64_t, 4> SubMask_Dim{ 1,1, nTextLength ,nTextLength };
     Ort::Value onnx_sub_mask = Ort::Value::CreateTensor<float>(
         m_memoryInfo,
-        arSubMask.data(),
-        arSubMask.size() ,
+        arVadMask.data(),
+        arVadMask.size(),
         SubMask_Dim.data(),
         SubMask_Dim.size()); // , ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
 

+ 1 - 1
funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp

@@ -55,7 +55,7 @@ void FsmnVadOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &
       int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
       int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
       int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats, input_finished);
-      int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
+      int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
       reserve_waveforms_.clear();
       reserve_waveforms_.insert(reserve_waveforms_.begin(),
                                 waves.begin() + reserve_frame_idx * frame_shift_sample_length_,

+ 20 - 8
funasr/runtime/onnxruntime/src/funasrruntime.cpp

@@ -11,9 +11,9 @@ extern "C" {
 		return mm;
 	}
 
-	_FUNASRAPI FUNASR_HANDLE  FunASROnlineInit(FUNASR_HANDLE asr_hanlde)
+	_FUNASRAPI FUNASR_HANDLE  FunASROnlineInit(FUNASR_HANDLE asr_hanlde, std::vector<int> chunk_size)
 	{
-		funasr::Model* mm = funasr::CreateModel(asr_hanlde);
+		funasr::Model* mm = funasr::CreateModel(asr_hanlde, chunk_size);
 		return mm;
 	}
 
@@ -47,9 +47,9 @@ extern "C" {
 		return mm;
 	}
 
-	_FUNASRAPI void FunTpassOnlineInit(FUNASR_HANDLE tpass_handle)
+	_FUNASRAPI void FunTpassOnlineInit(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size)
 	{
-		funasr::CreateTpassOnlineStream(tpass_handle);
+		funasr::CreateTpassOnlineStream(tpass_handle, chunk_size);
 	}
 
 	// APIs for ASR Infer
@@ -324,6 +324,7 @@ extern "C" {
 		funasr::Model* asr_online_handle = (tpass_stream->asr_online_handle).get();
 		if (!asr_online_handle)
 			return nullptr;
+		int chunk_len = ((funasr::ParaformerOnline*)asr_online_handle)->chunk_len;
 		
 		funasr::Model* asr_handle = (tpass_stream->asr_handle).get();
 		if (!asr_handle)
@@ -348,14 +349,25 @@ extern "C" {
 		if(p_result->snippet_time == 0){
 			return p_result;
 		}
-
-		audio->Split(vad_online_handle, input_finished, mode);
+		
+		audio->Split(vad_online_handle, chunk_len, input_finished, mode);
 
 		funasr::AudioFrame* frame = NULL;
 		while(audio->FetchChunck(frame) > 0){
 			string msg = asr_online_handle->Forward(frame->data, frame->len, frame->is_final);
-			string msg_punc = punc_online_handle->AddPunc(msg.c_str(), punc_cache[0]);
-			p_result->msg += msg_punc;
+			if(mode == ASR_ONLINE){
+				((funasr::ParaformerOnline*)asr_online_handle)->online_res += msg;
+				if(frame->is_final){
+					string online_msg = ((funasr::ParaformerOnline*)asr_online_handle)->online_res;
+					string msg_punc = punc_online_handle->AddPunc(online_msg.c_str(), punc_cache[0]);
+					p_result->tpass_msg = msg_punc;
+					((funasr::ParaformerOnline*)asr_online_handle)->online_res = "";
+				}else{
+					p_result->msg += msg;
+				}
+			}else if(mode == ASR_TWO_PASS && !(frame->is_final)){
+				p_result->msg += msg;
+			}
 			if(frame != NULL){
 				delete frame;
 				frame = NULL;

+ 2 - 2
funasr/runtime/onnxruntime/src/model.cpp

@@ -46,10 +46,10 @@ Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_nu
     }
 }
 
-Model *CreateModel(void* asr_handle)
+Model *CreateModel(void* asr_handle, std::vector<int> chunk_size)
 {
     Model* mm;
-    mm = new ParaformerOnline((Paraformer*)asr_handle);
+    mm = new ParaformerOnline((Paraformer*)asr_handle, chunk_size);
     return mm;
 }
 

+ 4 - 3
funasr/runtime/onnxruntime/src/paraformer-online.cpp

@@ -9,8 +9,8 @@ using namespace std;
 
 namespace funasr {
 
-ParaformerOnline::ParaformerOnline(Paraformer* para_handle)
-:para_handle_(std::move(para_handle)),session_options_{}{
+ParaformerOnline::ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size)
+:para_handle_(std::move(para_handle)),chunk_size(chunk_size),session_options_{}{
     InitOnline(
         para_handle_->fbank_opts_,
         para_handle_->encoder_session_,
@@ -61,6 +61,7 @@ void ParaformerOnline::InitOnline(
     for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
         fsmn_init_cache_.emplace_back(0);
     }
+    chunk_len = chunk_size[1]*frame_shift*lfr_n*MODEL_SAMPLE_RATE/1000;
 }
 
 void ParaformerOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &wav_feats,
@@ -109,7 +110,7 @@ void ParaformerOnline::ExtractFeats(float sample_rate, vector<std::vector<float>
         int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
         int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
         int lfr_splice_frame_idxs = OnlineLfrCmvn(wav_feats, input_finished);
-        int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
+        int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
         reserve_waveforms_.clear();
         reserve_waveforms_.insert(reserve_waveforms_.begin(),
                                     waves.begin() + reserve_frame_idx * frame_shift_sample_length_,

+ 4 - 1
funasr/runtime/onnxruntime/src/paraformer-online.h

@@ -92,7 +92,7 @@ namespace funasr {
         double sqrt_factor;
 
     public:
-        ParaformerOnline(Paraformer* para_handle);
+        ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size);
         ~ParaformerOnline();
         void Reset();
         void ResetCache();
@@ -103,6 +103,9 @@ namespace funasr {
         string ForwardChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished);
         string Forward(float* din, int len, bool input_finished);
         string Rescoring();
+        // 2pass
+        std::string online_res;
+        int chunk_len;
     };
 
 } // namespace funasr

+ 2 - 2
funasr/runtime/onnxruntime/src/tpass-stream.cpp

@@ -85,7 +85,7 @@ TpassStream *CreateTpassStream(std::map<std::string, std::string>& model_path, i
     return mm;
 }
 
-void CreateTpassOnlineStream(void* tpass_stream)
+void CreateTpassOnlineStream(void* tpass_stream, std::vector<int> chunk_size)
 {
     funasr::TpassStream* tpass_obj = (funasr::TpassStream*)tpass_stream;
     if(tpass_obj->vad_handle){
@@ -93,7 +93,7 @@ void CreateTpassOnlineStream(void* tpass_stream)
     }
 
     if(tpass_obj->asr_handle){
-        tpass_obj->asr_online_handle = make_unique<ParaformerOnline>((Paraformer*)(tpass_obj->asr_handle).get());
+        tpass_obj->asr_online_handle = make_unique<ParaformerOnline>((Paraformer*)(tpass_obj->asr_handle).get(), chunk_size);
     }else{
         LOG(ERROR)<<"asr_handle is null";
         exit(-1);