雾聪 2 tahun lalu
induk
melakukan
fcc497dd73

+ 193 - 0
funasr/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp

@@ -0,0 +1,193 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License  (https://opensource.org/licenses/MIT)
+*/
+
+#ifndef _WIN32
+#include <sys/time.h>
+#else
+#include <win_func.h>
+#endif
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <map>
+#include <glog/logging.h>
+#include "funasrruntime.h"
+#include "tclap/CmdLine.h"
+#include "com-define.h"
+#include "audio.h"
+
+using namespace std;
+
+bool is_target_file(const std::string& filename, const std::string target) {
+    std::size_t pos = filename.find_last_of(".");
+    if (pos == std::string::npos) {
+        return false;
+    }
+    std::string extension = filename.substr(pos + 1);
+    return (extension == target);
+}
+
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
+{
+    if (value_arg.isSet()){
+        model_path.insert({key, value_arg.getValue()});
+        LOG(INFO)<< key << " : " << value_arg.getValue();
+    }
+}
+
+int main(int argc, char** argv)
+{
+    google::InitGoogleLogging(argv[0]);
+    FLAGS_logtostderr = true;
+
+    TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
+    TCLAP::ValueArg<std::string>    offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
+    TCLAP::ValueArg<std::string>    online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
+    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+    TCLAP::ValueArg<std::string>    vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
+    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
+    TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
+    TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
+
+    TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
+
+    cmd.add(offline_model_dir);
+    cmd.add(online_model_dir);
+    cmd.add(quantize);
+    cmd.add(vad_dir);
+    cmd.add(vad_quant);
+    cmd.add(punc_dir);
+    cmd.add(punc_quant);
+    cmd.add(wav_path);
+    cmd.parse(argc, argv);
+
+    std::map<std::string, std::string> model_path;
+    GetValue(offline_model_dir, OFFLINE_MODEL_DIR, model_path);
+    GetValue(online_model_dir, ONLINE_MODEL_DIR, model_path);
+    GetValue(quantize, QUANTIZE, model_path);
+    GetValue(vad_dir, VAD_DIR, model_path);
+    GetValue(vad_quant, VAD_QUANT, model_path);
+    GetValue(punc_dir, PUNC_DIR, model_path);
+    GetValue(punc_quant, PUNC_QUANT, model_path);
+    GetValue(wav_path, WAV_PATH, model_path);
+
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
+    int thread_num = 1;
+    FUNASR_HANDLE tpass_hanlde=FunTpassInit(model_path, thread_num);
+
+    if (!tpass_hanlde)
+    {
+        LOG(ERROR) << "FunTpassInit init failed";
+        exit(-1);
+    }
+
+    gettimeofday(&end, NULL);
+    long seconds = (end.tv_sec - start.tv_sec);
+    long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+    LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
+
+    // read wav_path
+    vector<string> wav_list;
+    vector<string> wav_ids;
+    string default_id = "wav_default_id";
+    string wav_path_ = model_path.at(WAV_PATH);
+
+    if(is_target_file(wav_path_, "scp")){
+        ifstream in(wav_path_);
+        if (!in.is_open()) {
+            LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
+            return 0;
+        }
+        string line;
+        while(getline(in, line))
+        {
+            istringstream iss(line);
+            string column1, column2;
+            iss >> column1 >> column2;
+            wav_list.emplace_back(column2);
+            wav_ids.emplace_back(column1);
+        }
+        in.close();
+    }else{
+        wav_list.emplace_back(wav_path_);
+        wav_ids.emplace_back(default_id);
+    }
+
+    // init online features
+    FunTpassOnlineInit(tpass_hanlde);
+    float snippet_time = 0.0f;
+    long taking_micros = 0;
+    for (int i = 0; i < wav_list.size(); i++) {
+        auto& wav_file = wav_list[i];
+        auto& wav_id = wav_ids[i];
+
+        int32_t sampling_rate_ = -1;
+        funasr::Audio audio(1);
+		if(is_target_file(wav_file.c_str(), "wav")){
+			if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
+				LOG(ERROR)<<"Failed to load "<< wav_file;
+                exit(-1);
+            }
+		}else if(is_target_file(wav_file.c_str(), "pcm")){
+			if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
+				LOG(ERROR)<<"Failed to load "<< wav_file;
+                exit(-1);
+            }
+		}else{
+			if (!audio.FfmpegLoad(wav_file.c_str(), true)){
+				LOG(ERROR)<<"Failed to load "<< wav_file;
+                exit(-1);
+            }
+		}
+        char* speech_buff = audio.GetSpeechChar();
+        int buff_len = audio.GetSpeechLen()*2;
+
+        int step = 1600*2;
+        bool is_final = false;
+
+        string online_res="";
+        string tpass_res="";
+        for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
+            if (sample_offset + step >= buff_len - 1) {
+                    step = buff_len - sample_offset;
+                    is_final = true;
+                } else {
+                    is_final = false;
+            }
+            gettimeofday(&start, NULL);
+            FUNASR_RESULT result = FunTpassInferBuffer(tpass_hanlde, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, 16000);
+            gettimeofday(&end, NULL);
+            seconds = (end.tv_sec - start.tv_sec);
+            taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+
+            if (result)
+            {
+                string online_msg = FunASRGetResult(result, 0);
+                online_res += online_msg;
+                if(online_msg != ""){
+                    LOG(INFO)<< wav_id <<" : "<<online_msg;
+                }
+                string tpass_msg = FunASRGetTpassResult(result, 0);
+                tpass_res += tpass_msg;
+                if(tpass_msg != ""){
+                    LOG(INFO)<< wav_id <<" 2pass results : "<<tpass_msg;
+                }
+                snippet_time += FunASRGetRetSnippetTime(result);
+                FunASRFreeResult(result);
+            }
+        }
+        LOG(INFO)<<"Final online results " << wav_id <<" : "<<online_res;
+        LOG(INFO)<<"Final 2pass  results " << wav_id <<" : "<<tpass_res;
+    }
+ 
+    LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
+    LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
+    LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
+    FunTpassUninit(tpass_hanlde);
+    return 0;
+}
+

+ 1 - 1
funasr/runtime/onnxruntime/bin/funasr-onnx-online-asr.cpp

@@ -63,7 +63,7 @@ int main(int argc, char *argv[])
     struct timeval start, end;
     struct timeval start, end;
     gettimeofday(&start, NULL);
     gettimeofday(&start, NULL);
     int thread_num = 1;
     int thread_num = 1;
