|
|
@@ -1,3 +1,4 @@
|
|
|
+from pydoc import TextRepr
|
|
|
from scipy.fftpack import shift
|
|
|
import torch
|
|
|
import copy
|
|
|
@@ -5,6 +6,7 @@ import codecs
|
|
|
import logging
|
|
|
import edit_distance
|
|
|
import argparse
|
|
|
+import pdb
|
|
|
import numpy as np
|
|
|
from typing import Any, List, Tuple, Union
|
|
|
|
|
|
@@ -13,7 +15,8 @@ def ts_prediction_lfr6_standard(us_alphas,
|
|
|
us_peaks,
|
|
|
char_list,
|
|
|
vad_offset=0.0,
|
|
|
- force_time_shift=-1.5
|
|
|
+ force_time_shift=-1.5,
|
|
|
+ sil_in_str=True
|
|
|
):
|
|
|
if not len(char_list):
|
|
|
return []
|
|
|
@@ -66,6 +69,8 @@ def ts_prediction_lfr6_standard(us_alphas,
|
|
|
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
|
|
|
res_txt = ""
|
|
|
for char, timestamp in zip(new_char_list, timestamp_list):
|
|
|
+ #if char != '<sil>':
|
|
|
+ if not sil_in_str and char == '<sil>': continue
|
|
|
res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
|
|
|
res = []
|
|
|
for char, timestamp in zip(new_char_list, timestamp_list):
|
|
|
@@ -233,13 +238,54 @@ class AverageShiftCalculator():
|
|
|
return self._accumlated_shift / self._accumlated_tokens
|
|
|
|
|
|
|
|
|
-SUPPORTED_MODES = ['cal_aas']
|
|
|
+def convert_external_alphas(alphas_file, text_file, output_file):
|
|
|
+ from funasr.models.predictor.cif import cif_wo_hidden
|
|
|
+ with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3:
|
|
|
+ for line1, line2 in zip(f1.readlines(), f2.readlines()):
|
|
|
+ line1 = line1.rstrip()
|
|
|
+ line2 = line2.rstrip()
|
|
|
+ assert line1.split()[0] == line2.split()[0]
|
|
|
+ uttid = line1.split()[0]
|
|
|
+ alphas = [float(i) for i in line1.split()[1:]]
|
|
|
+ new_alphas = np.array(remove_chunk_padding(alphas))
|
|
|
+ new_alphas[-1] += 1e-4
|
|
|
+ text = line2.split()[1:]
|
|
|
+ if len(text) + 1 != int(new_alphas.sum()):
|
|
|
+ # force resize
|
|
|
+ new_alphas *= (len(text) + 1) / int(new_alphas.sum())
|
|
|
+ peaks = cif_wo_hidden(torch.Tensor(new_alphas).unsqueeze(0), 1.0-1e-4)
|
|
|
+ if " " in text:
|
|
|
+ text = text.split()
|
|
|
+ else:
|
|
|
+ text = [i for i in text]
|
|
|
+ res_str, _ = ts_prediction_lfr6_standard(new_alphas, peaks[0], text,
|
|
|
+ force_time_shift=-7.0,
|
|
|
+ sil_in_str=False)
|
|
|
+ f3.write("{} {}\n".format(uttid, res_str))
|
|
|
+
|
|
|
+
|
|
|
+def remove_chunk_padding(alphas):
|
|
|
+ # remove the padding part in alphas if using chunk paraformer for GPU
|
|
|
+ START_ZERO = 45
|
|
|
+ MID_ZERO = 75
|
|
|
+ REAL_FRAMES = 360 # for chunk based encoder 10-120-10 and fsmn padding 5
|
|
|
+ alphas = alphas[START_ZERO:] # remove the padding at beginning
|
|
|
+ new_alphas = []
|
|
|
+ while True:
|
|
|
+ new_alphas = new_alphas + alphas[:REAL_FRAMES]
|
|
|
+ alphas = alphas[REAL_FRAMES+MID_ZERO:]
|
|
|
+ if len(alphas) < REAL_FRAMES: break
|
|
|
+ return new_alphas
|
|
|
+
|
|
|
+SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
|
|
|
|
|
|
|
|
|
def main(args):
|
|
|
if args.mode == 'cal_aas':
|
|
|
asc = AverageShiftCalculator()
|
|
|
asc(args.input, args.input2)
|
|
|
+ elif args.mode == 'read_ext_alphas':
|
|
|
+ convert_external_alphas(args.input, args.input2, args.output)
|
|
|
else:
|
|
|
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))
|
|
|
|