Przeglądaj źródła

8k (#1174) (#1175)

* Adaptive 8K

* fix FfmpegLoad 8k

Co-authored-by: cdevelop <cdevelop@qq.com>
Yabin Li 2 lat temu
rodzic
commit
a0048dc766

+ 3 - 2
runtime/onnxruntime/include/audio.h

@@ -52,10 +52,11 @@ class Audio {
     queue<AudioFrame *> frame_queue;
     queue<AudioFrame *> asr_online_queue;
     queue<AudioFrame *> asr_offline_queue;
-
+    int dest_sample_rate;
   public:
     Audio(int data_type);
-    Audio(int data_type, int size);
+    Audio(int model_sample_rate,int data_type);
+    Audio(int model_sample_rate,int data_type, int size);
     ~Audio();
     void Disp();
     void WavResample(int32_t sampling_rate, const float *waveform, int32_t n);

+ 2 - 0
runtime/onnxruntime/include/model.h

@@ -23,6 +23,8 @@ class Model {
     virtual void InitSegDict(const std::string &seg_dict_model){};
     virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
     virtual std::string GetLang(){return "";};
+    virtual int GetAsrSampleRate() = 0;
+
 };
 
 Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);

+ 1 - 0
runtime/onnxruntime/include/vad-model.h

@@ -12,6 +12,7 @@ class VadModel {
     virtual ~VadModel(){};
     virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0;
     virtual std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true)=0;
+    virtual int GetVadSampleRate() = 0;
 };
 
 VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num);

+ 27 - 17
runtime/onnxruntime/src/audio.cpp

@@ -193,18 +193,28 @@ int AudioFrame::Disp()
     return 0;
 }
 
-Audio::Audio(int data_type) : data_type(data_type)
+Audio::Audio(int data_type) : dest_sample_rate(MODEL_SAMPLE_RATE), data_type(data_type)
 {
     speech_buff = NULL;
     speech_data = NULL;
     align_size = 1360;
+    seg_sample = dest_sample_rate / 1000;
 }
 
-Audio::Audio(int data_type, int size) : data_type(data_type)
+Audio::Audio(int model_sample_rate, int data_type) : dest_sample_rate(model_sample_rate), data_type(data_type)
+{
+    speech_buff = NULL;
+    speech_data = NULL;
+    align_size = 1360;
+    seg_sample = dest_sample_rate / 1000;
+}
+
+Audio::Audio(int model_sample_rate, int data_type, int size) : dest_sample_rate(model_sample_rate), data_type(data_type)
 {
     speech_buff = NULL;
     speech_data = NULL;
     align_size = (float)size;
+    seg_sample = dest_sample_rate / 1000;
 }
 
 Audio::~Audio()
