|
|
@@ -6,40 +6,58 @@
|
|
|
#include <fstream>
|
|
|
#include "precomp.h"
|
|
|
|
|
|
-void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len,
|
|
|
- float vad_speech_noise_thres) {
|
|
|
+void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config) {
|
|
|
session_options_.SetIntraOpNumThreads(1);
|
|
|
session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
|
|
|
session_options_.DisableCpuMemArena();
|
|
|
- this->vad_sample_rate_ = vad_sample_rate;
|
|
|
- this->vad_silence_duration_=vad_silence_duration;
|
|
|
- this->vad_max_len_=vad_max_len;
|
|
|
- this->vad_speech_noise_thres_=vad_speech_noise_thres;
|
|
|
|
|
|
- ReadModel(vad_model);
|
|
|
+ ReadModel(vad_model.c_str());
|
|
|
LoadCmvn(vad_cmvn.c_str());
|
|
|
+ LoadConfigFromYaml(vad_config.c_str());
|
|
|
InitCache();
|
|
|
+}
|
|
|
+
|
|
|
+void FsmnVad::LoadConfigFromYaml(const char* filename){
|
|
|
+
|
|
|
+ YAML::Node config;
|
|
|
+ try{
|
|
|
+ config = YAML::LoadFile(filename);
|
|
|
+ }catch(exception const &e){
|
|
|
+ LOG(ERROR) << "Error loading file, yaml file error or not exist.";
|
|
|
+ exit(-1);
|
|
|
+ }
|
|
|
|
|
|
- fbank_opts.frame_opts.dither = 0;
|
|
|
- fbank_opts.mel_opts.num_bins = 80;
|
|
|
- fbank_opts.frame_opts.samp_freq = vad_sample_rate;
|
|
|
- fbank_opts.frame_opts.window_type = "hamming";
|
|
|
- fbank_opts.frame_opts.frame_shift_ms = 10;
|
|
|
- fbank_opts.frame_opts.frame_length_ms = 25;
|
|
|
- fbank_opts.energy_floor = 0;
|
|
|
- fbank_opts.mel_opts.debug_mel = false;
|
|
|
+ try{
|
|
|
+ YAML::Node frontend_conf = config["frontend_conf"];
|
|
|
+ YAML::Node post_conf = config["vad_post_conf"];
|
|
|
|
|
|
+ this->vad_sample_rate_ = frontend_conf["fs"].as<int>();
|
|
|
+ this->vad_silence_duration_ = post_conf["max_end_silence_time"].as<int>();
|
|
|
+ this->vad_max_len_ = post_conf["max_single_segment_time"].as<int>();
|
|
|
+ this->vad_speech_noise_thres_ = post_conf["speech_noise_thres"].as<double>();
|
|
|
+
|
|
|
+ fbank_opts.frame_opts.dither = frontend_conf["dither"].as<float>();
|
|
|
+ fbank_opts.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
|
|
|
+ fbank_opts.frame_opts.samp_freq = (float)vad_sample_rate_;
|
|
|
+ fbank_opts.frame_opts.window_type = frontend_conf["window"].as<string>();
|
|
|
+ fbank_opts.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
|
|
|
+ fbank_opts.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
|
|
|
+ fbank_opts.energy_floor = 0;
|
|
|
+ fbank_opts.mel_opts.debug_mel = false;
|
|
|
+ }catch(exception const &e){
|
|
|
+ LOG(ERROR) << "Error when load argument from vad config YAML.";
|
|
|
+ exit(-1);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
-void FsmnVad::ReadModel(const std::string &vad_model) {
|
|
|
+void FsmnVad::ReadModel(const char* vad_model) {
|
|
|
try {
|
|
|
vad_session_ = std::make_shared<Ort::Session>(
|
|
|
- env_, vad_model.c_str(), session_options_);
|
|
|
+ env_, vad_model, session_options_);
|
|
|
} catch (std::exception const &e) {
|
|
|
LOG(ERROR) << "Error when load vad onnx model: " << e.what();
|
|
|
exit(0);
|
|
|
}
|
|
|
- LOG(INFO) << "vad onnx:";
|
|
|
GetInputOutputInfo(vad_session_, &vad_in_names_, &vad_out_names_);
|
|
|
}
|
|
|
|
|
|
@@ -61,8 +79,8 @@ void FsmnVad::GetInputOutputInfo(
|
|
|
shape << j;
|
|
|
shape << " ";
|
|
|
}
|
|
|
- LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type
|
|
|
- << " dims=" << shape.str();
|
|
|
+ // LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type
|
|
|
+ // << " dims=" << shape.str();
|
|
|
(*in_names)[i] = name.get();
|
|
|
name.release();
|
|
|
}
|
|
|
@@ -80,8 +98,8 @@ void FsmnVad::GetInputOutputInfo(
|
|
|
shape << j;
|
|
|
shape << " ";
|
|
|
}
|
|
|
- LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type
|
|
|
- << " dims=" << shape.str();
|
|
|
+ // LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type
|
|
|
+ // << " dims=" << shape.str();
|
|
|
(*out_names)[i] = name.get();
|
|
|
name.release();
|
|
|
}
|
|
|
@@ -121,13 +139,12 @@ void FsmnVad::Forward(
|
|
|
// 4. Onnx infer
|
|
|
std::vector<Ort::Value> vad_ort_outputs;
|
|
|
try {
|
|
|
- VLOG(3) << "Start infer";
|
|
|
vad_ort_outputs = vad_session_->Run(
|
|
|
Ort::RunOptions{nullptr}, vad_in_names_.data(), vad_inputs.data(),
|
|
|
vad_inputs.size(), vad_out_names_.data(), vad_out_names_.size());
|
|
|
} catch (std::exception const &e) {
|
|
|
- LOG(ERROR) << e.what();
|
|
|
- return;
|
|
|
+ LOG(ERROR) << "Error when run vad onnx forword: " << (e.what());
|
|
|
+ exit(0);
|
|
|
}
|
|
|
|
|
|
// 5. Change infer result to output shapes
|
|
|
@@ -168,6 +185,10 @@ void FsmnVad::LoadCmvn(const char *filename)
|
|
|
try{
|
|
|
using namespace std;
|
|
|
ifstream cmvn_stream(filename);
|
|
|
+ if (!cmvn_stream.is_open()) {
|
|
|
+ LOG(ERROR) << "Failed to open file: " << filename;
|
|
|
+ exit(0);
|
|
|
+ }
|
|
|
string line;
|
|
|
|
|
|
while (getline(cmvn_stream, line)) {
|
|
|
@@ -203,7 +224,7 @@ void FsmnVad::LoadCmvn(const char *filename)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-std::vector<std::vector<float>> &FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats, int lfr_m, int lfr_n) {
|
|
|
+std::vector<std::vector<float>> &FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats) {
|
|
|
|
|
|
std::vector<std::vector<float>> out_feats;
|
|
|
int T = vad_feats.size();
|
|
|
@@ -250,7 +271,7 @@ FsmnVad::Infer(const std::vector<float> &waves) {
|
|
|
std::vector<std::vector<float>> vad_feats;
|
|
|
std::vector<std::vector<float>> vad_probs;
|
|
|
FbankKaldi(vad_sample_rate_, vad_feats, waves);
|
|
|
- vad_feats = LfrCmvn(vad_feats, 5, 1);
|
|
|
+ vad_feats = LfrCmvn(vad_feats);
|
|
|
Forward(vad_feats, &vad_probs);
|
|
|
|
|
|
E2EVadModel vad_scorer = E2EVadModel();
|
|
|
@@ -258,7 +279,6 @@ FsmnVad::Infer(const std::vector<float> &waves) {
|
|
|
vad_segments = vad_scorer(vad_probs, waves, true, false, vad_silence_duration_, vad_max_len_,
|
|
|
vad_speech_noise_thres_, vad_sample_rate_);
|
|
|
return vad_segments;
|
|
|
-
|
|
|
}
|
|
|
|
|
|
void FsmnVad::InitCache(){
|