-    FUNASR_HANDLE asr_handle=FunASRInit(model_path, thread_num, 1);
+    FUNASR_HANDLE asr_handle=FunASRInit(model_path, thread_num, ASR_ONLINE);
 
 
     if (!asr_handle)
     if (!asr_handle)
     {
     {

+ 1 - 2
funasr/runtime/onnxruntime/bin/funasr-onnx-online-rtf.cpp

@@ -209,8 +209,7 @@ int main(int argc, char *argv[])
 
 
     struct timeval start, end;
     struct timeval start, end;
     gettimeofday(&start, NULL);
     gettimeofday(&start, NULL);
-    int online_mode = 1;
-    FUNASR_HANDLE asr_handle=FunASRInit(model_path, 1, online_mode);
+    FUNASR_HANDLE asr_handle=FunASRInit(model_path, 1, ASR_ONLINE);
 
 
     if (!asr_handle)
     if (!asr_handle)
     {
     {

+ 1 - 1
funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp

@@ -159,7 +159,7 @@ int main(int argc, char *argv[])
         char* speech_buff = audio.GetSpeechChar();
         char* speech_buff = audio.GetSpeechChar();
         int buff_len = audio.GetSpeechLen()*2;
         int buff_len = audio.GetSpeechLen()*2;
 
 
-        int step = 3200;
+        int step = 1600*2;
         bool is_final = false;
         bool is_final = false;
 
 
         for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
         for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {

+ 29 - 3
funasr/runtime/onnxruntime/include/audio.h

@@ -5,6 +5,7 @@
 #include <stdint.h>
 #include <stdint.h>
 #include "vad-model.h"
 #include "vad-model.h"
 #include "offline-stream.h"
 #include "offline-stream.h"
+#include "com-define.h"
 
 
 #ifndef WAV_HEADER_SIZE
 #ifndef WAV_HEADER_SIZE
 #define WAV_HEADER_SIZE 44
 #define WAV_HEADER_SIZE 44
@@ -17,11 +18,13 @@ class AudioFrame {
   private:
   private:
     int start;
     int start;
     int end;
     int end;
-    int len;
+
 
 
   public:
   public:
     AudioFrame();
     AudioFrame();
     AudioFrame(int len);
     AudioFrame(int len);
+    AudioFrame(const AudioFrame &other);
+    AudioFrame(int start, int end, bool is_final);
 
 
     ~AudioFrame();
     ~AudioFrame();
     int SetStart(int val);
     int SetStart(int val);
@@ -29,6 +32,10 @@ class AudioFrame {
     int GetStart();
     int GetStart();
     int GetLen();
     int GetLen();
     int Disp();
     int Disp();
+    // 2pass
+    bool is_final = false;
+    float* data = nullptr;
+    int len;
 };
 };
 
 
 class Audio {
 class Audio {
@@ -38,10 +45,11 @@ class Audio {
     char* speech_char=nullptr;
     char* speech_char=nullptr;
     int speech_len;
     int speech_len;
     int speech_align_len;
     int speech_align_len;
-    int offset;
     float align_size;
     float align_size;
     int data_type;
     int data_type;
     queue<AudioFrame *> frame_queue;
     queue<AudioFrame *> frame_queue;
+    queue<AudioFrame *> asr_online_queue;
+    queue<AudioFrame *> asr_offline_queue;
 
 
   public:
   public:
     Audio(int data_type);
     Audio(int data_type);
@@ -58,15 +66,33 @@ class Audio {
     bool LoadOthers2Char(const char* filename);
     bool LoadOthers2Char(const char* filename);
     bool FfmpegLoad(const char *filename, bool copy2char=false);
     bool FfmpegLoad(const char *filename, bool copy2char=false);
     bool FfmpegLoad(const char* buf, int n_file_len);
     bool FfmpegLoad(const char* buf, int n_file_len);
-    int FetchChunck(float *&dout, int len);
+    int FetchChunck(AudioFrame *&frame);
+    int FetchTpass(AudioFrame *&frame);
     int Fetch(float *&dout, int &len, int &flag);
     int Fetch(float *&dout, int &len, int &flag);
     void Padding();
     void Padding();
     void Split(OfflineStream* offline_streamj);
     void Split(OfflineStream* offline_streamj);
     void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
     void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
+    void Split(VadModel* vad_obj, bool input_finished=true);
     float GetTimeLen();
     float GetTimeLen();
     int GetQueueSize() { return (int)frame_queue.size(); }
     int GetQueueSize() { return (int)frame_queue.size(); }
     char* GetSpeechChar(){return speech_char;}
     char* GetSpeechChar(){return speech_char;}
     int GetSpeechLen(){return speech_len;}
     int GetSpeechLen(){return speech_len;}
+
+    // 2pass
+    vector<float> all_samples;
+    int offset = 0;
+    int speech_start=-1, speech_end=0;
+    int speech_offline_start=-1;
+
+    int seg_sample = MODEL_SAMPLE_RATE/1000;
+    bool LoadPcmwavOnline(const char* buf, int n_file_len, int32_t* sampling_rate);
+    void ResetIndex(){
+      speech_start=-1;
+      speech_end=0;
+      speech_offline_start=-1;
+      offset = 0;
+      all_samples.clear();
+    }
 };
 };
 
 
 } // namespace funasr
 } // namespace funasr

+ 6 - 0
funasr/runtime/onnxruntime/include/com-define.h

@@ -13,6 +13,8 @@ namespace funasr {
 
 
 // parser option
 // parser option
 #define MODEL_DIR "model-dir"
 #define MODEL_DIR "model-dir"
+#define OFFLINE_MODEL_DIR "offline-model-dir"
+#define ONLINE_MODEL_DIR "online-model-dir"
 #define VAD_DIR "vad-dir"
 #define VAD_DIR "vad-dir"
 #define PUNC_DIR "punc-dir"
 #define PUNC_DIR "punc-dir"
 #define QUANTIZE "quantize"
 #define QUANTIZE "quantize"
@@ -77,6 +79,10 @@ namespace funasr {
 #define PARA_LFR_N 6
 #define PARA_LFR_N 6
 #endif
 #endif
 
 
+#ifndef ONLINE_STEP
+#define ONLINE_STEP 9600
+#endif
+
 // punc
 // punc
 #define UNK_CHAR "<unk>"
 #define UNK_CHAR "<unk>"
 #define TOKEN_LEN     20
 #define TOKEN_LEN     20

+ 15 - 1
funasr/runtime/onnxruntime/include/funasrruntime.h

@@ -46,6 +46,12 @@ typedef enum {
 	FUNASR_MODEL_PARAFORMER = 3,
 	FUNASR_MODEL_PARAFORMER = 3,
 }FUNASR_MODEL_TYPE;
 }FUNASR_MODEL_TYPE;
 
 
+typedef enum {
+	ASR_OFFLINE=0,
+	ASR_ONLINE=1,
+	ASR_TWO_PASS=2,
+}ASR_TYPE;
+
 typedef enum {
 typedef enum {
 	PUNC_OFFLINE=0,
 	PUNC_OFFLINE=0,
 	PUNC_ONLINE=1,
 	PUNC_ONLINE=1,
@@ -54,7 +60,7 @@ typedef enum {
 typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
 typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
 
 
 // ASR
 // ASR
-_FUNASRAPI FUNASR_HANDLE  	FunASRInit(std::map<std::string, std::string>& model_path, int thread_num, int mode=0);
+_FUNASRAPI FUNASR_HANDLE  	FunASRInit(std::map<std::string, std::string>& model_path, int thread_num, ASR_TYPE type=ASR_OFFLINE);
 _FUNASRAPI FUNASR_HANDLE  	FunASROnlineInit(FUNASR_HANDLE asr_handle);
 _FUNASRAPI FUNASR_HANDLE  	FunASROnlineInit(FUNASR_HANDLE asr_handle);
 // buffer
 // buffer
 _FUNASRAPI FUNASR_RESULT	FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000, std::string wav_format="pcm");
 _FUNASRAPI FUNASR_RESULT	FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000, std::string wav_format="pcm");
@@ -62,6 +68,7 @@ _FUNASRAPI FUNASR_RESULT	FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_
 _FUNASRAPI FUNASR_RESULT	FunASRInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
 _FUNASRAPI FUNASR_RESULT	FunASRInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
 
 
 _FUNASRAPI const char*	FunASRGetResult(FUNASR_RESULT result,int n_index);
 _FUNASRAPI const char*	FunASRGetResult(FUNASR_RESULT result,int n_index);
+_FUNASRAPI const char*	FunASRGetTpassResult(FUNASR_RESULT result,int n_index);
 _FUNASRAPI const int	FunASRGetRetNumber(FUNASR_RESULT result);
 _FUNASRAPI const int	FunASRGetRetNumber(FUNASR_RESULT result);
 _FUNASRAPI void			FunASRFreeResult(FUNASR_RESULT result);
 _FUNASRAPI void			FunASRFreeResult(FUNASR_RESULT result);
 _FUNASRAPI void			FunASRUninit(FUNASR_HANDLE handle);
 _FUNASRAPI void			FunASRUninit(FUNASR_HANDLE handle);
@@ -95,6 +102,13 @@ _FUNASRAPI FUNASR_RESULT	FunOfflineInferBuffer(FUNASR_HANDLE handle, const char*
 _FUNASRAPI FUNASR_RESULT	FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
 _FUNASRAPI FUNASR_RESULT	FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
 _FUNASRAPI void				FunOfflineUninit(FUNASR_HANDLE handle);
 _FUNASRAPI void				FunOfflineUninit(FUNASR_HANDLE handle);
 
 
+//2passStream
+_FUNASRAPI FUNASR_HANDLE  	FunTpassInit(std::map<std::string, std::string>& model_path, int thread_num);
+_FUNASRAPI void  	        FunTpassOnlineInit(FUNASR_HANDLE tpass_handle);
+// buffer
+_FUNASRAPI FUNASR_RESULT	FunTpassInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000, std::string wav_format="pcm");
+_FUNASRAPI void				FunTpassUninit(FUNASR_HANDLE handle);
+
 #ifdef __cplusplus 
 #ifdef __cplusplus 
 
 
 }
 }

+ 3 - 1
funasr/runtime/onnxruntime/include/model.h

@@ -4,6 +4,7 @@
 
 
 #include <string>
 #include <string>
 #include <map>
 #include <map>
+#include "funasrruntime.h"
 namespace funasr {
 namespace funasr {
 class Model {
 class Model {
   public:
   public:
@@ -11,11 +12,12 @@ class Model {
     virtual void Reset() = 0;
     virtual void Reset() = 0;
     virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
     virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
     virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
     virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
+    virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
     virtual std::string Forward(float *din, int len, bool input_finished){return "";};
     virtual std::string Forward(float *din, int len, bool input_finished){return "";};
     virtual std::string Rescoring() = 0;
     virtual std::string Rescoring() = 0;
 };
 };
 
 
-Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, int mode=0);
+Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
 Model *CreateModel(void* asr_handle);
 Model *CreateModel(void* asr_handle);
 
 
 } // namespace funasr
 } // namespace funasr

+ 3 - 3
funasr/runtime/onnxruntime/include/offline-stream.h

@@ -14,9 +14,9 @@ class OfflineStream {
     OfflineStream(std::map<std::string, std::string>& model_path, int thread_num);
     OfflineStream(std::map<std::string, std::string>& model_path, int thread_num);
     ~OfflineStream(){};
     ~OfflineStream(){};
 
 
-    std::unique_ptr<VadModel> vad_handle;
-    std::unique_ptr<Model> asr_handle;
-    std::unique_ptr<PuncModel> punc_handle;
+    std::unique_ptr<VadModel> vad_handle= nullptr;
+    std::unique_ptr<Model> asr_handle= nullptr;
+    std::unique_ptr<PuncModel> punc_handle= nullptr;
     bool UseVad(){return use_vad;};
     bool UseVad(){return use_vad;};
     bool UsePunc(){return use_punc;}; 
     bool UsePunc(){return use_punc;}; 
     
     

+ 34 - 0
funasr/runtime/onnxruntime/include/tpass-stream.h

@@ -0,0 +1,34 @@
+#ifndef TPASS_STREAM_H
+#define TPASS_STREAM_H
+
+#include <memory>
+#include <string>
+#include <map>
+#include "model.h"
+#include "punc-model.h"
+#include "vad-model.h"
+
+namespace funasr {
+class TpassStream {
+  public:
+    TpassStream(std::map<std::string, std::string>& model_path, int thread_num);
+    ~TpassStream(){};
+
+    // std::unique_ptr<VadModel> vad_handle = nullptr;
+    std::unique_ptr<VadModel> vad_handle = nullptr;
+    std::unique_ptr<VadModel> vad_online_handle = nullptr;
+    std::unique_ptr<Model> asr_handle = nullptr;
+    std::unique_ptr<Model> asr_online_handle = nullptr;
+    std::unique_ptr<PuncModel> punc_online_handle = nullptr;
+    bool UseVad(){return use_vad;};
+    bool UsePunc(){return use_punc;}; 
+    
+  private:
+    bool use_vad=false;
+    bool use_punc=false;
+};
+
+TpassStream *CreateTpassStream(std::map<std::string, std::string>& model_path, int thread_num=1);
+void CreateTpassOnlineStream(void* tpass_stream);
+} // namespace funasr
+#endif

+ 254 - 25
funasr/runtime/onnxruntime/src/audio.cpp

@@ -132,40 +132,54 @@ class AudioWindow {
     };
     };
 };
 };
 
 
-AudioFrame::AudioFrame(){};
+AudioFrame::AudioFrame(){}
 AudioFrame::AudioFrame(int len) : len(len)
 AudioFrame::AudioFrame(int len) : len(len)
 {
 {
     start = 0;
     start = 0;
-};
-AudioFrame::~AudioFrame(){};
+}
+AudioFrame::AudioFrame(const AudioFrame &other)
+{
+    start = other.start;
+    end = other.end;
+    len = other.len;
+    is_final = other.is_final;
+}
+AudioFrame::AudioFrame(int start, int end, bool is_final):start(start),end(end),is_final(is_final){
+    len = end - start;
+}
+AudioFrame::~AudioFrame(){
+    if(data != NULL){
+        free(data);
+    }
+}
 int AudioFrame::SetStart(int val)
 int AudioFrame::SetStart(int val)
 {
 {
     start = val < 0 ? 0 : val;
     start = val < 0 ? 0 : val;
     return start;
     return start;
-};
+}
 
 
 int AudioFrame::SetEnd(int val)
 int AudioFrame::SetEnd(int val)
 {
 {
     end = val;
     end = val;
     len = end - start;
     len = end - start;
     return end;
     return end;
-};
+}
 
 
 int AudioFrame::GetStart()
 int AudioFrame::GetStart()
 {
 {
     return start;
     return start;
-};
+}
 
 
 int AudioFrame::GetLen()
 int AudioFrame::GetLen()
 {
 {
     return len;
     return len;
-};
+}
 
 
 int AudioFrame::Disp()
 int AudioFrame::Disp()
 {
 {
     LOG(ERROR) << "Not imp!!!!";
     LOG(ERROR) << "Not imp!!!!";
     return 0;
     return 0;
-};
+}
 
 
 Audio::Audio(int data_type) : data_type(data_type)
 Audio::Audio(int data_type) : data_type(data_type)
 {
 {
@@ -771,6 +785,55 @@ bool Audio::LoadPcmwav(const char* buf, int n_buf_len, int32_t* sampling_rate)
         return false;
         return false;
 }
 }
 
 
+bool Audio::LoadPcmwavOnline(const char* buf, int n_buf_len, int32_t* sampling_rate)
+{
+    if (speech_data != NULL) {
+        free(speech_data);
+    }
+    if (speech_buff != NULL) {
+        free(speech_buff);
+    }
+    if (speech_char != NULL) {
+        free(speech_char);
+    }
+
+    speech_len = n_buf_len / 2;
+    speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
+    if (speech_buff)
+    {
+        memset(speech_buff, 0, sizeof(int16_t) * speech_len);
+        memcpy((void*)speech_buff, (const void*)buf, speech_len * sizeof(int16_t));
+
+        speech_data = (float*)malloc(sizeof(float) * speech_len);
+        memset(speech_data, 0, sizeof(float) * speech_len);
+
+        float scale = 1;
+        if (data_type == 1) {
+            scale = 32768;
+        }
+
+        for (int32_t i = 0; i != speech_len; ++i) {
+            speech_data[i] = (float)speech_buff[i] / scale;
+        }
+        
+        //resample
+        if(*sampling_rate != MODEL_SAMPLE_RATE){
+            WavResample(*sampling_rate, speech_data, speech_len);
+        }
+
+        for (int32_t i = 0; i != speech_len; ++i) {
+            all_samples.emplace_back(speech_data[i]);
+        }
+
+        AudioFrame* frame = new AudioFrame(speech_len);
+        frame_queue.push(frame);
+        return true;
+
+    }
+    else
+        return false;
+}
+
 bool Audio::LoadPcmwav(const char* filename, int32_t* sampling_rate)
 bool Audio::LoadPcmwav(const char* filename, int32_t* sampling_rate)
 {
 {
     if (speech_data != NULL) {
     if (speech_data != NULL) {
@@ -879,24 +942,25 @@ bool Audio::LoadOthers2Char(const char* filename)
     return true;
     return true;
 }
 }
 
 
-int Audio::FetchChunck(float *&dout, int len)
+int Audio::FetchTpass(AudioFrame *&frame)
 {
 {
-    if (offset >= speech_align_len) {
-        dout = NULL;
-        return S_ERR;
-    } else if (offset == speech_align_len - len) {
-        dout = speech_data + offset;
-        offset = speech_align_len;
-        // 临时解决 
-        AudioFrame *frame = frame_queue.front();
-        frame_queue.pop();
-        delete frame;
+    if (asr_offline_queue.size() > 0) {
+        frame = asr_offline_queue.front();
+        asr_offline_queue.pop();
+        return 1;
+    } else {
+        return 0;
+    }
+}
 
 
-        return S_END;
+int Audio::FetchChunck(AudioFrame *&frame)
+{
+    if (asr_online_queue.size() > 0) {
+        frame = asr_online_queue.front();
+        asr_online_queue.pop();
+        return 1;
     } else {
     } else {
-        dout = speech_data + offset;
-        offset += len;
-        return S_MIDDLE;
+        return 0;
     }
     }
 }
 }
 
 
@@ -965,7 +1029,6 @@ void Audio::Split(OfflineStream* offline_stream)
 
 
     std::vector<float> pcm_data(speech_data, speech_data+sp_len);
     std::vector<float> pcm_data(speech_data, speech_data+sp_len);
     vector<std::vector<int>> vad_segments = (offline_stream->vad_handle)->Infer(pcm_data);
     vector<std::vector<int>> vad_segments = (offline_stream->vad_handle)->Infer(pcm_data);
-    int seg_sample = MODEL_SAMPLE_RATE/1000;
     for(vector<int> segment:vad_segments)
     for(vector<int> segment:vad_segments)
     {
     {
         frame = new AudioFrame();
         frame = new AudioFrame();
@@ -978,7 +1041,6 @@ void Audio::Split(OfflineStream* offline_stream)
     }
     }
 }
 }
 
 
-
 void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished)
 void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished)
 {
 {
     AudioFrame *frame;
     AudioFrame *frame;
@@ -993,4 +1055,171 @@ void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, boo
     vad_segments = vad_obj->Infer(pcm_data, input_finished);
     vad_segments = vad_obj->Infer(pcm_data, input_finished);
 }
 }
 
 
+// 2pass
+void Audio::Split(VadModel* vad_obj, bool input_finished)
+{
+    AudioFrame *frame;
+
+    frame = frame_queue.front();
+    frame_queue.pop();
+    int sp_len = frame->GetLen();
+    delete frame;
+    frame = NULL;
+
+    std::vector<float> pcm_data(speech_data, speech_data+sp_len);
+    vector<std::vector<int>> vad_segments = vad_obj->Infer(pcm_data, input_finished);
+
+    // print vad_segments
+    // string seg_out="[";
+    // for (int i = 0; i < vad_segments.size(); i++) {
+    //     vector<int> inner_vec = vad_segments[i];
+    //     if(inner_vec.size() == 0){
+    //         continue;
+    //     }
+    //     seg_out += "[";
+    //     for (int j = 0; j < inner_vec.size(); j++) {
+    //         seg_out += to_string(inner_vec[j]);
+    //         if (j != inner_vec.size() - 1) {
+    //             seg_out += ",";
+    //         }
+    //     }
+    //     seg_out += "]";
+    //     if (i != vad_segments.size() - 1) {
+    //         seg_out += ",";
+    //     }
+    // }
+    // seg_out += "]";
+    // LOG(INFO)<<seg_out;
+
+    speech_end += sp_len/seg_sample;
+    if(vad_segments.size() == 0){
+        if(speech_start != -1){
+            int start = speech_start*seg_sample;
+            int end = speech_end*seg_sample;
+            int buff_len = end-start;
+            int step = ONLINE_STEP;
+
+            if(buff_len >= step){
+                frame = new AudioFrame(step);
+                frame->data = (float*)malloc(sizeof(float) * step);
+                memcpy(frame->data, all_samples.data()+start-offset, step*sizeof(float));
+                asr_online_queue.push(frame);
+                frame = NULL;
+                speech_start += step/seg_sample;
+            }
+        }
+    }else{
+        for(auto vad_segment: vad_segments){
+            int speech_start_i=-1, speech_end_i=-1;
+            if(vad_segment[0] != -1){
+                speech_start_i = vad_segment[0];
+            }
+            if(vad_segment[1] != -1){
+                speech_end_i = vad_segment[1];
+            }
+
+            // [1, 100]
+            if(speech_start_i != -1 && speech_end_i != -1){
+                int start = speech_start_i*seg_sample;
+                int end = speech_end_i*seg_sample;
+
+                frame = new AudioFrame(end-start);
+                frame->is_final = true;
+                frame->data = (float*)malloc(sizeof(float) * (end-start));
+                memcpy(frame->data, all_samples.data()+start-offset, (end-start)*sizeof(float));
+                asr_online_queue.push(frame);
+                frame = NULL;
+
+                frame = new AudioFrame(end-start);
+                frame->is_final = true;
+                frame->data = (float*)malloc(sizeof(float) * (end-start));
+                memcpy(frame->data, all_samples.data()+start-offset, (end-start)*sizeof(float));
+                asr_offline_queue.push(frame);
+                frame = NULL;
+
+                speech_start = -1;
+                speech_offline_start = -1;
+            // [70, -1]
+            }else if(speech_start_i != -1){
+                speech_start = speech_start_i;
+                speech_offline_start = speech_start_i;
+                
+                int start = speech_start*seg_sample;
+                int end = speech_end*seg_sample;
+                int buff_len = end-start;
+                int step = ONLINE_STEP;
+
+                if(buff_len >= step){
+                    frame = new AudioFrame(step);
+                    frame->data = (float*)malloc(sizeof(float) * step);
+                    memcpy(frame->data, all_samples.data()+start-offset, step*sizeof(float));
+                    asr_online_queue.push(frame);
+                    frame = NULL;
+                    speech_start += step/seg_sample;
+                }
+
+            }else if(speech_end_i != -1){ // [-1,100]
+                if(speech_start == -1 or speech_offline_start == -1){
+                    LOG(ERROR) <<"Vad start is null while vad end is available." ;
+                    exit(-1);
+                }
+
+                int start = speech_start*seg_sample;
+                int offline_start = speech_offline_start*seg_sample;
+                int end = speech_end_i*seg_sample;
+                int buff_len = end-start;
+                int step = ONLINE_STEP;
+
+                frame = new AudioFrame(end-offline_start);
+                frame->is_final = true;
+                frame->data = (float*)malloc(sizeof(float) * (end-offline_start));
+                memcpy(frame->data, all_samples.data()+offline_start-offset, (end-offline_start)*sizeof(float));
+                asr_offline_queue.push(frame);
+                frame = NULL;
+
+                if(buff_len > 0){
+                    for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
+                        bool is_final = false;
+                        if (sample_offset + step >= buff_len - 1) {
+                            step = buff_len - sample_offset;
+                            is_final = true;
+                        }
+                        frame = new AudioFrame(step);
+                        frame->is_final = is_final;
+                        frame->data = (float*)malloc(sizeof(float) * step);
+                        memcpy(frame->data, all_samples.data()+start-offset+sample_offset, step*sizeof(float));
+                        asr_online_queue.push(frame);
+                        frame = NULL;
+                    }
+                }else{
+                    frame = new AudioFrame(0);
+                    frame->is_final = true;
+                    asr_online_queue.push(frame);
+                    frame = NULL;
+                }
+                speech_start = -1;
+                speech_offline_start = -1;
+            }
+        }
+    }
+
+    // erase all_samples
+    int vector_cache = MODEL_SAMPLE_RATE*2;
+    if(speech_offline_start == -1){
+        if(all_samples.size() > vector_cache){
+            int erase_num = all_samples.size() - vector_cache;
+            all_samples.erase(all_samples.begin(), all_samples.begin()+erase_num);
+            offset += erase_num;
+        }
+    }else{
+        int offline_start = speech_offline_start*seg_sample;
+         if(offline_start-offset > vector_cache){
+            int erase_num = offline_start-offset - vector_cache;
+            all_samples.erase(all_samples.begin(), all_samples.begin()+erase_num);
+            offset += erase_num;
+        }       
+    }
+    
+}
+
 } // namespace funasr
 } // namespace funasr

