|
|
@@ -1,14 +1,10 @@
|
|
|
-from itertools import zip_longest
|
|
|
-
|
|
|
import torch
|
|
|
-import copy
|
|
|
import codecs
|
|
|
import logging
|
|
|
-import edit_distance
|
|
|
import argparse
|
|
|
-import pdb
|
|
|
import numpy as np
|
|
|
-from typing import Any, List, Tuple, Union
|
|
|
+import edit_distance
|
|
|
+from itertools import zip_longest
|
|
|
|
|
|
|
|
|
def ts_prediction_lfr6_standard(us_alphas,
|
|
|
@@ -36,7 +32,14 @@ def ts_prediction_lfr6_standard(us_alphas,
|
|
|
# so treat the frames between two peaks as the duration of the former token
|
|
|
fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
|
|
num_peak = len(fire_place)
|
|
|
- assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
|
|
|
+ if num_peak != len(char_list) + 1:
|
|
|
+ logging.warning("length mismatch, result might be incorrect.")
|
|
|
+ logging.warning("num_peaks: {}, num_chars+1: {}, which is supposed to be same.".format(num_peak, len(char_list)+1))
|
|
|
+ if num_peak > len(char_list) + 1:
|
|
|
+ fire_place = fire_place[:len(char_list) - 1]
|
|
|
+ elif num_peak < len(char_list) + 1:
|
|
|
+ char_list = char_list[:num_peak + 1]
|
|
|
+ # assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
|
|
|
# begin silence
|
|
|
if fire_place[0] > START_END_THRESHOLD:
|
|
|
# char_list.insert(0, '<sil>')
|