雾聪 2 лет назад
Родитель
Сommit
3372b13d24

+ 8 - 5
funasr/runtime/onnxruntime/CMakeLists.txt

@@ -7,6 +7,8 @@ option(ENABLE_GLOG "Whether to build glog" ON)
 # set(CMAKE_CXX_STANDARD 11)
 set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
 set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
+
 
 include(TestBigEndian)
 test_big_endian(BIG_ENDIAN)
@@ -30,12 +32,13 @@ endif()
 include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi-native-fbank)
 include_directories(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp/include)
 
-add_subdirectory(third_party/yaml-cpp)
-add_subdirectory(third_party/kaldi-native-fbank/kaldi-native-fbank/csrc)
-add_subdirectory(src)
-
 if(ENABLE_GLOG)
     include_directories(${PROJECT_SOURCE_DIR}/third_party/glog)
     set(BUILD_TESTING OFF)
     add_subdirectory(third_party/glog)
-endif()
+endif()
+
+add_subdirectory(third_party/yaml-cpp)
+add_subdirectory(third_party/kaldi-native-fbank/kaldi-native-fbank/csrc)
+add_subdirectory(src)
+add_subdirectory(bin)

+ 16 - 0
funasr/runtime/onnxruntime/bin/CMakeLists.txt

@@ -0,0 +1,16 @@
+include_directories(${CMAKE_SOURCE_DIR}/include)
+
+add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp")
+target_link_libraries(funasr-onnx-offline PUBLIC funasr)
+
+add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp")
+target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
+
+add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp")
+target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
+
+add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
+target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
+
+add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
+target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)

+ 0 - 0
funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp → funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp


+ 0 - 0
funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp → funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp


+ 1 - 1
funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp → funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp

@@ -125,7 +125,7 @@ int main(int argc, char *argv[])
     long taking_micros = 0;
     for(auto& wav_file : wav_list){
         gettimeofday(&start, NULL);
-        FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), FSMN_VAD_OFFLINE, NULL, 16000);
+        FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), NULL, 16000);
         gettimeofday(&end, NULL);
         seconds = (end.tv_sec - start.tv_sec);
         taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);

+ 0 - 0
funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp → funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp


+ 193 - 0
funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp

@@ -0,0 +1,193 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License  (https://opensource.org/licenses/MIT)
+*/
+
+#ifndef _WIN32
+#include <sys/time.h>
+#else
+#include <win_func.h>
+#endif
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <map>
+#include <vector>
+#include <glog/logging.h>
+#include "funasrruntime.h"
+#include "tclap/CmdLine.h"
+#include "com-define.h"
+#include "audio.h"
+
+using namespace std;
+
+bool is_target_file(const std::string& filename, const std::string target) {
+    std::size_t pos = filename.find_last_of(".");
+    if (pos == std::string::npos) {
+        return false;
+    }
+    std::string extension = filename.substr(pos + 1);
+    return (extension == target);
+}
+
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
+{
+    if (value_arg.isSet()){
+        model_path.insert({key, value_arg.getValue()});
+        LOG(INFO)<< key << " : " << value_arg.getValue();
+    }
+}
+
+void print_segs(vector<vector<int>>* vec) {
+    if((*vec).size() == 0){
+        return;
+    }    
+    string seg_out="[";
+    for (int i = 0; i < vec->size(); i++) {
+        vector<int> inner_vec = (*vec)[i];
+        if(inner_vec.size() == 0){
+            continue;
+        }
+        seg_out += "[";
+        for (int j = 0; j < inner_vec.size(); j++) {
+            seg_out += to_string(inner_vec[j]);
+            if (j != inner_vec.size() - 1) {
+                seg_out += ",";
+            }
+        }
+        seg_out += "]";
+        if (i != vec->size() - 1) {
+            seg_out += ",";
+        }
+    }
+    seg_out += "]";
+    LOG(INFO)<<seg_out;
+}
+
+int main(int argc, char *argv[])
+{
+    google::InitGoogleLogging(argv[0]);
+    FLAGS_logtostderr = true;
+
+    TCLAP::CmdLine cmd("funasr-onnx-offline-vad", ' ', "1.0");
+    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string");
+    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+
+    TCLAP::ValueArg<std::string>    wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
+
+    cmd.add(model_dir);
+    cmd.add(quantize);
+    cmd.add(wav_path);
+    cmd.parse(argc, argv);
+
+    std::map<std::string, std::string> model_path;
+    GetValue(model_dir, MODEL_DIR, model_path);
+    GetValue(quantize, QUANTIZE, model_path);
+    GetValue(wav_path, WAV_PATH, model_path);
+
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
+    int thread_num = 1;
+    FUNASR_HANDLE vad_hanlde=FsmnVadInit(model_path, thread_num);
+
+    if (!vad_hanlde)
+    {
+        LOG(ERROR) << "FunVad init failed";
+        exit(-1);
+    }
+
+    gettimeofday(&end, NULL);
+    long seconds = (end.tv_sec - start.tv_sec);
+    long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+    LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
+
+    // read wav_path
+    vector<string> wav_list;
+    string wav_path_ = model_path.at(WAV_PATH);
+    if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
+        wav_list.emplace_back(wav_path_);
+    }
+    else if(is_target_file(wav_path_, "scp")){
+        ifstream in(wav_path_);
+        if (!in.is_open()) {
+            LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
+            return 0;
+        }
+        string line;
+        while(getline(in, line))
+        {
+            istringstream iss(line);
+            string column1, column2;
+            iss >> column1 >> column2;
+            wav_list.emplace_back(column2); 
+        }
+        in.close();
+    }else{
+        LOG(ERROR)<<"Please check the wav extension!";
+        exit(-1);
+    }
+    // init online features
+    FUNASR_HANDLE online_hanlde=FsmnVadOnlineInit(vad_hanlde);
+    float snippet_time = 0.0f;
+    long taking_micros = 0;
+    for(auto& wav_file : wav_list){
+
+        int32_t sampling_rate_ = -1;
+        funasr::Audio audio(1);
+		if(is_target_file(wav_file.c_str(), "wav")){
+			int32_t sampling_rate_ = -1;
+			if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
+				LOG(ERROR)<<"Failed to load "<< wav_file;
+                exit(-1);
+            }
+		}else if(is_target_file(wav_file.c_str(), "pcm")){
+			if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
+				LOG(ERROR)<<"Failed to load "<< wav_file;
+                exit(-1);
+            }
+		}else{
+			LOG(ERROR)<<"Wrong wav extension";
+			exit(-1);
+		}
+        char* speech_buff = audio.GetSpeechChar();
+        int buff_len = audio.GetSpeechLen()*2;
+
+        int step = 3200;
+        bool is_final = false;
+
+        for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
+            if (sample_offset + step >= buff_len - 1) {
+                    step = buff_len - sample_offset;
+                    is_final = true;
+                } else {
+                    is_final = false;
+            }
+            gettimeofday(&start, NULL);
+            FUNASR_RESULT result = FsmnVadInferBuffer(online_hanlde, speech_buff+sample_offset, step, NULL, is_final, 16000);
+            gettimeofday(&end, NULL);
+            seconds = (end.tv_sec - start.tv_sec);
+            taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+
+            if (result)
+            {
+                vector<std::vector<int>>* vad_segments = FsmnVadGetResult(result, 0);
+                print_segs(vad_segments);
+                snippet_time += FsmnVadGetRetSnippetTime(result);
+                FsmnVadFreeResult(result);
+            }
+            else
+            {
+                LOG(ERROR) << ("No return data!\n");
+            }
+        }
+    }
+
+    LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
+    LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
+    LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
+    FsmnVadUninit(online_hanlde);
+    FsmnVadUninit(vad_hanlde);
+    return 0;
+}
+