+ 2 - 1
funasr/runtime/onnxruntime/src/commonfunc.h

@@ -5,7 +5,8 @@ namespace funasr {
 typedef struct
 typedef struct
 {
 {
     std::string msg;
     std::string msg;
-    float  snippet_time;
+    std::string tpass_msg;
+    float snippet_time;
 }FUNASR_RECOG_RESULT;
 }FUNASR_RECOG_RESULT;
 
 
 typedef struct
 typedef struct

+ 3 - 0
funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp

@@ -175,6 +175,9 @@ void FsmnVadOnline::InitOnline(std::shared_ptr<Ort::Session> &vad_session,
     vad_silence_duration_ = vad_silence_duration;
     vad_silence_duration_ = vad_silence_duration;
     vad_max_len_ = vad_max_len;
     vad_max_len_ = vad_max_len;
     vad_speech_noise_thres_ = vad_speech_noise_thres;
     vad_speech_noise_thres_ = vad_speech_noise_thres;
+
+    // 2pass
+    audio_handle = make_unique<Audio>(1);
 }
 }
 
 
 FsmnVadOnline::~FsmnVadOnline() {
 FsmnVadOnline::~FsmnVadOnline() {

+ 2 - 0
funasr/runtime/onnxruntime/src/fsmn-vad-online.h

@@ -21,6 +21,8 @@ public:
     std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
     std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
     void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
     void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
     void Reset();
     void Reset();
+    // 2pass
+    std::unique_ptr<Audio> audio_handle = nullptr;
 
 
 private:
 private:
     E2EVadModel vad_scorer = E2EVadModel();
     E2EVadModel vad_scorer = E2EVadModel();

+ 97 - 2
funasr/runtime/onnxruntime/src/funasrruntime.cpp

@@ -5,9 +5,9 @@ extern "C" {
 #endif
 #endif
 
 
 	// APIs for Init
 	// APIs for Init
-	_FUNASRAPI FUNASR_HANDLE  FunASRInit(std::map<std::string, std::string>& model_path, int thread_num, int mode)
+	_FUNASRAPI FUNASR_HANDLE  FunASRInit(std::map<std::string, std::string>& model_path, int thread_num, ASR_TYPE type)
 	{
 	{
-		funasr::Model* mm = funasr::CreateModel(model_path, thread_num, mode);
+		funasr::Model* mm = funasr::CreateModel(model_path, thread_num, type);
 		return mm;
 		return mm;
 	}
 	}
 
 
@@ -41,6 +41,17 @@ extern "C" {
 		return mm;
 		return mm;
 	}
 	}
 
 
+	_FUNASRAPI FUNASR_HANDLE  FunTpassInit(std::map<std::string, std::string>& model_path, int thread_num)
+	{
+		funasr::TpassStream* mm = funasr::CreateTpassStream(model_path, thread_num);
+		return mm;
+	}
+
+	_FUNASRAPI void FunTpassOnlineInit(FUNASR_HANDLE tpass_handle)
+	{
+		funasr::CreateTpassOnlineStream(tpass_handle);
+	}
+
 	// APIs for ASR Infer
 	// APIs for ASR Infer
 	_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate, std::string wav_format)
 	_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate, std::string wav_format)
 	{
 	{
@@ -297,6 +308,71 @@ extern "C" {
 		return p_result;
 		return p_result;
 	}
 	}
 
 
+	// APIs for 2pass-stream Infer
+	_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate, std::string wav_format)
+	{
+		funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
+		if (!tpass_stream)
+			return nullptr;
+		
+		funasr::VadModel* vad_online_handle = (tpass_stream->vad_online_handle).get();
+		if (!vad_online_handle)
+			return nullptr;
+
+		funasr::Audio* audio = ((funasr::FsmnVadOnline*)vad_online_handle)->audio_handle.get();
+
+		funasr::Model* asr_online_handle = (tpass_stream->asr_online_handle).get();
+		if (!asr_online_handle)
+			return nullptr;
+		
+		funasr::Model* asr_handle = (tpass_stream->asr_handle).get();
+		if (!asr_handle)
+			return nullptr;
+
+		if(wav_format == "pcm" || wav_format == "PCM"){
+			if (!audio->LoadPcmwavOnline(sz_buf, n_len, &sampling_rate))
+				return nullptr;
+		}else{
+			// if (!audio->FfmpegLoad(sz_buf, n_len))
+			// 	return nullptr;
+			LOG(ERROR) <<"Wrong wav_format: " << wav_format ;
+			exit(-1);
+		}
+
+		funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
+		p_result->snippet_time = audio->GetTimeLen();
+		if(p_result->snippet_time == 0){
+			return p_result;
+		}
+
+		audio->Split(vad_online_handle, input_finished);
+
+		funasr::AudioFrame* frame = NULL;
+		while(audio->FetchChunck(frame) > 0){
+			string msg = asr_online_handle->Forward(frame->data, frame->len, frame->is_final);
+			p_result->msg += msg;
+			if(frame != NULL){
+				delete frame;
+				frame = NULL;
+			}
+		}
+
+		while(audio->FetchTpass(frame) > 0){
+			string msg = asr_handle->Forward(frame->data, frame->len, frame->is_final);
+			p_result->tpass_msg += msg;
+			if(frame != NULL){
+				delete frame;
+				frame = NULL;
+			}
+		}
+
+		if(input_finished){
+			audio->ResetIndex();
+		}
+
+		return p_result;
+	}
+
 	_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result)
 	_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result)
 	{
 	{
 		if (!result)
 		if (!result)
@@ -332,6 +408,15 @@ extern "C" {
 		return p_result->msg.c_str();
 		return p_result->msg.c_str();
 	}
 	}
 
 
+	_FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index)
+	{
+		funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
+		if(!p_result)
+			return nullptr;
+
+		return p_result->tpass_msg.c_str();
+	}
+
 	_FUNASRAPI const char* CTTransformerGetResult(FUNASR_RESULT result,int n_index)
 	_FUNASRAPI const char* CTTransformerGetResult(FUNASR_RESULT result,int n_index)
 	{
 	{
 		funasr::FUNASR_PUNC_RESULT * p_result = (funasr::FUNASR_PUNC_RESULT*)result;
 		funasr::FUNASR_PUNC_RESULT * p_result = (funasr::FUNASR_PUNC_RESULT*)result;
@@ -420,6 +505,16 @@ extern "C" {
 		delete offline_stream;
 		delete offline_stream;
 	}
 	}
 
 
+	_FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle)
+	{
+		funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
+
+		if (!tpass_stream)
+			return;
+
+		delete tpass_stream;
+	}
+
 #ifdef __cplusplus 
 #ifdef __cplusplus 
 
 
 }
 }

