cif_predictor.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import torch
  6. from funasr.register import tables
  7. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  8. class mae_loss(torch.nn.Module):
  9. def __init__(self, normalize_length=False):
  10. super(mae_loss, self).__init__()
  11. self.normalize_length = normalize_length
  12. self.criterion = torch.nn.L1Loss(reduction='sum')
  13. def forward(self, token_length, pre_token_length):
  14. loss_token_normalizer = token_length.size(0)
  15. if self.normalize_length:
  16. loss_token_normalizer = token_length.sum().type(torch.float32)
  17. loss = self.criterion(token_length, pre_token_length)
  18. loss = loss / loss_token_normalizer
  19. return loss
  20. def cif(hidden, alphas, threshold):
  21. batch_size, len_time, hidden_size = hidden.size()
  22. # loop varss
  23. integrate = torch.zeros([batch_size], device=hidden.device)
  24. frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
  25. # intermediate vars along time
  26. list_fires = []
  27. list_frames = []
  28. for t in range(len_time):
  29. alpha = alphas[:, t]
  30. distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
  31. integrate += alpha
  32. list_fires.append(integrate)
  33. fire_place = integrate >= threshold
  34. integrate = torch.where(fire_place,
  35. integrate - torch.ones([batch_size], device=hidden.device),
  36. integrate)
  37. cur = torch.where(fire_place,
  38. distribution_completion,
  39. alpha)
  40. remainds = alpha - cur
  41. frame += cur[:, None] * hidden[:, t, :]
  42. list_frames.append(frame)
  43. frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
  44. remainds[:, None] * hidden[:, t, :],
  45. frame)
  46. fires = torch.stack(list_fires, 1)
  47. frames = torch.stack(list_frames, 1)
  48. list_ls = []
  49. len_labels = torch.round(alphas.sum(-1)).int()
  50. max_label_len = len_labels.max()
  51. for b in range(batch_size):
  52. fire = fires[b, :]
  53. l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
  54. pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
  55. list_ls.append(torch.cat([l, pad_l], 0))
  56. return torch.stack(list_ls, 0), fires
  57. def cif_wo_hidden(alphas, threshold):
  58. batch_size, len_time = alphas.size()
  59. # loop varss
  60. integrate = torch.zeros([batch_size], device=alphas.device)
  61. # intermediate vars along time
  62. list_fires = []
  63. for t in range(len_time):
  64. alpha = alphas[:, t]
  65. integrate += alpha
  66. list_fires.append(integrate)
  67. fire_place = integrate >= threshold
  68. integrate = torch.where(fire_place,
  69. integrate - torch.ones([batch_size], device=alphas.device)*threshold,
  70. integrate)
  71. fires = torch.stack(list_fires, 1)
  72. return fires
  73. @tables.register("predictor_classes", "CifPredictorV3")
  74. class CifPredictorV3(torch.nn.Module):
  75. def __init__(self,
  76. idim,
  77. l_order,
  78. r_order,
  79. threshold=1.0,
  80. dropout=0.1,
  81. smooth_factor=1.0,
  82. noise_threshold=0,
  83. tail_threshold=0.0,
  84. tf2torch_tensor_name_prefix_torch="predictor",
  85. tf2torch_tensor_name_prefix_tf="seq2seq/cif",
  86. smooth_factor2=1.0,
  87. noise_threshold2=0,
  88. upsample_times=5,
  89. upsample_type="cnn",
  90. use_cif1_cnn=True,
  91. tail_mask=True,
  92. ):
  93. super(CifPredictorV3, self).__init__()
  94. self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
  95. self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
  96. self.cif_output = torch.nn.Linear(idim, 1)
  97. self.dropout = torch.nn.Dropout(p=dropout)
  98. self.threshold = threshold
  99. self.smooth_factor = smooth_factor
  100. self.noise_threshold = noise_threshold
  101. self.tail_threshold = tail_threshold
  102. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  103. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  104. self.upsample_times = upsample_times
  105. self.upsample_type = upsample_type
  106. self.use_cif1_cnn = use_cif1_cnn
  107. if self.upsample_type == 'cnn':
  108. self.upsample_cnn = torch.nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
  109. self.cif_output2 = torch.nn.Linear(idim, 1)
  110. elif self.upsample_type == 'cnn_blstm':
  111. self.upsample_cnn = torch.nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
  112. self.blstm = torch.nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
  113. self.cif_output2 = torch.nn.Linear(idim*2, 1)
  114. elif self.upsample_type == 'cnn_attn':
  115. self.upsample_cnn = torch.nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
  116. from funasr.models.transformer.encoder import EncoderLayer as TransformerEncoderLayer
  117. from funasr.models.transformer.attention import MultiHeadedAttention
  118. from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
  119. positionwise_layer_args = (
  120. idim,
  121. idim*2,
  122. 0.1,
  123. )
  124. self.self_attn = TransformerEncoderLayer(
  125. idim,
  126. MultiHeadedAttention(
  127. 4, idim, 0.1
  128. ),
  129. PositionwiseFeedForward(*positionwise_layer_args),
  130. 0.1,
  131. True, #normalize_before,
  132. False, #concat_after,
  133. )
  134. self.cif_output2 = torch.nn.Linear(idim, 1)
  135. self.smooth_factor2 = smooth_factor2
  136. self.noise_threshold2 = noise_threshold2
  137. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
  138. target_label_length=None):
  139. h = hidden
  140. context = h.transpose(1, 2)
  141. queries = self.pad(context)
  142. output = torch.relu(self.cif_conv1d(queries))
  143. # alphas2 is an extra head for timestamp prediction
  144. if not self.use_cif1_cnn:
  145. _output = context
  146. else:
  147. _output = output
  148. if self.upsample_type == 'cnn':
  149. output2 = self.upsample_cnn(_output)
  150. output2 = output2.transpose(1,2)
  151. elif self.upsample_type == 'cnn_blstm':
  152. output2 = self.upsample_cnn(_output)
  153. output2 = output2.transpose(1,2)
  154. output2, (_, _) = self.blstm(output2)
  155. elif self.upsample_type == 'cnn_attn':
  156. output2 = self.upsample_cnn(_output)
  157. output2 = output2.transpose(1,2)
  158. output2, _ = self.self_attn(output2, mask)
  159. # import pdb; pdb.set_trace()
  160. alphas2 = torch.sigmoid(self.cif_output2(output2))
  161. alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
  162. # repeat the mask in T demension to match the upsampled length
  163. if mask is not None:
  164. mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
  165. mask2 = mask2.unsqueeze(-1)
  166. alphas2 = alphas2 * mask2
  167. alphas2 = alphas2.squeeze(-1)
  168. token_num2 = alphas2.sum(-1)
  169. output = output.transpose(1, 2)
  170. output = self.cif_output(output)
  171. alphas = torch.sigmoid(output)
  172. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  173. if mask is not None:
  174. mask = mask.transpose(-1, -2).float()
  175. alphas = alphas * mask
  176. if mask_chunk_predictor is not None:
  177. alphas = alphas * mask_chunk_predictor
  178. alphas = alphas.squeeze(-1)
  179. mask = mask.squeeze(-1)
  180. if target_label_length is not None:
  181. target_length = target_label_length
  182. elif target_label is not None:
  183. target_length = (target_label != ignore_id).float().sum(-1)
  184. else:
  185. target_length = None
  186. token_num = alphas.sum(-1)
  187. if target_length is not None:
  188. alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
  189. elif self.tail_threshold > 0.0:
  190. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
  191. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  192. if target_length is None and self.tail_threshold > 0.0:
  193. token_num_int = torch.max(token_num).type(torch.int32).item()
  194. acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
  195. return acoustic_embeds, token_num, alphas, cif_peak, token_num2
  196. def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
  197. h = hidden
  198. b = hidden.shape[0]
  199. context = h.transpose(1, 2)
  200. queries = self.pad(context)
  201. output = torch.relu(self.cif_conv1d(queries))
  202. # alphas2 is an extra head for timestamp prediction
  203. if not self.use_cif1_cnn:
  204. _output = context
  205. else:
  206. _output = output
  207. if self.upsample_type == 'cnn':
  208. output2 = self.upsample_cnn(_output)
  209. output2 = output2.transpose(1,2)
  210. elif self.upsample_type == 'cnn_blstm':
  211. output2 = self.upsample_cnn(_output)
  212. output2 = output2.transpose(1,2)
  213. output2, (_, _) = self.blstm(output2)
  214. elif self.upsample_type == 'cnn_attn':
  215. output2 = self.upsample_cnn(_output)
  216. output2 = output2.transpose(1,2)
  217. output2, _ = self.self_attn(output2, mask)
  218. alphas2 = torch.sigmoid(self.cif_output2(output2))
  219. alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
  220. # repeat the mask in T demension to match the upsampled length
  221. if mask is not None:
  222. mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
  223. mask2 = mask2.unsqueeze(-1)
  224. alphas2 = alphas2 * mask2
  225. alphas2 = alphas2.squeeze(-1)
  226. _token_num = alphas2.sum(-1)
  227. if token_num is not None:
  228. alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
  229. # re-downsample
  230. ds_alphas = alphas2.reshape(b, -1, self.upsample_times).sum(-1)
  231. ds_cif_peak = cif_wo_hidden(ds_alphas, self.threshold - 1e-4)
  232. # upsampled alphas and cif_peak
  233. us_alphas = alphas2
  234. us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
  235. return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
  236. def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
  237. b, t, d = hidden.size()
  238. tail_threshold = self.tail_threshold
  239. if mask is not None:
  240. zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
  241. ones_t = torch.ones_like(zeros_t)
  242. mask_1 = torch.cat([mask, zeros_t], dim=1)
  243. mask_2 = torch.cat([ones_t, mask], dim=1)
  244. mask = mask_2 - mask_1
  245. tail_threshold = mask * tail_threshold
  246. alphas = torch.cat([alphas, zeros_t], dim=1)
  247. alphas = torch.add(alphas, tail_threshold)
  248. else:
  249. tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
  250. tail_threshold = torch.reshape(tail_threshold, (1, 1))
  251. alphas = torch.cat([alphas, tail_threshold], dim=1)
  252. zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
  253. hidden = torch.cat([hidden, zeros], dim=1)
  254. token_num = alphas.sum(dim=-1)
  255. token_num_floor = torch.floor(token_num)
  256. return hidden, alphas, token_num_floor
  257. def gen_frame_alignments(self,
  258. alphas: torch.Tensor = None,
  259. encoder_sequence_length: torch.Tensor = None):
  260. batch_size, maximum_length = alphas.size()
  261. int_type = torch.int32
  262. is_training = self.training
  263. if is_training:
  264. token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
  265. else:
  266. token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
  267. max_token_num = torch.max(token_num).item()
  268. alphas_cumsum = torch.cumsum(alphas, dim=1)
  269. alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
  270. alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
  271. index = torch.ones([batch_size, max_token_num], dtype=int_type)
  272. index = torch.cumsum(index, dim=1)
  273. index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
  274. index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
  275. index_div_bool_zeros = index_div.eq(0)
  276. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
  277. index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
  278. token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
  279. index_div_bool_zeros_count *= token_num_mask
  280. index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
  281. ones = torch.ones_like(index_div_bool_zeros_count_tile)
  282. zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
  283. ones = torch.cumsum(ones, dim=2)
  284. cond = index_div_bool_zeros_count_tile == ones
  285. index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
  286. index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
  287. index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
  288. index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
  289. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
  290. predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
  291. int_type).to(encoder_sequence_length.device)
  292. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
  293. predictor_alignments = index_div_bool_zeros_count_tile_out
  294. predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
  295. return predictor_alignments.detach(), predictor_alignments_length.detach()