+ 9 - 4
funasr/runtime/onnxruntime/include/audio.h

@@ -33,8 +33,9 @@ class AudioFrame {
 
 class Audio {
   private:
-    float *speech_data;
-    int16_t *speech_buff;
+    float *speech_data=nullptr;
+    int16_t *speech_buff=nullptr;
+    char* speech_char=nullptr;
     int speech_len;
     int speech_align_len;
     int offset;
@@ -47,18 +48,22 @@ class Audio {
     Audio(int data_type, int size);
     ~Audio();
     void Disp();
-    bool LoadWav(const char* filename, int32_t* sampling_rate);
     void WavResample(int32_t sampling_rate, const float *waveform, int32_t n);
     bool LoadWav(const char* buf, int n_len, int32_t* sampling_rate);
+    bool LoadWav(const char* filename, int32_t* sampling_rate);
+    bool LoadWav2Char(const char* filename, int32_t* sampling_rate);
     bool LoadPcmwav(const char* buf, int n_file_len, int32_t* sampling_rate);
     bool LoadPcmwav(const char* filename, int32_t* sampling_rate);
+    bool LoadPcmwav2Char(const char* filename, int32_t* sampling_rate);
     int FetchChunck(float *&dout, int len);
     int Fetch(float *&dout, int &len, int &flag);
     void Padding();
     void Split(OfflineStream* offline_streamj);
-    void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments);
+    void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
     float GetTimeLen();
     int GetQueueSize() { return (int)frame_queue.size(); }
+    char* GetSpeechChar(){return speech_char;}
+    int GetSpeechLen(){return speech_len;}
 };
 
 } // namespace funasr

+ 4 - 9
funasr/runtime/onnxruntime/include/funasrruntime.h

@@ -46,12 +46,6 @@ typedef enum {
 	FUNASR_MODEL_PARAFORMER = 3,
 }FUNASR_MODEL_TYPE;
 
-typedef enum
-{
- FSMN_VAD_OFFLINE=0,
- FSMN_VAD_ONLINE = 1,
-}FSMN_VAD_MODE;
-
 typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
 	
 // ASR
@@ -68,11 +62,12 @@ _FUNASRAPI void			FunASRUninit(FUNASR_HANDLE handle);
 _FUNASRAPI const float	FunASRGetRetSnippetTime(FUNASR_RESULT result);
 
 // VAD
-_FUNASRAPI FUNASR_HANDLE  	FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num, FSMN_VAD_MODE mode=FSMN_VAD_OFFLINE);
+_FUNASRAPI FUNASR_HANDLE  	FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num);
+_FUNASRAPI FUNASR_HANDLE  	FsmnVadOnlineInit(FUNASR_HANDLE fsmnvad_handle);
 // buffer
-_FUNASRAPI FUNASR_RESULT	FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
+_FUNASRAPI FUNASR_RESULT	FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000);
 // file, support wav & pcm
-_FUNASRAPI FUNASR_RESULT	FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
+_FUNASRAPI FUNASR_RESULT	FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, QM_CALLBACK fn_callback, int sampling_rate=16000);
 
 _FUNASRAPI std::vector<std::vector<int>>*	FsmnVadGetResult(FUNASR_RESULT result,int n_index);
 _FUNASRAPI void			 	FsmnVadFreeResult(FUNASR_RESULT result);

+ 2 - 7
funasr/runtime/onnxruntime/include/vad-model.h

@@ -12,14 +12,9 @@ class VadModel {
     virtual ~VadModel(){};
     virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0;
     virtual std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true)=0;
-    virtual void ReadModel(const char* vad_model)=0;
-    virtual void LoadConfigFromYaml(const char* filename)=0;
-    virtual void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
-                    std::vector<float> &waves)=0;
-    virtual void LoadCmvn(const char *filename)=0;
-    virtual void InitCache()=0;
 };
 
-VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num, int mode);
+VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num);
+VadModel *CreateVadModel(void* fsmnvad_handle);
 } // namespace funasr
 #endif

+ 2 - 15
funasr/runtime/onnxruntime/src/CMakeLists.txt

@@ -1,11 +1,8 @@
 
 file(GLOB files1 "*.cpp")
-file(GLOB files2 "*.cc")
+set(files ${files1})
 
-set(files ${files1} ${files2})
-set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
-
-add_library(funasr ${files})
+add_library(funasr SHARED ${files})
 
 if(WIN32)
     set(EXTRA_LIBS pthread yaml-cpp csrc glog)
@@ -24,13 +21,3 @@ endif()
 
 include_directories(${CMAKE_SOURCE_DIR}/include)
 target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
-
-add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp")
-add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp")
-add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
-add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
-target_link_libraries(funasr-onnx-offline PUBLIC funasr)
-target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
-target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
-target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
-

+ 72 - 6
funasr/runtime/onnxruntime/src/audio.cpp

@@ -176,13 +176,13 @@ Audio::~Audio()
 {
     if (speech_buff != NULL) {
         free(speech_buff);
-        
     }
-
     if (speech_data != NULL) {
-        
         free(speech_data);
     }
+    if (speech_char != NULL) {
+        free(speech_char);
+    }
 }
 
 void Audio::Disp()
@@ -296,8 +296,47 @@ bool Audio::LoadWav(const char *filename, int32_t* sampling_rate)
         return false;
 }
 
-bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate)
+bool Audio::LoadWav2Char(const char *filename, int32_t* sampling_rate)
 {
+    WaveHeader header;
+    if (speech_char != NULL) {
+        free(speech_char);
+    }
+    offset = 0;
+    std::ifstream is(filename, std::ifstream::binary);
+    is.read(reinterpret_cast<char *>(&header), sizeof(header));
+    if(!is){
+        LOG(ERROR) << "Failed to read " << filename;
+        return false;
+    }
+    if (!header.Validate()) {
+        return false;
+    }
+    header.SeekToDataChunk(is);
+        if (!is) {
+            return false;
+    }
+    if (!header.Validate()) {
+        return false;
+    }
+    header.SeekToDataChunk(is);
+    if (!is) {
+        return false;
+    }
+    
+    *sampling_rate = header.sample_rate;
+    // header.subchunk2_size contains the number of bytes in the data.
+    // As we assume each sample contains two bytes, so it is divided by 2 here
+    speech_len = header.subchunk2_size / 2;
+    speech_char = (char *)malloc(header.subchunk2_size);
+    memset(speech_char, 0, header.subchunk2_size);
+    is.read(speech_char, header.subchunk2_size);
+
+    return true;
+}
+
+bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate)
+{ 
     WaveHeader header;
     if (speech_data != NULL) {
         free(speech_data);
@@ -441,6 +480,33 @@ bool Audio::LoadPcmwav(const char* filename, int32_t* sampling_rate)
 
 }
 
+bool Audio::LoadPcmwav2Char(const char* filename, int32_t* sampling_rate)
+{
+    if (speech_char != NULL) {
+        free(speech_char);
+    }
+    offset = 0;
+
+    FILE* fp;
+    fp = fopen(filename, "rb");
+    if (fp == nullptr)
+	{
+        LOG(ERROR) << "Failed to read " << filename;
+        return false;
+	}
+    fseek(fp, 0, SEEK_END);
+    uint32_t n_file_len = ftell(fp);
+    fseek(fp, 0, SEEK_SET);
+
+    speech_len = (n_file_len) / 2;
+    speech_char = (char *)malloc(n_file_len);
+    memset(speech_char, 0, n_file_len);
+    fread(speech_char, sizeof(int16_t), n_file_len/2, fp);
+    fclose(fp);
+    
+    return true;
+}
+
 int Audio::FetchChunck(float *&dout, int len)
 {
     if (offset >= speech_align_len) {
@@ -541,7 +607,7 @@ void Audio::Split(OfflineStream* offline_stream)
 }
 
 
-void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments)
+void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished)
 {
     AudioFrame *frame;
 
@@ -552,7 +618,7 @@ void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments)
     frame = NULL;
 
     std::vector<float> pcm_data(speech_data, speech_data+sp_len);
-    vad_segments = vad_obj->Infer(pcm_data);
+    vad_segments = vad_obj->Infer(pcm_data, input_finished);
 }
 
 } // namespace funasr

+ 198 - 0
funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp

@@ -0,0 +1,198 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License  (https://opensource.org/licenses/MIT)
+*/
+
+#include <fstream>
+#include "precomp.h"
+
+namespace funasr {
+
+void FsmnVadOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
+                               std::vector<float> &waves) {
+    knf::OnlineFbank fbank(fbank_opts_);
+    // cache merge
+    waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
+    int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
+    // Send the audio after the last frame shift position to the cache
+    input_cache_.clear();
+    input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
+    if (frame_number == 0) {
+        return;
+    }
+    // Delete audio that haven't undergone fbank processing
+    waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
+
+    std::vector<float> buf(waves.size());
+    for (int32_t i = 0; i != waves.size(); ++i) {
+        buf[i] = waves[i] * 32768;
+    }
+    fbank.AcceptWaveform(sample_rate, buf.data(), buf.size());
+    // fbank.AcceptWaveform(sample_rate, &waves[0], waves.size());
+    int32_t frames = fbank.NumFramesReady();
+    for (int32_t i = 0; i != frames; ++i) {
+        const float *frame = fbank.GetFrame(i);
+        vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
+        vad_feats.emplace_back(frame_vector);
+    }
+}
+
+void FsmnVadOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &vad_feats,
+                                 vector<float> &waves, bool input_finished) {
+  FbankKaldi(sample_rate, vad_feats, waves);
+  // cache deal & online lfr,cmvn
+  if (vad_feats.size() > 0) {
+    if (!reserve_waveforms_.empty()) {
+      waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
+    }
+    if (lfr_splice_cache_.empty()) {
+      for (int i = 0; i < (lfr_m - 1) / 2; i++) {
+        lfr_splice_cache_.emplace_back(vad_feats[0]);
+      }
+    }
+    if (vad_feats.size() + lfr_splice_cache_.size() >= lfr_m) {
+      vad_feats.insert(vad_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
+      int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
+      int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
+      int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats, input_finished);
+      int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
+      reserve_waveforms_.clear();
+      reserve_waveforms_.insert(reserve_waveforms_.begin(),
+                                waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
+                                waves.begin() + frame_from_waves * frame_shift_sample_length_);
+      int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
+      waves.erase(waves.begin() + sample_length, waves.end());
+    } else {
+      reserve_waveforms_.clear();
+      reserve_waveforms_.insert(reserve_waveforms_.begin(),
+                                waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
+      lfr_splice_cache_.insert(lfr_splice_cache_.end(), vad_feats.begin(), vad_feats.end());
+    }
+  } else {
+    if (input_finished) {
+      if (!reserve_waveforms_.empty()) {
+        waves = reserve_waveforms_;
+      }
+      vad_feats = lfr_splice_cache_;
+      OnlineLfrCmvn(vad_feats, input_finished);
+    }
+  }
+  if(input_finished){
+      Reset();
+      ResetCache();
+  }
+}
+
+int FsmnVadOnline::OnlineLfrCmvn(vector<vector<float>> &vad_feats, bool input_finished) {
+    vector<vector<float>> out_feats;
+    int T = vad_feats.size();
+    int T_lrf = ceil((T - (lfr_m - 1) / 2) / lfr_n);
+    int lfr_splice_frame_idxs = T_lrf;
+    vector<float> p;
+    for (int i = 0; i < T_lrf; i++) {
+        if (lfr_m <= T - i * lfr_n) {
+            for (int j = 0; j < lfr_m; j++) {
+                p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
+            }
+            out_feats.emplace_back(p);
+            p.clear();
+        } else {
+            if (input_finished) {
+                int num_padding = lfr_m - (T - i * lfr_n);
+                for (int j = 0; j < (vad_feats.size() - i * lfr_n); j++) {
+                    p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
+                }
+                for (int j = 0; j < num_padding; j++) {
+                    p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end());
+                }
+                out_feats.emplace_back(p);
+            } else {
+                lfr_splice_frame_idxs = i;
+                break;
+            }
+        }
+    }
+    lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n);
+    lfr_splice_cache_.clear();
+    lfr_splice_cache_.insert(lfr_splice_cache_.begin(), vad_feats.begin() + lfr_splice_frame_idxs, vad_feats.end());
+
+    // Apply cmvn
+    for (auto &out_feat: out_feats) {
+        for (int j = 0; j < means_list_.size(); j++) {
+            out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
+        }
+    }
+    vad_feats = out_feats;
+    return lfr_splice_frame_idxs;
+}
+
+std::vector<std::vector<int>>
+FsmnVadOnline::Infer(std::vector<float> &waves, bool input_finished) {
+    std::vector<std::vector<float>> vad_feats;
+    std::vector<std::vector<float>> vad_probs;
+    ExtractFeats(vad_sample_rate_, vad_feats, waves, input_finished);
+    fsmnvad_handle_->Forward(vad_feats, &vad_probs, &in_cache_, input_finished);
+
+    std::vector<std::vector<int>> vad_segments;
+    vad_segments = vad_scorer(vad_probs, waves, input_finished, true, vad_silence_duration_, vad_max_len_,
+                              vad_speech_noise_thres_, vad_sample_rate_);
+    return vad_segments;
+}
+
+void FsmnVadOnline::InitCache(){
+  std::vector<float> cache_feats(128 * 19 * 1, 0);
+  for (int i=0;i<4;i++){
+    in_cache_.emplace_back(cache_feats);
+  }
+};
+
+void FsmnVadOnline::Reset(){
+  in_cache_.clear();
+  InitCache();
+};
+
+void FsmnVadOnline::Test() {
+}
+
+void FsmnVadOnline::InitOnline(std::shared_ptr<Ort::Session> &vad_session,
+                               Ort::Env &env,
+                               std::vector<const char *> &vad_in_names,
+                               std::vector<const char *> &vad_out_names,
+                               knf::FbankOptions &fbank_opts,
+                               std::vector<float> &means_list,
+                               std::vector<float> &vars_list,
+                               int vad_sample_rate,
+                               int vad_silence_duration,
+                               int vad_max_len,
+                               double vad_speech_noise_thres) {
+    vad_session_ = vad_session;
+    vad_in_names_ = vad_in_names;
+    vad_out_names_ = vad_out_names;
+    fbank_opts_ = fbank_opts;
+    means_list_ = means_list;
+    vars_list_ = vars_list;
+    vad_sample_rate_ = vad_sample_rate;
+    vad_silence_duration_ = vad_silence_duration;
+    vad_max_len_ = vad_max_len;
+    vad_speech_noise_thres_ = vad_speech_noise_thres;
+}
+
+FsmnVadOnline::~FsmnVadOnline() {
+}
+
+FsmnVadOnline::FsmnVadOnline(FsmnVad* fsmnvad_handle):fsmnvad_handle_(std::move(fsmnvad_handle)),session_options_{}{
+   InitCache();
+   InitOnline(fsmnvad_handle_->vad_session_,
+              fsmnvad_handle_->env_,
+              fsmnvad_handle_->vad_in_names_,
+              fsmnvad_handle_->vad_out_names_,
+              fsmnvad_handle_->fbank_opts_,
+              fsmnvad_handle_->means_list_,
+              fsmnvad_handle_->vars_list_,
+              fsmnvad_handle_->vad_sample_rate_,
+              fsmnvad_handle_->vad_silence_duration_,
+              fsmnvad_handle_->vad_max_len_,
+              fsmnvad_handle_->vad_speech_noise_thres_);
+}
+
+} // namespace funasr

