Browse Source

remove assert in ts_prediction_lfr6_standard

shixian.shi 2 years ago
parent
commit
1054daf44a
1 changed files with 10 additions and 7 deletions
  1. 10 7
      funasr/utils/timestamp_tools.py

+ 10 - 7
funasr/utils/timestamp_tools.py

@@ -1,14 +1,10 @@
-from itertools import zip_longest
-
 import torch
-import copy
 import codecs
 import logging
-import edit_distance
 import argparse
-import pdb
 import numpy as np
-from typing import Any, List, Tuple, Union
+import edit_distance
+from itertools import zip_longest
 
 
 def ts_prediction_lfr6_standard(us_alphas, 
@@ -36,7 +32,14 @@ def ts_prediction_lfr6_standard(us_alphas,
     # so treat the frames between two peaks as the duration of the former token
     fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
     num_peak = len(fire_place)
-    assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
+    if num_peak != len(char_list) + 1:
+        logging.warning("length mismatch, result might be incorrect.")
+        logging.warning("num_peaks: {}, num_chars+1: {}, which is supposed to be same.".format(num_peak, len(char_list)+1))
+    if num_peak > len(char_list) + 1:
+        fire_place = fire_place[:len(char_list) - 1]
+    elif num_peak < len(char_list) + 1:
+        char_list = char_list[:num_peak + 1]
+    # assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
     # begin silence
     if fire_place[0] > START_END_THRESHOLD:
         # char_list.insert(0, '<sil>')