|
|
@@ -46,10 +46,11 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn
|
|
|
GetInputName(m_session_.get(), strName,1);
|
|
|
m_strInputNames.push_back(strName);
|
|
|
|
|
|
- GetOutputName(m_session_.get(), strName);
|
|
|
- m_strOutputNames.push_back(strName);
|
|
|
- GetOutputName(m_session_.get(), strName,1);
|
|
|
- m_strOutputNames.push_back(strName);
|
|
|
+ size_t numOutputNodes = m_session_->GetOutputCount();
|
|
|
+ for(int index=0; index<numOutputNodes; index++){
|
|
|
+ GetOutputName(m_session_.get(), strName, index);
|
|
|
+ m_strOutputNames.push_back(strName);
|
|
|
+ }
|
|
|
|
|
|
for (auto& item : m_strInputNames)
|
|
|
m_szInputNames.push_back(item.c_str());
|
|
|
@@ -274,7 +275,7 @@ void Paraformer::LoadCmvn(const char *filename)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-string Paraformer::GreedySearch(float * in, int n_len, int64_t token_nums)
|
|
|
+string Paraformer::GreedySearch(float * in, int n_len, int64_t token_nums, bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
|
|
|
{
|
|
|
vector<int> hyps;
|
|
|
int Tmax = n_len;
|
|
|
@@ -284,8 +285,229 @@ string Paraformer::GreedySearch(float * in, int n_len, int64_t token_nums)
|
|
|
FindMax(in + i * token_nums, token_nums, max_val, max_idx);
|
|
|
hyps.push_back(max_idx);
|
|
|
}
|
|
|
+ if(!is_stamp){
|
|
|
+ return vocab->Vector2StringV2(hyps);
|
|
|
+ }else{
|
|
|
+ std::vector<string> char_list;
|
|
|
+ std::vector<std::vector<float>> timestamp_list;
|
|
|
+ std::string res_str;
|
|
|
+ vocab->Vector2String(hyps, char_list);
|
|
|
+ std::vector<string> raw_char(char_list);
|
|
|
+ TimestampOnnx(us_alphas, us_cif_peak, char_list, res_str, timestamp_list);
|
|
|
+
|
|
|
+ return PostProcess(raw_char, timestamp_list);
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
- return vocab->Vector2StringV2(hyps);
|
|
|
+string Paraformer::PostProcess(std::vector<string> &raw_char, std::vector<std::vector<float>> ×tamp_list){
|
|
|
+ std::vector<std::vector<float>> timestamp_merge;
|
|
|
+ int i;
|
|
|
+ list<string> words;
|
|
|
+ int is_pre_english = false;
|
|
|
+ int pre_english_len = 0;
|
|
|
+ int is_combining = false;
|
|
|
+ string combine = "";
|
|
|
+
|
|
|
+ float begin=-1;
|
|
|
+ for (i=0; i<raw_char.size(); i++){
|
|
|
+ string word = raw_char[i];
|
|
|
+ // step1 space character skips
|
|
|
+ if (word == "<s>" || word == "</s>" || word == "<unk>")
|
|
|
+ continue;
|
|
|
+ // step2 combie phoneme to full word
|
|
|
+ {
|
|
|
+ int sub_word = !(word.find("@@") == string::npos);
|
|
|
+ // process word start and middle part
|
|
|
+ if (sub_word) {
|
|
|
+ combine += word.erase(word.length() - 2);
|
|
|
+ if(!is_combining){
|
|
|
+ begin = timestamp_list[i][0];
|
|
|
+ }
|
|
|
+ is_combining = true;
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ // process word end part
|
|
|
+ else if (is_combining) {
|
|
|
+ combine += word;
|
|
|
+ is_combining = false;
|
|
|
+ word = combine;
|
|
|
+ combine = "";
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // step3 process english word deal with space , turn abbreviation to upper case
|
|
|
+ {
|
|
|
+ // input word is chinese, not need process
|
|
|
+ if (vocab->IsChinese(word)) {
|
|
|
+ words.push_back(word);
|
|
|
+ timestamp_merge.emplace_back(timestamp_list[i]);
|
|
|
+ is_pre_english = false;
|
|
|
+ }
|
|
|
+ // input word is english word
|
|
|
+ else {
|
|
|
+ // pre word is chinese
|
|
|
+ if (!is_pre_english) {
|
|
|
+ // word[0] = word[0] - 32;
|
|
|
+ words.push_back(word);
|
|
|
+ begin = (begin==-1)?timestamp_list[i][0]:begin;
|
|
|
+ std::vector<float> vec = {begin, timestamp_list[i][1]};
|
|
|
+ timestamp_merge.emplace_back(vec);
|
|
|
+ begin = -1;
|
|
|
+ pre_english_len = word.size();
|
|
|
+ }
|
|
|
+ // pre word is english word
|
|
|
+ else {
|
|
|
+ // single letter turn to upper case
|
|
|
+ // if (word.size() == 1) {
|
|
|
+ // word[0] = word[0] - 32;
|
|
|
+ // }
|
|
|
+
|
|
|
+ if (pre_english_len > 1) {
|
|
|
+ words.push_back(" ");
|
|
|
+ words.push_back(word);
|
|
|
+ begin = (begin==-1)?timestamp_list[i][0]:begin;
|
|
|
+ std::vector<float> vec = {begin, timestamp_list[i][1]};
|
|
|
+ timestamp_merge.emplace_back(vec);
|
|
|
+ begin = -1;
|
|
|
+ pre_english_len = word.size();
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ // if (word.size() > 1) {
|
|
|
+ // words.push_back(" ");
|
|
|
+ // }
|
|
|
+ words.push_back(" ");
|
|
|
+ words.push_back(word);
|
|
|
+ begin = (begin==-1)?timestamp_list[i][0]:begin;
|
|
|
+ std::vector<float> vec = {begin, timestamp_list[i][1]};
|
|
|
+ timestamp_merge.emplace_back(vec);
|
|
|
+ begin = -1;
|
|
|
+ pre_english_len = word.size();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ is_pre_english = true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ string stamp_str="";
|
|
|
+ for (i=0; i<timestamp_list.size(); i++) {
|
|
|
+ stamp_str += std::to_string(timestamp_list[i][0]);
|
|
|
+ stamp_str += ", ";
|
|
|
+ stamp_str += std::to_string(timestamp_list[i][1]);
|
|
|
+ if(i!=timestamp_list.size()-1){
|
|
|
+ stamp_str += ",";
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ stringstream ss;
|
|
|
+ for (auto it = words.begin(); it != words.end(); it++) {
|
|
|
+ ss << *it;
|
|
|
+ }
|
|
|
+
|
|
|
+ return ss.str()+" | "+stamp_str;
|
|
|
+}
|
|
|
+
|
|
|
+void Paraformer::TimestampOnnx(std::vector<float>& us_alphas,
|
|
|
+ std::vector<float> us_cif_peak,
|
|
|
+ std::vector<string>& char_list,
|
|
|
+ std::string &res_str,
|
|
|
+ std::vector<std::vector<float>> ×tamp_vec,
|
|
|
+ float begin_time,
|
|
|
+ float total_offset){
|
|
|
+ if (char_list.empty()) {
|
|
|
+ return ;
|
|
|
+ }
|
|
|
+
|
|
|
+ const float START_END_THRESHOLD = 5.0;
|
|
|
+ const float MAX_TOKEN_DURATION = 30.0;
|
|
|
+ const float TIME_RATE = 10.0 * 6 / 1000 / 3;
|
|
|
+ // 3 times upsampled, cif_peak is flattened into a 1D array
|
|
|
+ std::vector<float> cif_peak = us_cif_peak;
|
|
|
+ int num_frames = cif_peak.size();
|
|
|
+ if (char_list.back() == "</s>") {
|
|
|
+ char_list.pop_back();
|
|
|
+ }
|
|
|
+
|
|
|
+ vector<vector<float>> timestamp_list;
|
|
|
+ vector<string> new_char_list;
|
|
|
+ vector<float> fire_place;
|
|
|
+ // for bicif model trained with large data, cif2 actually fires when a character starts
|
|
|
+ // so treat the frames between two peaks as the duration of the former token
|
|
|
+ for (int i = 0; i < num_frames; i++) {
|
|
|
+ if (cif_peak[i] > 1.0 - 1e-4) {
|
|
|
+ fire_place.push_back(i + total_offset);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ int num_peak = fire_place.size();
|
|
|
+ if(num_peak != (int)char_list.size() + 1){
|
|
|
+ float sum = std::accumulate(us_alphas.begin(), us_alphas.end(), 0.0f);
|
|
|
+ float scale = sum/((int)char_list.size() + 1);
|
|
|
+ cif_peak.clear();
|
|
|
+ sum = 0.0;
|
|
|
+ for(auto &alpha:us_alphas){
|
|
|
+ alpha = alpha/scale;
|
|
|
+ sum += alpha;
|
|
|
+ cif_peak.emplace_back(sum);
|
|
|
+ if(sum>=1.0 - 1e-4){
|
|
|
+ sum -=(1.0 - 1e-4);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ fire_place.clear();
|
|
|
+ for (int i = 0; i < num_frames; i++) {
|
|
|
+ if (cif_peak[i] > 1.0 - 1e-4) {
|
|
|
+ fire_place.push_back(i + total_offset);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // begin silence
|
|
|
+ if (fire_place[0] > START_END_THRESHOLD) {
|
|
|
+ new_char_list.push_back("<sil>");
|
|
|
+ timestamp_list.push_back({0.0, fire_place[0] * TIME_RATE});
|
|
|
+ }
|
|
|
+
|
|
|
+ // tokens timestamp
|
|
|
+ for (int i = 0; i < num_peak - 1; i++) {
|
|
|
+ new_char_list.push_back(char_list[i]);
|
|
|
+ if (i == num_peak - 2 || MAX_TOKEN_DURATION < 0 || fire_place[i + 1] - fire_place[i] < MAX_TOKEN_DURATION) {
|
|
|
+ timestamp_list.push_back({fire_place[i] * TIME_RATE, fire_place[i + 1] * TIME_RATE});
|
|
|
+ } else {
|
|
|
+ // cut the duration to token and sil of the 0-weight frames last long
|
|
|
+ float _split = fire_place[i] + MAX_TOKEN_DURATION;
|
|
|
+ timestamp_list.push_back({fire_place[i] * TIME_RATE, _split * TIME_RATE});
|
|
|
+ timestamp_list.push_back({_split * TIME_RATE, fire_place[i + 1] * TIME_RATE});
|
|
|
+ new_char_list.push_back("<sil>");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // tail token and end silence
|
|
|
+ if (num_frames - fire_place.back() > START_END_THRESHOLD) {
|
|
|
+ float _end = (num_frames + fire_place.back()) / 2.0;
|
|
|
+ timestamp_list.back()[1] = _end * TIME_RATE;
|
|
|
+ timestamp_list.push_back({_end * TIME_RATE, num_frames * TIME_RATE});
|
|
|
+ new_char_list.push_back("<sil>");
|
|
|
+ } else {
|
|
|
+ timestamp_list.back()[1] = num_frames * TIME_RATE;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (begin_time) { // add offset time in model with vad
|
|
|
+ for (auto& timestamp : timestamp_list) {
|
|
|
+ timestamp[0] += begin_time / 1000.0;
|
|
|
+ timestamp[1] += begin_time / 1000.0;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ assert(new_char_list.size() == timestamp_list.size());
|
|
|
+
|
|
|
+ for (int i = 0; i < (int)new_char_list.size(); i++) {
|
|
|
+ res_str += new_char_list[i] + " " + to_string(timestamp_list[i][0]) + " " + to_string(timestamp_list[i][1]) + ";";
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int i = 0; i < (int)new_char_list.size(); i++) {
|
|
|
+ if(new_char_list[i] != "<sil>"){
|
|
|
+ timestamp_vec.push_back(timestamp_list[i]);
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
vector<float> Paraformer::ApplyLfr(const std::vector<float> &in)
|
|
|
@@ -369,7 +591,25 @@ string Paraformer::Forward(float* din, int len, bool input_finished)
|
|
|
int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
|
|
|
float* floatData = outputTensor[0].GetTensorMutableData<float>();
|
|
|
auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
|
|
|
- result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
|
|
|
+ // timestamp
|
|
|
+ if(outputTensor.size() == 4){
|
|
|
+ std::vector<int64_t> us_alphas_shape = outputTensor[2].GetTensorTypeAndShapeInfo().GetShape();
|
|
|
+ float* us_alphas_data = outputTensor[2].GetTensorMutableData<float>();
|
|
|
+ std::vector<float> us_alphas(us_alphas_shape[1]);
|
|
|
+ for (int i = 0; i < us_alphas_shape[1]; i++) {
|
|
|
+ us_alphas[i] = us_alphas_data[i];
|
|
|
+ }
|
|
|
+
|
|
|
+ std::vector<int64_t> us_peaks_shape = outputTensor[3].GetTensorTypeAndShapeInfo().GetShape();
|
|
|
+ float* us_peaks_data = outputTensor[3].GetTensorMutableData<float>();
|
|
|
+ std::vector<float> us_peaks(us_peaks_shape[1]);
|
|
|
+ for (int i = 0; i < us_peaks_shape[1]; i++) {
|
|
|
+ us_peaks[i] = us_peaks_data[i];
|
|
|
+ }
|
|
|
+ result = GreedySearch(floatData, *encoder_out_lens, outputShape[2], true, us_alphas, us_peaks);
|
|
|
+ }else{
|
|
|
+ result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
|
|
|
+ }
|
|
|
}
|
|
|
catch (std::exception const &e)
|
|
|
{
|