+ 88 - 0
funasr/runtime/onnxruntime/src/fsmn-vad-online.h

@@ -0,0 +1,88 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License  (https://opensource.org/licenses/MIT)
+*/
+
+#pragma once 
+#include "precomp.h"
+
+namespace funasr {
+class FsmnVadOnline : public VadModel {
+/**
+ * Author: Speech Lab of DAMO Academy, Alibaba Group
+ * Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ * https://arxiv.org/abs/1803.05030
+*/
+
+public:
+    explicit FsmnVadOnline(FsmnVad* fsmnvad_handle);
+    ~FsmnVadOnline();
+    void Test();
+    std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
+    void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
+    void Reset();
+
+private:
+    E2EVadModel vad_scorer = E2EVadModel();
+    // std::unique_ptr<FsmnVad> fsmnvad_handle_;
+    FsmnVad* fsmnvad_handle_ = nullptr;
+
+    void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
+                    std::vector<float> &waves);
+    int OnlineLfrCmvn(vector<vector<float>> &vad_feats, bool input_finished);
+    void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num){}
+    void InitCache();
+    void InitOnline(std::shared_ptr<Ort::Session> &vad_session,
+                    Ort::Env &env,
+                    std::vector<const char *> &vad_in_names,
+                    std::vector<const char *> &vad_out_names,
+                    knf::FbankOptions &fbank_opts,
+                    std::vector<float> &means_list,
+                    std::vector<float> &vars_list,
+                    int vad_sample_rate,
+                    int vad_silence_duration,
+                    int vad_max_len,
+                    double vad_speech_noise_thres);
+
+    static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
+        int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
+        if (frame_num >= 1 && sample_length >= frame_sample_length)
+            return frame_num;
+        else
+            return 0;
+    }
+    void ResetCache() {
+        reserve_waveforms_.clear();
+        input_cache_.clear();
+        lfr_splice_cache_.clear();
+    }
+
+    // from fsmnvad_handle_
+    std::shared_ptr<Ort::Session> vad_session_ = nullptr;
+    Ort::Env env_;
+    Ort::SessionOptions session_options_;
+    std::vector<const char *> vad_in_names_;
+    std::vector<const char *> vad_out_names_;
+    knf::FbankOptions fbank_opts_;
+    std::vector<float> means_list_;
+    std::vector<float> vars_list_;
+
+    std::vector<std::vector<float>> in_cache_;
+    // The reserved waveforms by fbank
+    std::vector<float> reserve_waveforms_;
+    // waveforms reserved after last shift position
+    std::vector<float> input_cache_;
+    // lfr reserved cache
+    std::vector<std::vector<float>> lfr_splice_cache_;
+
+    int vad_sample_rate_ = MODEL_SAMPLE_RATE;
+    int vad_silence_duration_ = VAD_SILENCE_DURATION;
+    int vad_max_len_ = VAD_MAX_LEN;
+    double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
+    int lfr_m = VAD_LFR_M;
+    int lfr_n = VAD_LFR_N;
+    int frame_sample_length_ = vad_sample_rate_ / 1000 * 25;;
+    int frame_shift_sample_length_ = vad_sample_rate_ / 1000 * 10;
+};
+
+} // namespace funasr

+ 28 - 23
funasr/runtime/onnxruntime/src/fsmn-vad.cpp

