timestamp_tools.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import torch
  2. import codecs
  3. import logging
  4. import argparse
  5. import numpy as np
  6. # import edit_distance
  7. from itertools import zip_longest
  8. def cif_wo_hidden(alphas, threshold):
  9. batch_size, len_time = alphas.size()
  10. # loop varss
  11. integrate = torch.zeros([batch_size], device=alphas.device)
  12. # intermediate vars along time
  13. list_fires = []
  14. for t in range(len_time):
  15. alpha = alphas[:, t]
  16. integrate += alpha
  17. list_fires.append(integrate)
  18. fire_place = integrate >= threshold
  19. integrate = torch.where(fire_place,
  20. integrate - torch.ones([batch_size], device=alphas.device)*threshold,
  21. integrate)
  22. fires = torch.stack(list_fires, 1)
  23. return fires
  24. def ts_prediction_lfr6_standard(us_alphas,
  25. us_peaks,
  26. char_list,
  27. vad_offset=0.0,
  28. force_time_shift=-1.5,
  29. sil_in_str=True
  30. ):
  31. if not len(char_list):
  32. return "", []
  33. START_END_THRESHOLD = 5
  34. MAX_TOKEN_DURATION = 12
  35. TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
  36. if len(us_alphas.shape) == 2:
  37. alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
  38. else:
  39. alphas, peaks = us_alphas, us_peaks
  40. if char_list[-1] == '</s>':
  41. char_list = char_list[:-1]
  42. fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
  43. if len(fire_place) != len(char_list) + 1:
  44. alphas /= (alphas.sum() / (len(char_list) + 1))
  45. alphas = alphas.unsqueeze(0)
  46. peaks = cif_wo_hidden(alphas, threshold=1.0-1e-4)[0]
  47. fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
  48. num_frames = peaks.shape[0]
  49. timestamp_list = []
  50. new_char_list = []
  51. # for bicif model trained with large data, cif2 actually fires when a character starts
  52. # so treat the frames between two peaks as the duration of the former token
  53. fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
  54. # assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
  55. # begin silence
  56. if fire_place[0] > START_END_THRESHOLD:
  57. # char_list.insert(0, '<sil>')
  58. timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
  59. new_char_list.append('<sil>')
  60. # tokens timestamp
  61. for i in range(len(fire_place)-1):
  62. new_char_list.append(char_list[i])
  63. if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] <= MAX_TOKEN_DURATION:
  64. timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE])
  65. else:
  66. # cut the duration to token and sil of the 0-weight frames last long
  67. _split = fire_place[i] + MAX_TOKEN_DURATION
  68. timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE])
  69. timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE])
  70. new_char_list.append('<sil>')
  71. # tail token and end silence
  72. # new_char_list.append(char_list[-1])
  73. if num_frames - fire_place[-1] > START_END_THRESHOLD:
  74. _end = (num_frames + fire_place[-1]) * 0.5
  75. # _end = fire_place[-1]
  76. timestamp_list[-1][1] = _end*TIME_RATE
  77. timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
  78. new_char_list.append("<sil>")
  79. else:
  80. timestamp_list[-1][1] = num_frames*TIME_RATE
  81. if vad_offset: # add offset time in model with vad
  82. for i in range(len(timestamp_list)):
  83. timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0
  84. timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
  85. res_txt = ""
  86. for char, timestamp in zip(new_char_list, timestamp_list):
  87. #if char != '<sil>':
  88. if not sil_in_str and char == '<sil>': continue
  89. res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
  90. res = []
  91. for char, timestamp in zip(new_char_list, timestamp_list):
  92. if char != '<sil>':
  93. res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
  94. return res_txt, res
  95. def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed):
  96. punc_list = [',', '。', '?', '、']
  97. res = []
  98. if text_postprocessed is None:
  99. return res
  100. if timestamp_postprocessed is None:
  101. return res
  102. if len(timestamp_postprocessed) == 0:
  103. return res
  104. if len(text_postprocessed) == 0:
  105. return res
  106. if punc_id_list is None or len(punc_id_list) == 0:
  107. res.append({
  108. 'text': text_postprocessed.split(),
  109. "start": timestamp_postprocessed[0][0],
  110. "end": timestamp_postprocessed[-1][1],
  111. "timestamp": timestamp_postprocessed,
  112. })
  113. return res
  114. if len(punc_id_list) != len(timestamp_postprocessed):
  115. logging.warning("length mismatch between punc and timestamp")
  116. sentence_text = ""
  117. sentence_text_seg = ""
  118. ts_list = []
  119. sentence_start = timestamp_postprocessed[0][0]
  120. sentence_end = timestamp_postprocessed[0][1]
  121. texts = text_postprocessed.split()
  122. punc_stamp_text_list = list(zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None))
  123. for punc_stamp_text in punc_stamp_text_list:
  124. punc_id, timestamp, text = punc_stamp_text
  125. # sentence_text += text if text is not None else ''
  126. if text is not None:
  127. if 'a' <= text[0] <= 'z' or 'A' <= text[0] <= 'Z':
  128. sentence_text += ' ' + text
  129. elif len(sentence_text) and ('a' <= sentence_text[-1] <= 'z' or 'A' <= sentence_text[-1] <= 'Z'):
  130. sentence_text += ' ' + text
  131. else:
  132. sentence_text += text
  133. sentence_text_seg += text + ' '
  134. ts_list.append(timestamp)
  135. punc_id = int(punc_id) if punc_id is not None else 1
  136. sentence_end = timestamp[1] if timestamp is not None else sentence_end
  137. if punc_id > 1:
  138. sentence_text += punc_list[punc_id - 2]
  139. res.append({
  140. 'text': sentence_text,
  141. "start": sentence_start,
  142. "end": sentence_end,
  143. "timestamp": ts_list
  144. })
  145. sentence_text = ''
  146. sentence_text_seg = ''
  147. ts_list = []
  148. sentence_start = sentence_end
  149. return res