+ 7 - 3
funasr/runtime/onnxruntime/src/model.cpp

@@ -1,9 +1,10 @@
 #include "precomp.h"
 #include "precomp.h"
 
 
 namespace funasr {
 namespace funasr {
-Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num, int mode)
+Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num, ASR_TYPE type)
 {
 {
-    if(mode == 0){
+    // offline
+    if(type == ASR_OFFLINE){
         string am_model_path;
         string am_model_path;
         string am_cmvn_path;
         string am_cmvn_path;
         string am_config_path;
         string am_config_path;
@@ -19,7 +20,7 @@ Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_nu
         mm = new Paraformer();
         mm = new Paraformer();
         mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
         mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
         return mm;
         return mm;
-    }else if(mode == 1){
+    }else if(type == ASR_ONLINE){
         // online
         // online
         string en_model_path;
         string en_model_path;
         string de_model_path;
         string de_model_path;
@@ -39,6 +40,9 @@ Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_nu
         mm = new Paraformer();
         mm = new Paraformer();
         mm->InitAsr(en_model_path, de_model_path, am_cmvn_path, am_config_path, thread_num);
         mm->InitAsr(en_model_path, de_model_path, am_cmvn_path, am_config_path, thread_num);
         return mm;
         return mm;
+    }else{
+        LOG(ERROR)<<"Wrong ASR_TYPE : " << type;
+        exit(-1);
     }
     }
 }
 }
 
 