@@ -37,14 +37,14 @@ void FsmnVad::LoadConfigFromYaml(const char* filename){
         this->vad_max_len_ = post_conf["max_single_segment_time"].as<int>();
         this->vad_speech_noise_thres_ = post_conf["speech_noise_thres"].as<double>();
 
-        fbank_opts.frame_opts.dither = frontend_conf["dither"].as<float>();
-        fbank_opts.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
-        fbank_opts.frame_opts.samp_freq = (float)vad_sample_rate_;
-        fbank_opts.frame_opts.window_type = frontend_conf["window"].as<string>();
-        fbank_opts.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
-        fbank_opts.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
-        fbank_opts.energy_floor = 0;
-        fbank_opts.mel_opts.debug_mel = false;
+        fbank_opts_.frame_opts.dither = frontend_conf["dither"].as<float>();
+        fbank_opts_.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
+        fbank_opts_.frame_opts.samp_freq = (float)vad_sample_rate_;
+        fbank_opts_.frame_opts.window_type = frontend_conf["window"].as<string>();
+        fbank_opts_.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
+        fbank_opts_.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
+        fbank_opts_.energy_floor = 0;
+        fbank_opts_.mel_opts.debug_mel = false;
     }catch(exception const &e){
         LOG(ERROR) << "Error when load argument from vad config YAML.";
         exit(-1);
@@ -55,6 +55,7 @@ void FsmnVad::ReadModel(const char* vad_model) {
     try {
         vad_session_ = std::make_shared<Ort::Session>(
                 env_, vad_model, session_options_);
+        LOG(INFO) << "Successfully load model from " << vad_model;
     } catch (std::exception const &e) {
         LOG(ERROR) << "Error when load vad onnx model: " << e.what();
         exit(0);
@@ -109,7 +110,9 @@ void FsmnVad::GetInputOutputInfo(
 
 void FsmnVad::Forward(
         const std::vector<std::vector<float>> &chunk_feats,
-        std::vector<std::vector<float>> *out_prob) {
+        std::vector<std::vector<float>> *out_prob,
+        std::vector<std::vector<float>> *in_cache,
+        bool is_final) {
     Ort::MemoryInfo memory_info =
             Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
 
@@ -132,9 +135,9 @@ void FsmnVad::Forward(
     // 4 caches
     // cache node {batch,128,19,1}
     const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
-    for (int i = 0; i < in_cache_.size(); i++) {
+    for (int i = 0; i < in_cache->size(); i++) {
       vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
-              memory_info, in_cache_[i].data(), in_cache_[i].size(), cache_feats_shape, 4)));
+              memory_info, (*in_cache)[i].data(), (*in_cache)[i].size(), cache_feats_shape, 4)));
     }
   
     // 4. Onnx infer
@@ -162,15 +165,17 @@ void FsmnVad::Forward(
     }
   
     // get 4 caches outputs,each size is 128*19
-    // for (int i = 1; i < 5; i++) {
-    //   float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
-    //   memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
-    // }
+    if(!is_final){
+        for (int i = 1; i < 5; i++) {
+        float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
+        memcpy((*in_cache)[i-1].data(), data, sizeof(float) * 128*19);
+        }
+    }
 }
 
 void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
                          std::vector<float> &waves) {
-    knf::OnlineFbank fbank(fbank_opts);
+    knf::OnlineFbank fbank(fbank_opts_);
 
     std::vector<float> buf(waves.size());
     for (int32_t i = 0; i != waves.size(); ++i) {
@@ -180,7 +185,7 @@ void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad
     int32_t frames = fbank.NumFramesReady();
     for (int32_t i = 0; i != frames; ++i) {
         const float *frame = fbank.GetFrame(i);
-        std::vector<float> frame_vector(frame, frame + fbank_opts.mel_opts.num_bins);
+        std::vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
         vad_feats.emplace_back(frame_vector);
     }
 }
@@ -205,7 +210,7 @@ void FsmnVad::LoadCmvn(const char *filename)
                 vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
                 if (means_lines[0] == "<LearnRateCoef>") {
                     for (int j = 3; j < means_lines.size() - 1; j++) {
-                        means_list.push_back(stof(means_lines[j]));
+                        means_list_.push_back(stof(means_lines[j]));
                     }
                     continue;
                 }
@@ -216,8 +221,8 @@ void FsmnVad::LoadCmvn(const char *filename)
                 vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
                 if (vars_lines[0] == "<LearnRateCoef>") {
                     for (int j = 3; j < vars_lines.size() - 1; j++) {
-                        // vars_list.push_back(stof(vars_lines[j])*scale);
-                        vars_list.push_back(stof(vars_lines[j]));
+                        // vars_list_.push_back(stof(vars_lines[j])*scale);
+                        vars_list_.push_back(stof(vars_lines[j]));
                     }
                     continue;
                 }
@@ -263,8 +268,8 @@ void FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats) {
     }
     // Apply cmvn
     for (auto &out_feat: out_feats) {
-        for (int j = 0; j < means_list.size(); j++) {
-            out_feat[j] = (out_feat[j] + means_list[j]) * vars_list[j];
+        for (int j = 0; j < means_list_.size(); j++) {
+            out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
         }
     }
     vad_feats = out_feats;
@@ -276,7 +281,7 @@ FsmnVad::Infer(std::vector<float> &waves, bool input_finished) {
     std::vector<std::vector<float>> vad_probs;
     FbankKaldi(vad_sample_rate_, vad_feats, waves);
     LfrCmvn(vad_feats);
-    Forward(vad_feats, &vad_probs);
+    Forward(vad_feats, &vad_probs, &in_cache_, input_finished);
 
     E2EVadModel vad_scorer = E2EVadModel();
     std::vector<std::vector<int>> vad_segments;

+ 23 - 22
funasr/runtime/onnxruntime/src/fsmn-vad.h

@@ -22,7 +22,30 @@ public:
     void Test();
     void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num);
     std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true);
+    void Forward(
+        const std::vector<std::vector<float>> &chunk_feats,
+        std::vector<std::vector<float>> *out_prob,
+        std::vector<std::vector<float>> *in_cache,
+        bool is_final);
     void Reset();
+    
+    std::shared_ptr<Ort::Session> vad_session_ = nullptr;
+    Ort::Env env_;
+    Ort::SessionOptions session_options_;
+    std::vector<const char *> vad_in_names_;
+    std::vector<const char *> vad_out_names_;
+    std::vector<std::vector<float>> in_cache_;
+    
+    knf::FbankOptions fbank_opts_;
+    std::vector<float> means_list_;
+    std::vector<float> vars_list_;
+
+    int vad_sample_rate_ = MODEL_SAMPLE_RATE;
+    int vad_silence_duration_ = VAD_SILENCE_DURATION;
+    int vad_max_len_ = VAD_MAX_LEN;
+    double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
+    int lfr_m = VAD_LFR_M;
+    int lfr_n = VAD_LFR_N;
 
 private:
 
@@ -37,31 +60,9 @@ private:
                     std::vector<float> &waves);
 
     void LfrCmvn(std::vector<std::vector<float>> &vad_feats);
-
-    void Forward(
-            const std::vector<std::vector<float>> &chunk_feats,
-            std::vector<std::vector<float>> *out_prob);
-
     void LoadCmvn(const char *filename);
     void InitCache();
 
-    std::shared_ptr<Ort::Session> vad_session_ = nullptr;
-    Ort::Env env_;
-    Ort::SessionOptions session_options_;
-    std::vector<const char *> vad_in_names_;
-    std::vector<const char *> vad_out_names_;
-    std::vector<std::vector<float>> in_cache_;
-    
-    knf::FbankOptions fbank_opts;
-    std::vector<float> means_list;
-    std::vector<float> vars_list;
-
-    int vad_sample_rate_ = MODEL_SAMPLE_RATE;
-    int vad_silence_duration_ = VAD_SILENCE_DURATION;
-    int vad_max_len_ = VAD_MAX_LEN;
-    double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
-    int lfr_m = VAD_LFR_M;
-    int lfr_n = VAD_LFR_N;
 };
 
 } // namespace funasr