@@ -222,12 +232,12 @@ Audio::~Audio()
 
 void Audio::Disp()
 {
-    LOG(INFO) << "Audio time is " << (float)speech_len / MODEL_SAMPLE_RATE << " s. len is " << speech_len;
+    LOG(INFO) << "Audio time is " << (float)speech_len / dest_sample_rate << " s. len is " << speech_len;
 }
 
 float Audio::GetTimeLen()
 {
-    return (float)speech_len / MODEL_SAMPLE_RATE;
+    return (float)speech_len / dest_sample_rate;
 }
 
 void Audio::WavResample(int32_t sampling_rate, const float *waveform,
@@ -237,13 +247,13 @@ void Audio::WavResample(int32_t sampling_rate, const float *waveform,
               << "   in_sample_rate: "<< sampling_rate << "\n"
               << "   output_sample_rate: " << static_cast<int32_t>(MODEL_SAMPLE_RATE);
     float min_freq =
-        std::min<int32_t>(sampling_rate, MODEL_SAMPLE_RATE);
+        std::min<int32_t>(sampling_rate, dest_sample_rate);
     float lowpass_cutoff = 0.99 * 0.5 * min_freq;
 
     int32_t lowpass_filter_width = 6;
 
     auto resampler = std::make_unique<LinearResample>(
-          sampling_rate, MODEL_SAMPLE_RATE, lowpass_cutoff, lowpass_filter_width);
+          sampling_rate, dest_sample_rate, lowpass_cutoff, lowpass_filter_width);
     std::vector<float> samples;
     resampler->Resample(waveform, n, true, &samples);
     //reset speech_data
@@ -311,7 +321,7 @@ bool Audio::FfmpegLoad(const char *filename, bool copy2char){
         nullptr, // allocate a new context
         AV_CH_LAYOUT_MONO, // output channel layout (stereo)
         AV_SAMPLE_FMT_S16, // output sample format (signed 16-bit)
-        16000, // output sample rate (same as input)
+        dest_sample_rate, // output sample rate (same as input)
         av_get_default_channel_layout(codecContext->channels), // input channel layout
         codecContext->sample_fmt, // input sample format
         codecContext->sample_rate, // input sample rate
@@ -347,7 +357,7 @@ bool Audio::FfmpegLoad(const char *filename, bool copy2char){
                     int in_samples = frame->nb_samples;
                     uint8_t **in_data = frame->extended_data;
                     int out_samples = av_rescale_rnd(in_samples,
-                                                    16000,
+                                                    dest_sample_rate,
                                                     codecContext->sample_rate,
                                                     AV_ROUND_DOWN);
                     
@@ -494,7 +504,7 @@ bool Audio::FfmpegLoad(const char* buf, int n_file_len){
         nullptr, // allocate a new context
         AV_CH_LAYOUT_MONO, // output channel layout (stereo)
         AV_SAMPLE_FMT_S16, // output sample format (signed 16-bit)
-        16000, // output sample rate (same as input)
+        dest_sample_rate, // output sample rate (same as input)
         av_get_default_channel_layout(codecContext->channels), // input channel layout
         codecContext->sample_fmt, // input sample format
         codecContext->sample_rate, // input sample rate
@@ -532,7 +542,7 @@ bool Audio::FfmpegLoad(const char* buf, int n_file_len){
                     int in_samples = frame->nb_samples;
                     uint8_t **in_data = frame->extended_data;
                     int out_samples = av_rescale_rnd(in_samples,
-                                                    16000,
+                                                    dest_sample_rate,
                                                     codecContext->sample_rate,
                                                     AV_ROUND_DOWN);
                     
@@ -666,7 +676,7 @@ bool Audio::LoadWav(const char *filename, int32_t* sampling_rate)
         }
 
         //resample
-        if(*sampling_rate != MODEL_SAMPLE_RATE){
+        if(*sampling_rate != dest_sample_rate){
             WavResample(*sampling_rate, speech_data, speech_len);
         }
 
@@ -752,7 +762,7 @@ bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate)
         }
         
         //resample
-        if(*sampling_rate != MODEL_SAMPLE_RATE){
+        if(*sampling_rate != dest_sample_rate){
             WavResample(*sampling_rate, speech_data, speech_len);
         }
 
@@ -795,7 +805,7 @@ bool Audio::LoadPcmwav(const char* buf, int n_buf_len, int32_t* sampling_rate)
         }
         
         //resample
-        if(*sampling_rate != MODEL_SAMPLE_RATE){
+        if(*sampling_rate != dest_sample_rate){
             WavResample(*sampling_rate, speech_data, speech_len);
         }
 
@@ -840,7 +850,7 @@ bool Audio::LoadPcmwavOnline(const char* buf, int n_buf_len, int32_t* sampling_r
         }
         
         //resample
-        if(*sampling_rate != MODEL_SAMPLE_RATE){
+        if(*sampling_rate != dest_sample_rate){
             WavResample(*sampling_rate, speech_data, speech_len);
         }
 
@@ -898,7 +908,7 @@ bool Audio::LoadPcmwav(const char* filename, int32_t* sampling_rate)
         }
 
         //resample
-        if(*sampling_rate != MODEL_SAMPLE_RATE){
+        if(*sampling_rate != dest_sample_rate){
             WavResample(*sampling_rate, speech_data, speech_len);
         }
 
@@ -1009,7 +1019,7 @@ int Audio::Fetch(float *&dout, int &len, int &flag, float &start_time)
         AudioFrame *frame = frame_queue.front();
         frame_queue.pop();
 
-        start_time = (float)(frame->GetStart())/MODEL_SAMPLE_RATE;
+        start_time = (float)(frame->GetStart())/ dest_sample_rate;
         dout = speech_data + frame->GetStart();
         len = frame->GetLen();
         delete frame;
@@ -1248,7 +1258,7 @@ void Audio::Split(VadModel* vad_obj, int chunk_len, bool input_finished, ASR_TYP
     }
 
     // erase all_samples
-    int vector_cache = MODEL_SAMPLE_RATE*2;
+    int vector_cache = dest_sample_rate*2;
     if(speech_offline_start == -1){
         if(all_samples.size() > vector_cache){
             int erase_num = all_samples.size() - vector_cache;

+ 4 - 1
runtime/onnxruntime/src/fsmn-vad-online.cpp

@@ -187,8 +187,11 @@ void FsmnVadOnline::InitOnline(std::shared_ptr<Ort::Session> &vad_session,
     vad_max_len_ = vad_max_len;
     vad_speech_noise_thres_ = vad_speech_noise_thres;
 
+    frame_sample_length_ = vad_sample_rate_ / 1000 * 25;;
+    frame_shift_sample_length_ = vad_sample_rate_ / 1000 * 10;
+
     // 2pass
-    audio_handle = make_unique<Audio>(1);
+    audio_handle = make_unique<Audio>(vad_sample_rate,1);
 }
 
 FsmnVadOnline::~FsmnVadOnline() {

+ 2 - 0
runtime/onnxruntime/src/fsmn-vad-online.h

@@ -21,6 +21,8 @@ public:
     std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
     void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
     void Reset();
+    int GetVadSampleRate() { return vad_sample_rate_; };
+
     // 2pass
     std::unique_ptr<Audio> audio_handle = nullptr;
 

+ 2 - 0
runtime/onnxruntime/src/fsmn-vad.h

@@ -28,6 +28,8 @@ public:
         std::vector<std::vector<float>> *in_cache,
         bool is_final);
     void Reset();
+
+    int GetVadSampleRate() { return vad_sample_rate_; };
     
     std::shared_ptr<Ort::Session> vad_session_ = nullptr;
     Ort::Env env_;

+ 6 - 6
runtime/onnxruntime/src/funasrruntime.cpp

@@ -57,7 +57,7 @@
 		if (!recog_obj)
 			return nullptr;
 
-		funasr::Audio audio(1);
+		funasr::Audio audio(recog_obj->GetAsrSampleRate(),1);
 		if(wav_format == "pcm" || wav_format == "PCM"){
 			if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
 				return nullptr;
@@ -93,7 +93,7 @@
 		if (!recog_obj)
 			return nullptr;
 
-		funasr::Audio audio(1);
+		funasr::Audio audio(recog_obj->GetAsrSampleRate(),1);
 		if(funasr::is_target_file(sz_filename, "wav")){
 			int32_t sampling_rate_ = -1;
 			if(!audio.LoadWav(sz_filename, &sampling_rate_))
@@ -134,7 +134,7 @@
 		if (!vad_obj)
 			return nullptr;
 
-		funasr::Audio audio(1);
+		funasr::Audio audio(vad_obj->GetVadSampleRate(),1);
 		if(wav_format == "pcm" || wav_format == "PCM"){
 			if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
 				return nullptr;
@@ -162,7 +162,7 @@
 		if (!vad_obj)
 			return nullptr;
 
-		funasr::Audio audio(1);
+		funasr::Audio audio(vad_obj->GetVadSampleRate(),1);
 		if(funasr::is_target_file(sz_filename, "wav")){
 			int32_t sampling_rate_ = -1;
 			if(!audio.LoadWav(sz_filename, &sampling_rate_))
@@ -222,7 +222,7 @@
 		if (!offline_stream)
 			return nullptr;
 
-		funasr::Audio audio(1);
+		funasr::Audio audio(offline_stream->asr_handle->GetAsrSampleRate(),1);
 		try{
 			if(wav_format == "pcm" || wav_format == "PCM"){
 				if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
@@ -314,7 +314,7 @@
 		if (!offline_stream)
 			return nullptr;
 		
-		funasr::Audio audio(1);
+		funasr::Audio audio((offline_stream->asr_handle)->GetAsrSampleRate(),1);
 		try{
 			if(funasr::is_target_file(sz_filename, "wav")){
 				int32_t sampling_rate_ = -1;

+ 6 - 2
runtime/onnxruntime/src/paraformer-online.cpp

@@ -61,7 +61,11 @@ void ParaformerOnline::InitOnline(
     for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
         fsmn_init_cache_.emplace_back(0);
     }
-    chunk_len = chunk_size[1]*frame_shift*lfr_n*MODEL_SAMPLE_RATE/1000;
+    chunk_len = chunk_size[1]*frame_shift*lfr_n*para_handle_->asr_sample_rate/1000;
+
+    frame_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_length;
+    frame_shift_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_shift;
+
 }
 
 void ParaformerOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &wav_feats,
@@ -489,7 +493,7 @@ string ParaformerOnline::Forward(float* din, int len, bool input_finished, const
         if(is_first_chunk){
             is_first_chunk = false;
         }
-        ExtractFeats(MODEL_SAMPLE_RATE, wav_feats, waves, input_finished);
+        ExtractFeats(para_handle_->asr_sample_rate, wav_feats, waves, input_finished);
         if(wav_feats.size() == 0){
             return result;
         }

+ 3 - 0
runtime/onnxruntime/src/paraformer-online.h

@@ -111,6 +111,9 @@ namespace funasr {
         string ForwardChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished);
         string Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
         string Rescoring();
+
+        int GetAsrSampleRate() { return para_handle_->asr_sample_rate; };
+
         // 2pass
         std::string online_res;
         int chunk_len;

+ 9 - 3
runtime/onnxruntime/src/paraformer.cpp

@@ -19,10 +19,11 @@ Paraformer::Paraformer()
 
 // offline
 void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+    LoadConfigFromYaml(am_config.c_str());
     // knf options
     fbank_opts_.frame_opts.dither = 0;
     fbank_opts_.mel_opts.num_bins = n_mels;
-    fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
+    fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
     fbank_opts_.frame_opts.window_type = window_type;
     fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
     fbank_opts_.frame_opts.frame_length_ms = frame_length;
@@ -65,7 +66,6 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn
     for (auto& item : m_strOutputNames)
         m_szOutputNames.push_back(item.c_str());
     vocab = new Vocab(am_config.c_str());
-    LoadConfigFromYaml(am_config.c_str());
 	phone_set_ = new PhoneSet(am_config.c_str());
     LoadCmvn(am_cmvn.c_str());
 }
@@ -77,7 +77,7 @@ void Paraformer::InitAsr(const std::string &en_model, const std::string &de_mode
     // knf options
     fbank_opts_.frame_opts.dither = 0;
     fbank_opts_.mel_opts.num_bins = n_mels;
-    fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
+    fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
     fbank_opts_.frame_opts.window_type = window_type;
     fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
     fbank_opts_.frame_opts.frame_length_ms = frame_length;
@@ -216,6 +216,9 @@ void Paraformer::LoadConfigFromYaml(const char* filename){
     }
 
     try{
+        YAML::Node frontend_conf = config["frontend_conf"];
+        this->asr_sample_rate = frontend_conf["fs"].as<int>();
+
         YAML::Node lang_conf = config["lang"];
         if (lang_conf.IsDefined()){
             language = lang_conf.as<string>();
@@ -258,6 +261,9 @@ void Paraformer::LoadOnlineConfigFromYaml(const char* filename){
         this->cif_threshold = predictor_conf["threshold"].as<double>();
         this->tail_alphas = predictor_conf["tail_threshold"].as<double>();
 
+        this->asr_sample_rate = frontend_conf["fs"].as<int>();
+
+
     }catch(exception const &e){
         LOG(ERROR) << "Error when load argument from vad config YAML.";
         exit(-1);

+ 2 - 3
runtime/onnxruntime/src/paraformer.h

@@ -57,7 +57,7 @@ namespace funasr {
 
         string Rescoring();
         string GetLang(){return language;};
-		
+        int GetAsrSampleRate() { return asr_sample_rate; };
         void StartUtterance();
         void EndUtterance();
         void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
@@ -107,8 +107,7 @@ namespace funasr {
         int fsmn_dims = 512;
         float cif_threshold = 1.0;
         float tail_alphas = 0.45;
-
-
+        int asr_sample_rate = MODEL_SAMPLE_RATE;
     };
 
 } // namespace funasr