+ 37 - 6
funasr/runtime/onnxruntime/src/paraformer.cpp

@@ -33,7 +33,7 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn
     session_options_.DisableCpuMemArena();
     session_options_.DisableCpuMemArena();
 
 
     try {
     try {
-        m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options_);
+        m_session_ = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options_);
         LOG(INFO) << "Successfully load model from " << am_model;
         LOG(INFO) << "Successfully load model from " << am_model;
     } catch (std::exception const &e) {
     } catch (std::exception const &e) {
         LOG(ERROR) << "Error when load am onnx model: " << e.what();
         LOG(ERROR) << "Error when load am onnx model: " << e.what();
@@ -41,14 +41,14 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn
     }
     }
 
 
     string strName;
     string strName;
-    GetInputName(m_session.get(), strName);
+    GetInputName(m_session_.get(), strName);
     m_strInputNames.push_back(strName.c_str());
     m_strInputNames.push_back(strName.c_str());
-    GetInputName(m_session.get(), strName,1);
+    GetInputName(m_session_.get(), strName,1);
     m_strInputNames.push_back(strName);
     m_strInputNames.push_back(strName);
     
     
-    GetOutputName(m_session.get(), strName);
+    GetOutputName(m_session_.get(), strName);
     m_strOutputNames.push_back(strName);
     m_strOutputNames.push_back(strName);
