timestamp_tools.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import torch
  2. import copy
  3. import logging
  4. import numpy as np
  5. from typing import Any, List, Tuple, Union
  6. def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None):
  7. if not len(char_list):
  8. return []
  9. START_END_THRESHOLD = 5
  10. TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
  11. if len(us_alphas.shape) == 3:
  12. alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only
  13. else:
  14. alphas, cif_peak = us_alphas, us_cif_peak
  15. num_frames = cif_peak.shape[0]
  16. if char_list[-1] == '</s>':
  17. char_list = char_list[:-1]
  18. # char_list = [i for i in text]
  19. timestamp_list = []
  20. # for bicif model trained with large data, cif2 actually fires when a character starts
  21. # so treat the frames between two peaks as the duration of the former token
  22. fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 1.5
  23. num_peak = len(fire_place)
  24. assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
  25. # begin silence
  26. if fire_place[0] > START_END_THRESHOLD:
  27. char_list.insert(0, '<sil>')
  28. timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
  29. # tokens timestamp
  30. for i in range(len(fire_place)-1):
  31. # the peak is always a little ahead of the start time
  32. # timestamp_list.append([(fire_place[i]-1.2)*TIME_RATE, fire_place[i+1]*TIME_RATE])
  33. timestamp_list.append([(fire_place[i])*TIME_RATE, fire_place[i+1]*TIME_RATE])
  34. # cut the duration to token and sil of the 0-weight frames last long
  35. # tail token and end silence
  36. if num_frames - fire_place[-1] > START_END_THRESHOLD:
  37. _end = (num_frames + fire_place[-1]) / 2
  38. timestamp_list[-1][1] = _end*TIME_RATE
  39. timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
  40. char_list.append("<sil>")
  41. else:
  42. timestamp_list[-1][1] = num_frames*TIME_RATE
  43. if begin_time: # add offset time in model with vad
  44. for i in range(len(timestamp_list)):
  45. timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0
  46. timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0
  47. res_txt = ""
  48. for char, timestamp in zip(char_list, timestamp_list):
  49. res_txt += "{} {} {};".format(char, timestamp[0], timestamp[1])
  50. res = []
  51. for char, timestamp in zip(char_list, timestamp_list):
  52. if char != '<sil>':
  53. res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
  54. return res
  55. def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):
  56. res = []
  57. if text_postprocessed is None:
  58. return res
  59. if time_stamp_postprocessed is None:
  60. return res
  61. if len(time_stamp_postprocessed) == 0:
  62. return res
  63. if len(text_postprocessed) == 0:
  64. return res
  65. if punc_id_list is None or len(punc_id_list) == 0:
  66. res.append({
  67. 'text': text_postprocessed.split(),
  68. "start": time_stamp_postprocessed[0][0],
  69. "end": time_stamp_postprocessed[-1][1]
  70. })
  71. return res
  72. if len(punc_id_list) != len(time_stamp_postprocessed):
  73. res.append({
  74. 'text': text_postprocessed.split(),
  75. "start": time_stamp_postprocessed[0][0],
  76. "end": time_stamp_postprocessed[-1][1]
  77. })
  78. return res
  79. sentence_text = ''
  80. sentence_start = time_stamp_postprocessed[0][0]
  81. texts = text_postprocessed.split()
  82. for i in range(len(punc_id_list)):
  83. sentence_text += texts[i]
  84. if punc_id_list[i] == 2:
  85. sentence_text += ','
  86. res.append({
  87. 'text': sentence_text,
  88. "start": sentence_start,
  89. "end": time_stamp_postprocessed[i][1]
  90. })
  91. sentence_text = ''
  92. sentence_start = time_stamp_postprocessed[i][1]
  93. elif punc_id_list[i] == 3:
  94. sentence_text += '.'
  95. res.append({
  96. 'text': sentence_text,
  97. "start": sentence_start,
  98. "end": time_stamp_postprocessed[i][1]
  99. })
  100. sentence_text = ''
  101. sentence_start = time_stamp_postprocessed[i][1]
  102. return res