+ 12 - 6
funasr/runtime/onnxruntime/src/funasrruntime.cpp

@@ -11,9 +11,15 @@ extern "C" {
 		return mm;
 	}
 
-	_FUNASRAPI FUNASR_HANDLE  FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num, FSMN_VAD_MODE mode)
+	_FUNASRAPI FUNASR_HANDLE  FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num)
 	{
-		funasr::VadModel* mm = funasr::CreateVadModel(model_path, thread_num, mode);
+		funasr::VadModel* mm = funasr::CreateVadModel(model_path, thread_num);
+		return mm;
+	}
+
+	_FUNASRAPI FUNASR_HANDLE  FsmnVadOnlineInit(FUNASR_HANDLE fsmnvad_handle)
+	{
+		funasr::VadModel* mm = funasr::CreateVadModel(fsmnvad_handle);
 		return mm;
 	}
 
@@ -96,7 +102,7 @@ extern "C" {
 	}
 
 	// APIs for VAD Infer
-	_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate)
+	_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate)
 	{
 		funasr::VadModel* vad_obj = (funasr::VadModel*)handle;
 		if (!vad_obj)
@@ -110,13 +116,13 @@ extern "C" {
 		p_result->snippet_time = audio.GetTimeLen();
 		
 		vector<std::vector<int>> vad_segments;
-		audio.Split(vad_obj, vad_segments);
+		audio.Split(vad_obj, vad_segments, input_finished);
 		p_result->segments = new vector<std::vector<int>>(vad_segments);
 
 		return p_result;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate)
+	_FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, QM_CALLBACK fn_callback, int sampling_rate)
 	{
 		funasr::VadModel* vad_obj = (funasr::VadModel*)handle;
 		if (!vad_obj)
@@ -139,7 +145,7 @@ extern "C" {
 		p_result->snippet_time = audio.GetTimeLen();
 		
 		vector<std::vector<int>> vad_segments;
-		audio.Split(vad_obj, vad_segments);
+		audio.Split(vad_obj, vad_segments, true);
 		p_result->segments = new vector<std::vector<int>>(vad_segments);
 
 		return p_result;

+ 0 - 137
funasr/runtime/onnxruntime/src/online-feature.cpp

@@ -1,137 +0,0 @@
-/**
- * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
- * MIT License  (https://opensource.org/licenses/MIT)
- * Contributed by zhuzizyf(China Telecom).
-*/
-
-#include "online-feature.h"
-#include <utility>
-
-namespace funasr {
-OnlineFeature::OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m, int lfr_n,
-                             std::vector<std::vector<float>> cmvns)
-  : sample_rate_(sample_rate),
-    fbank_opts_(std::move(fbank_opts)),
-    lfr_m_(lfr_m),
-    lfr_n_(lfr_n),
-    cmvns_(std::move(cmvns)) {
-  frame_sample_length_ = sample_rate_ / 1000 * 25;;
-  frame_shift_sample_length_ = sample_rate_ / 1000 * 10;
-}
-
-void OnlineFeature::ExtractFeats(vector<std::vector<float>> &vad_feats,
-                                 vector<float> waves, bool input_finished) {
-  input_finished_ = input_finished;
-  OnlineFbank(vad_feats, waves);
-  // cache deal & online lfr,cmvn
-  if (vad_feats.size() > 0) {
-    if (!reserve_waveforms_.empty()) {
-      waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
-    }
-    if (lfr_splice_cache_.empty()) {
-      for (int i = 0; i < (lfr_m_ - 1) / 2; i++) {
-        lfr_splice_cache_.emplace_back(vad_feats[0]);
-      }
-    }
-    if (vad_feats.size() + lfr_splice_cache_.size() >= lfr_m_) {
-      vad_feats.insert(vad_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
-      int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
-      int minus_frame = reserve_waveforms_.empty() ? (lfr_m_ - 1) / 2 : 0;
-      int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats);
-      int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
-      reserve_waveforms_.clear();
-      reserve_waveforms_.insert(reserve_waveforms_.begin(),
-                                waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
-                                waves.begin() + frame_from_waves * frame_shift_sample_length_);
-      int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
-      waves.erase(waves.begin() + sample_length, waves.end());
-    } else {
-      reserve_waveforms_.clear();
-      reserve_waveforms_.insert(reserve_waveforms_.begin(),
-                                waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
-      lfr_splice_cache_.insert(lfr_splice_cache_.end(), vad_feats.begin(), vad_feats.end());
-    }
-
-  } else {
-    if (input_finished_) {
-      if (!reserve_waveforms_.empty()) {
-        waves = reserve_waveforms_;
-      }
-      vad_feats = lfr_splice_cache_;
-      OnlineLfrCmvn(vad_feats);
-      ResetCache();
-    }
-  }
-
-}
-
-int OnlineFeature::OnlineLfrCmvn(vector<vector<float>> &vad_feats) {
-  vector<vector<float>> out_feats;
-  int T = vad_feats.size();
-  int T_lrf = ceil((T - (lfr_m_ - 1) / 2) / lfr_n_);
-  int lfr_splice_frame_idxs = T_lrf;
-  vector<float> p;
-  for (int i = 0; i < T_lrf; i++) {
-    if (lfr_m_ <= T - i * lfr_n_) {
-      for (int j = 0; j < lfr_m_; j++) {
-        p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end());
-      }
-      out_feats.emplace_back(p);
-      p.clear();
-    } else {
-      if (input_finished_) {
-        int num_padding = lfr_m_ - (T - i * lfr_n_);
-        for (int j = 0; j < (vad_feats.size() - i * lfr_n_); j++) {
-          p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end());
-        }
-        for (int j = 0; j < num_padding; j++) {
-          p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end());
-        }
-        out_feats.emplace_back(p);
-      } else {
-        lfr_splice_frame_idxs = i;
-        break;
-      }
-    }
-  }
-  lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n_);
-  lfr_splice_cache_.clear();
-  lfr_splice_cache_.insert(lfr_splice_cache_.begin(), vad_feats.begin() + lfr_splice_frame_idxs, vad_feats.end());
-
-  // Apply cmvn
-  for (auto &out_feat: out_feats) {
-    for (int j = 0; j < cmvns_[0].size(); j++) {
-      out_feat[j] = (out_feat[j] + cmvns_[0][j]) * cmvns_[1][j];
-    }
-  }
-  vad_feats = out_feats;
-  return lfr_splice_frame_idxs;
-}
-
-void OnlineFeature::OnlineFbank(vector<std::vector<float>> &vad_feats,
-                                vector<float> &waves) {
-
-  knf::OnlineFbank fbank(fbank_opts_);
-  // cache merge
-  waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
-  int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
-  // Send the audio after the last frame shift position to the cache
-  input_cache_.clear();
-  input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
-  if (frame_number == 0) {
-    return;
-  }
-  // Delete audio that haven't undergone fbank processing
-  waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
-
-  fbank.AcceptWaveform(sample_rate_, &waves[0], waves.size());
-  int32_t frames = fbank.NumFramesReady();
-  for (int32_t i = 0; i != frames; ++i) {
-    const float *frame = fbank.GetFrame(i);
-    vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
-    vad_feats.emplace_back(frame_vector);
-  }
-
-}
-
-} // namespace funasr

