|
|
@@ -6,10 +6,9 @@ from typing import Any, List, Tuple, Union
|
|
|
|
|
|
|
|
|
def ts_prediction_lfr6_standard(us_alphas,
|
|
|
- us_cif_peak,
|
|
|
+ us_peaks,
|
|
|
char_list,
|
|
|
vad_offset=0.0,
|
|
|
- end_time=None,
|
|
|
force_time_shift=-1.5
|
|
|
):
|
|
|
if not len(char_list):
|
|
|
@@ -18,17 +17,17 @@ def ts_prediction_lfr6_standard(us_alphas,
|
|
|
MAX_TOKEN_DURATION = 12
|
|
|
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
|
|
|
if len(us_alphas.shape) == 2:
|
|
|
- alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only
|
|
|
+ _, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
|
|
|
else:
|
|
|
- alphas, cif_peak = us_alphas, us_cif_peak
|
|
|
- num_frames = cif_peak.shape[0]
|
|
|
+ _, peaks = us_alphas, us_peaks
|
|
|
+ num_frames = peaks.shape[0]
|
|
|
if char_list[-1] == '</s>':
|
|
|
char_list = char_list[:-1]
|
|
|
timestamp_list = []
|
|
|
new_char_list = []
|
|
|
# for bicif model trained with large data, cif2 actually fires when a character starts
|
|
|
# so treat the frames between two peaks as the duration of the former token
|
|
|
- fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
|
|
+ fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
|
|
num_peak = len(fire_place)
|
|
|
assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
|
|
|
# begin silence
|