-    GetOutputName(m_session.get(), strName,1);
+    GetOutputName(m_session_.get(), strName,1);
     m_strOutputNames.push_back(strName);
     m_strOutputNames.push_back(strName);
 
 
     for (auto& item : m_strInputNames)
     for (auto& item : m_strInputNames)
@@ -136,6 +136,37 @@ void Paraformer::InitAsr(const std::string &en_model, const std::string &de_mode
     LoadCmvn(am_cmvn.c_str());
     LoadCmvn(am_cmvn.c_str());
 }
 }
 
 
+// 2pass
+void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+    // online
+    InitAsr(en_model, de_model, am_cmvn, am_config, thread_num);
+
+    // offline
+    try {
+        m_session_ = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options_);
+        LOG(INFO) << "Successfully load model from " << am_model;
+    } catch (std::exception const &e) {
+        LOG(ERROR) << "Error when load am onnx model: " << e.what();
+        exit(0);
+    }
+
+    string strName;
+    GetInputName(m_session_.get(), strName);
+    m_strInputNames.push_back(strName.c_str());
+    GetInputName(m_session_.get(), strName,1);
+    m_strInputNames.push_back(strName);
+    
+    GetOutputName(m_session_.get(), strName);
+    m_strOutputNames.push_back(strName);
+    GetOutputName(m_session_.get(), strName,1);
+    m_strOutputNames.push_back(strName);
+
+    for (auto& item : m_strInputNames)
+        m_szInputNames.push_back(item.c_str());
+    for (auto& item : m_strOutputNames)
+        m_szOutputNames.push_back(item.c_str());
+}
+
 void Paraformer::LoadOnlineConfigFromYaml(const char* filename){
 void Paraformer::LoadOnlineConfigFromYaml(const char* filename){
 
 
     YAML::Node config;
     YAML::Node config;
@@ -332,7 +363,7 @@ string Paraformer::Forward(float* din, int len, bool input_finished)
 
 
     string result;
     string result;
     try {
     try {
-        auto outputTensor = m_session->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
+        auto outputTensor = m_session_->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
         std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
         std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
 
 
         int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
         int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());

+ 5 - 2
funasr/runtime/onnxruntime/src/paraformer.h

@@ -30,6 +30,8 @@ namespace funasr {
         void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
         void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
         // online
         // online
         void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
         void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
+        // 2pass
+        void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
         void Reset();
         void Reset();
         vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
         vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
         string Forward(float* din, int len, bool input_finished=true);
         string Forward(float* din, int len, bool input_finished=true);
@@ -42,7 +44,8 @@ namespace funasr {
         int lfr_m = PARA_LFR_M;
         int lfr_m = PARA_LFR_M;
         int lfr_n = PARA_LFR_N;
         int lfr_n = PARA_LFR_N;
 
 
-        std::shared_ptr<Ort::Session> m_session = nullptr;
+        // paraformer-offline
+        std::shared_ptr<Ort::Session> m_session_ = nullptr;
         Ort::Env env_;
         Ort::Env env_;
         Ort::SessionOptions session_options_;
         Ort::SessionOptions session_options_;
 
 
@@ -50,7 +53,7 @@ namespace funasr {
         vector<const char*> m_szInputNames;
         vector<const char*> m_szInputNames;
         vector<const char*> m_szOutputNames;
         vector<const char*> m_szOutputNames;
 
 
-        //paraformer-online
+        // paraformer-online
         std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
         std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
         std::shared_ptr<Ort::Session> decoder_session_ = nullptr;
         std::shared_ptr<Ort::Session> decoder_session_ = nullptr;
         vector<string> en_strInputNames, en_strOutputNames;
         vector<string> en_strInputNames, en_strOutputNames;

+ 2 - 2
funasr/runtime/onnxruntime/src/precomp.h

@@ -33,19 +33,19 @@ using namespace std;
 #include "model.h"
 #include "model.h"
 #include "vad-model.h"
 #include "vad-model.h"
 #include "punc-model.h"
 #include "punc-model.h"
-#include "offline-stream.h"
 #include "tokenizer.h"
 #include "tokenizer.h"
 #include "ct-transformer.h"
 #include "ct-transformer.h"
 #include "ct-transformer-online.h"
 #include "ct-transformer-online.h"
 #include "e2e-vad.h"
 #include "e2e-vad.h"
 #include "fsmn-vad.h"
 #include "fsmn-vad.h"
-#include "fsmn-vad-online.h"
 #include "vocab.h"
 #include "vocab.h"
 #include "audio.h"
 #include "audio.h"
+#include "fsmn-vad-online.h"
 #include "tensor.h"
 #include "tensor.h"
 #include "util.h"
 #include "util.h"
 #include "resample.h"
 #include "resample.h"
 #include "paraformer.h"
 #include "paraformer.h"
 #include "paraformer-online.h"
 #include "paraformer-online.h"
 #include "offline-stream.h"
 #include "offline-stream.h"
+#include "tpass-stream.h"
 #include "funasrruntime.h"
 #include "funasrruntime.h"

+ 103 - 0
funasr/runtime/onnxruntime/src/tpass-stream.cpp

@@ -0,0 +1,103 @@
+#include "precomp.h"
+#include <unistd.h>
+
+namespace funasr {
+TpassStream::TpassStream(std::map<std::string, std::string>& model_path, int thread_num)
+{
+    // VAD model
+    if(model_path.find(VAD_DIR) != model_path.end()){
+        string vad_model_path;
+        string vad_cmvn_path;
+        string vad_config_path;
+    
+        vad_model_path = PathAppend(model_path.at(VAD_DIR), MODEL_NAME);
+        if(model_path.find(VAD_QUANT) != model_path.end() && model_path.at(VAD_QUANT) == "true"){
+            vad_model_path = PathAppend(model_path.at(VAD_DIR), QUANT_MODEL_NAME);
+        }
+        vad_cmvn_path = PathAppend(model_path.at(VAD_DIR), VAD_CMVN_NAME);
+        vad_config_path = PathAppend(model_path.at(VAD_DIR), VAD_CONFIG_NAME);
+        if (access(vad_model_path.c_str(), F_OK) != 0 ||
+            access(vad_cmvn_path.c_str(), F_OK) != 0 ||
+            access(vad_config_path.c_str(), F_OK) != 0 )
+        {
+            LOG(INFO) << "VAD model file is not exist, skip load vad model.";
+        }else{
+            vad_handle = make_unique<FsmnVad>();
+            vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
+            use_vad = true;
+        }
+    }
+
+    // AM model
+    if(model_path.find(OFFLINE_MODEL_DIR) != model_path.end() && model_path.find(ONLINE_MODEL_DIR) != model_path.end()){
+        // 2pass
+        string am_model_path;
+        string en_model_path;
+        string de_model_path;
+        string am_cmvn_path;
+        string am_config_path;
+
+        am_model_path = PathAppend(model_path.at(OFFLINE_MODEL_DIR), MODEL_NAME);
+        en_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), ENCODER_NAME);
+        de_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), DECODER_NAME);
+        if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
+            am_model_path = PathAppend(model_path.at(OFFLINE_MODEL_DIR), QUANT_MODEL_NAME);
+            en_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), QUANT_ENCODER_NAME);
+            de_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), QUANT_DECODER_NAME);
+        }
+        am_cmvn_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CMVN_NAME);
+        am_config_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CONFIG_NAME);
+
+        asr_handle = make_unique<Paraformer>();
+        asr_handle->InitAsr(am_model_path, en_model_path, de_model_path, am_cmvn_path, am_config_path, thread_num);
+    }else{
+        LOG(ERROR) <<"Can not find offline-model-dir or online-model-dir";
+        exit(-1);
+    }
+
+    // PUNC model
+    if(model_path.find(PUNC_DIR) != model_path.end()){
+        string punc_model_path;
+        string punc_config_path;
+    
+        punc_model_path = PathAppend(model_path.at(PUNC_DIR), MODEL_NAME);
+        if(model_path.find(PUNC_QUANT) != model_path.end() && model_path.at(PUNC_QUANT) == "true"){
+            punc_model_path = PathAppend(model_path.at(PUNC_DIR), QUANT_MODEL_NAME);
+        }
+        punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
+
+        if (access(punc_model_path.c_str(), F_OK) != 0 ||
+            access(punc_config_path.c_str(), F_OK) != 0 )
+        {
+            LOG(INFO) << "PUNC model file is not exist, skip load punc model.";
+        }else{
+            punc_online_handle = make_unique<CTTransformerOnline>();
+            punc_online_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
+            use_punc = true;
+        }
+    }
+}
+
+TpassStream *CreateTpassStream(std::map<std::string, std::string>& model_path, int thread_num)
+{
+    TpassStream *mm;
+    mm = new TpassStream(model_path, thread_num);
+    return mm;
+}
+
+void CreateTpassOnlineStream(void* tpass_stream)
+{
+    funasr::TpassStream* tpass_obj = (funasr::TpassStream*)tpass_stream;
+    if(tpass_obj->vad_handle){
+        tpass_obj->vad_online_handle = make_unique<FsmnVadOnline>((FsmnVad*)(tpass_obj->vad_handle).get());
+    }
+
+    if(tpass_obj->asr_handle){
+        tpass_obj->asr_online_handle = make_unique<ParaformerOnline>((Paraformer*)(tpass_obj->asr_handle).get());
+    }else{
+        LOG(ERROR)<<"asr_handle is null";
+        exit(-1);
+    }
+}
+
+} // namespace funasr