+ 0 - 58
funasr/runtime/onnxruntime/src/online-feature.h

@@ -1,58 +0,0 @@
-/**
- * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
- * MIT License  (https://opensource.org/licenses/MIT)
- * Contributed by zhuzizyf(China Telecom).
-*/
-#pragma once 
-#include <vector>
-#include "precomp.h"
-
-using namespace std;
-namespace funasr {
-class OnlineFeature {
-
-public:
-  OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m_, int lfr_n_,
-                std::vector<std::vector<float>> cmvns_);
-
-  void ExtractFeats(vector<vector<float>> &vad_feats, vector<float> waves, bool input_finished);
-
-private:
-  void OnlineFbank(vector<vector<float>> &vad_feats, vector<float> &waves);
-  int OnlineLfrCmvn(vector<vector<float>> &vad_feats);
-  
-  static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
-    int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
-    if (frame_num >= 1 && sample_length >= frame_sample_length)
-      return frame_num;
-    else
-      return 0;
-  }
-
-  void ResetCache() {
-    reserve_waveforms_.clear();
-    input_cache_.clear();
-    lfr_splice_cache_.clear();
-    input_finished_ = false;
-
-  }
-
-  knf::FbankOptions fbank_opts_;
-  // The reserved waveforms by fbank
-  std::vector<float> reserve_waveforms_;
-  // waveforms reserved after last shift position
-  std::vector<float> input_cache_;
-  // lfr reserved cache
-  std::vector<std::vector<float>> lfr_splice_cache_;
-  std::vector<std::vector<float>> cmvns_;
-
-  int sample_rate_ = 16000;
-  int frame_sample_length_ = sample_rate_ / 1000 * 25;;
-  int frame_shift_sample_length_ = sample_rate_ / 1000 * 10;
-  int lfr_m_;
-  int lfr_n_;
-  bool input_finished_ = false;
-
-};
-
-} // namespace funasr

+ 2 - 2
funasr/runtime/onnxruntime/src/paraformer.h

@@ -18,7 +18,7 @@ namespace funasr {
         //std::unique_ptr<knf::OnlineFbank> fbank_;
         knf::FbankOptions fbank_opts;
 
-        Vocab* vocab;
+        Vocab* vocab = nullptr;
         vector<float> means_list;
         vector<float> vars_list;
         const float scale = 22.6274169979695;
@@ -30,7 +30,7 @@ namespace funasr {
         void ApplyCmvn(vector<float> *v);
         string GreedySearch( float* in, int n_len, int64_t token_nums);
 
-        std::shared_ptr<Ort::Session> m_session;
+        std::shared_ptr<Ort::Session> m_session = nullptr;
         Ort::Env env_;
         Ort::SessionOptions session_options;
 

+ 2 - 1
funasr/runtime/onnxruntime/src/precomp.h

@@ -36,8 +36,9 @@ using namespace std;
 #include "offline-stream.h"
 #include "tokenizer.h"
 #include "ct-transformer.h"
-#include "fsmn-vad.h"
 #include "e2e-vad.h"
+#include "fsmn-vad.h"
+#include "fsmn-vad-online.h"
 #include "vocab.h"
 #include "audio.h"
 #include "tensor.h"

+ 9 - 6
funasr/runtime/onnxruntime/src/vad-model.cpp

@@ -1,14 +1,10 @@
 #include "precomp.h"
 
 namespace funasr {
-VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num, int mode)
+VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num)
 {
     VadModel *mm;
-    if(mode == FSMN_VAD_OFFLINE){
-        mm = new FsmnVad();
-    }else{
-        LOG(ERROR)<<"Online fsmn vad not imp!";
-    }
+    mm = new FsmnVad();
 
     string vad_model_path;
     string vad_cmvn_path;
@@ -25,4 +21,11 @@ VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thr
     return mm;
 }
 
+VadModel *CreateVadModel(void* fsmnvad_handle)
+{
+    VadModel *mm;
+    mm = new FsmnVadOnline((FsmnVad*)fsmnvad_handle);
+    return mm;
+}
+
 } // namespace funasr