cif.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. import torch
  2. from torch import nn
  3. import logging
  4. import numpy as np
  5. from funasr.modules.nets_utils import make_pad_mask
  6. from funasr.modules.streaming_utils.utils import sequence_mask
  7. class CifPredictor(nn.Module):
  8. def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
  9. super(CifPredictor, self).__init__()
  10. self.pad = nn.ConstantPad1d((l_order, r_order), 0)
  11. self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
  12. self.cif_output = nn.Linear(idim, 1)
  13. self.dropout = torch.nn.Dropout(p=dropout)
  14. self.threshold = threshold
  15. self.smooth_factor = smooth_factor
  16. self.noise_threshold = noise_threshold
  17. self.tail_threshold = tail_threshold
  18. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
  19. target_label_length=None):
  20. h = hidden
  21. context = h.transpose(1, 2)
  22. queries = self.pad(context)
  23. memory = self.cif_conv1d(queries)
  24. output = memory + context
  25. output = self.dropout(output)
  26. output = output.transpose(1, 2)
  27. output = torch.relu(output)
  28. output = self.cif_output(output)
  29. alphas = torch.sigmoid(output)
  30. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  31. if mask is not None:
  32. mask = mask.transpose(-1, -2).float()
  33. alphas = alphas * mask
  34. if mask_chunk_predictor is not None:
  35. alphas = alphas * mask_chunk_predictor
  36. alphas = alphas.squeeze(-1)
  37. mask = mask.squeeze(-1)
  38. if target_label_length is not None:
  39. target_length = target_label_length
  40. elif target_label is not None:
  41. target_length = (target_label != ignore_id).float().sum(-1)
  42. else:
  43. target_length = None
  44. token_num = alphas.sum(-1)
  45. if target_length is not None:
  46. alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
  47. elif self.tail_threshold > 0.0:
  48. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
  49. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  50. if target_length is None and self.tail_threshold > 0.0:
  51. token_num_int = torch.max(token_num).type(torch.int32).item()
  52. acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
  53. return acoustic_embeds, token_num, alphas, cif_peak
  54. def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
  55. b, t, d = hidden.size()
  56. tail_threshold = self.tail_threshold
  57. if mask is not None:
  58. zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
  59. ones_t = torch.ones_like(zeros_t)
  60. mask_1 = torch.cat([mask, zeros_t], dim=1)
  61. mask_2 = torch.cat([ones_t, mask], dim=1)
  62. mask = mask_2 - mask_1
  63. tail_threshold = mask * tail_threshold
  64. alphas = torch.cat([alphas, zeros_t], dim=1)
  65. alphas = torch.add(alphas, tail_threshold)
  66. else:
  67. tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
  68. tail_threshold = torch.reshape(tail_threshold, (1, 1))
  69. alphas = torch.cat([alphas, tail_threshold], dim=1)
  70. zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
  71. hidden = torch.cat([hidden, zeros], dim=1)
  72. token_num = alphas.sum(dim=-1)
  73. token_num_floor = torch.floor(token_num)
  74. return hidden, alphas, token_num_floor
  75. def gen_frame_alignments(self,
  76. alphas: torch.Tensor = None,
  77. encoder_sequence_length: torch.Tensor = None):
  78. batch_size, maximum_length = alphas.size()
  79. int_type = torch.int32
  80. is_training = self.training
  81. if is_training:
  82. token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
  83. else:
  84. token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
  85. max_token_num = torch.max(token_num).item()
  86. alphas_cumsum = torch.cumsum(alphas, dim=1)
  87. alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
  88. alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
  89. index = torch.ones([batch_size, max_token_num], dtype=int_type)
  90. index = torch.cumsum(index, dim=1)
  91. index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
  92. index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
  93. index_div_bool_zeros = index_div.eq(0)
  94. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
  95. index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
  96. token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
  97. index_div_bool_zeros_count *= token_num_mask
  98. index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
  99. ones = torch.ones_like(index_div_bool_zeros_count_tile)
  100. zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
  101. ones = torch.cumsum(ones, dim=2)
  102. cond = index_div_bool_zeros_count_tile == ones
  103. index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
  104. index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
  105. index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
  106. index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
  107. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
  108. predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
  109. int_type).to(encoder_sequence_length.device)
  110. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
  111. predictor_alignments = index_div_bool_zeros_count_tile_out
  112. predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
  113. return predictor_alignments.detach(), predictor_alignments_length.detach()
  114. class CifPredictorV2(nn.Module):
  115. def __init__(self,
  116. idim,
  117. l_order,
  118. r_order,
  119. threshold=1.0,
  120. dropout=0.1,
  121. smooth_factor=1.0,
  122. noise_threshold=0,
  123. tail_threshold=0.0,
  124. tf2torch_tensor_name_prefix_torch="predictor",
  125. tf2torch_tensor_name_prefix_tf="seq2seq/cif",
  126. tail_mask=True,
  127. ):
  128. super(CifPredictorV2, self).__init__()
  129. self.pad = nn.ConstantPad1d((l_order, r_order), 0)
  130. self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
  131. self.cif_output = nn.Linear(idim, 1)
  132. self.dropout = torch.nn.Dropout(p=dropout)
  133. self.threshold = threshold
  134. self.smooth_factor = smooth_factor
  135. self.noise_threshold = noise_threshold
  136. self.tail_threshold = tail_threshold
  137. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  138. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  139. self.tail_mask = tail_mask
  140. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
  141. target_label_length=None):
  142. h = hidden
  143. context = h.transpose(1, 2)
  144. queries = self.pad(context)
  145. output = torch.relu(self.cif_conv1d(queries))
  146. output = output.transpose(1, 2)
  147. output = self.cif_output(output)
  148. alphas = torch.sigmoid(output)
  149. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  150. if mask is not None:
  151. mask = mask.transpose(-1, -2).float()
  152. alphas = alphas * mask
  153. if mask_chunk_predictor is not None:
  154. alphas = alphas * mask_chunk_predictor
  155. alphas = alphas.squeeze(-1)
  156. mask = mask.squeeze(-1)
  157. if target_label_length is not None:
  158. target_length = target_label_length
  159. elif target_label is not None:
  160. target_length = (target_label != ignore_id).float().sum(-1)
  161. else:
  162. target_length = None
  163. token_num = alphas.sum(-1)
  164. if target_length is not None:
  165. alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
  166. elif self.tail_threshold > 0.0:
  167. if self.tail_mask:
  168. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
  169. else:
  170. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
  171. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  172. if target_length is None and self.tail_threshold > 0.0:
  173. token_num_int = torch.max(token_num).type(torch.int32).item()
  174. acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
  175. return acoustic_embeds, token_num, alphas, cif_peak
  176. def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
  177. b, t, d = hidden.size()
  178. tail_threshold = self.tail_threshold
  179. if mask is not None:
  180. zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
  181. ones_t = torch.ones_like(zeros_t)
  182. mask_1 = torch.cat([mask, zeros_t], dim=1)
  183. mask_2 = torch.cat([ones_t, mask], dim=1)
  184. mask = mask_2 - mask_1
  185. tail_threshold = mask * tail_threshold
  186. alphas = torch.cat([alphas, zeros_t], dim=1)
  187. alphas = torch.add(alphas, tail_threshold)
  188. else:
  189. tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
  190. tail_threshold = torch.reshape(tail_threshold, (1, 1))
  191. if b > 1:
  192. alphas = torch.cat([alphas, tail_threshold.repeat(b, 1)], dim=1)
  193. else:
  194. alphas = torch.cat([alphas, tail_threshold], dim=1)
  195. zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
  196. hidden = torch.cat([hidden, zeros], dim=1)
  197. token_num = alphas.sum(dim=-1)
  198. token_num_floor = torch.floor(token_num)
  199. return hidden, alphas, token_num_floor
  200. def gen_frame_alignments(self,
  201. alphas: torch.Tensor = None,
  202. encoder_sequence_length: torch.Tensor = None):
  203. batch_size, maximum_length = alphas.size()
  204. int_type = torch.int32
  205. is_training = self.training
  206. if is_training:
  207. token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
  208. else:
  209. token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
  210. max_token_num = torch.max(token_num).item()
  211. alphas_cumsum = torch.cumsum(alphas, dim=1)
  212. alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
  213. alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
  214. index = torch.ones([batch_size, max_token_num], dtype=int_type)
  215. index = torch.cumsum(index, dim=1)
  216. index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
  217. index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
  218. index_div_bool_zeros = index_div.eq(0)
  219. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
  220. index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
  221. token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
  222. index_div_bool_zeros_count *= token_num_mask
  223. index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
  224. ones = torch.ones_like(index_div_bool_zeros_count_tile)
  225. zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
  226. ones = torch.cumsum(ones, dim=2)
  227. cond = index_div_bool_zeros_count_tile == ones
  228. index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
  229. index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
  230. index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
  231. index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
  232. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
  233. predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
  234. int_type).to(encoder_sequence_length.device)
  235. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
  236. predictor_alignments = index_div_bool_zeros_count_tile_out
  237. predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
  238. return predictor_alignments.detach(), predictor_alignments_length.detach()
  239. def gen_tf2torch_map_dict(self):
  240. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  241. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  242. map_dict_local = {
  243. ## predictor
  244. "{}.cif_conv1d.weight".format(tensor_name_prefix_torch):
  245. {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
  246. "squeeze": None,
  247. "transpose": (2, 1, 0),
  248. }, # (256,256,3),(3,256,256)
  249. "{}.cif_conv1d.bias".format(tensor_name_prefix_torch):
  250. {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
  251. "squeeze": None,
  252. "transpose": None,
  253. }, # (256,),(256,)
  254. "{}.cif_output.weight".format(tensor_name_prefix_torch):
  255. {"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf),
  256. "squeeze": 0,
  257. "transpose": (1, 0),
  258. }, # (1,256),(1,256,1)
  259. "{}.cif_output.bias".format(tensor_name_prefix_torch):
  260. {"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf),
  261. "squeeze": None,
  262. "transpose": None,
  263. }, # (1,),(1,)
  264. }
  265. return map_dict_local
  266. def convert_tf2torch(self,
  267. var_dict_tf,
  268. var_dict_torch,
  269. ):
  270. map_dict = self.gen_tf2torch_map_dict()
  271. var_dict_torch_update = dict()
  272. for name in sorted(var_dict_torch.keys(), reverse=False):
  273. names = name.split('.')
  274. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  275. name_tf = map_dict[name]["name"]
  276. data_tf = var_dict_tf[name_tf]
  277. if map_dict[name]["squeeze"] is not None:
  278. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  279. if map_dict[name]["transpose"] is not None:
  280. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  281. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  282. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  283. var_dict_torch[
  284. name].size(),
  285. data_tf.size())
  286. var_dict_torch_update[name] = data_tf
  287. logging.info(
  288. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  289. var_dict_tf[name_tf].shape))
  290. return var_dict_torch_update
  291. class mae_loss(nn.Module):
  292. def __init__(self, normalize_length=False):
  293. super(mae_loss, self).__init__()
  294. self.normalize_length = normalize_length
  295. self.criterion = torch.nn.L1Loss(reduction='sum')
  296. def forward(self, token_length, pre_token_length):
  297. loss_token_normalizer = token_length.size(0)
  298. if self.normalize_length:
  299. loss_token_normalizer = token_length.sum().type(torch.float32)
  300. loss = self.criterion(token_length, pre_token_length)
  301. loss = loss / loss_token_normalizer
  302. return loss
  303. def cif(hidden, alphas, threshold):
  304. batch_size, len_time, hidden_size = hidden.size()
  305. # loop varss
  306. integrate = torch.zeros([batch_size], device=hidden.device)
  307. frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
  308. # intermediate vars along time
  309. list_fires = []
  310. list_frames = []
  311. for t in range(len_time):
  312. alpha = alphas[:, t]
  313. distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
  314. integrate += alpha
  315. list_fires.append(integrate)
  316. fire_place = integrate >= threshold
  317. integrate = torch.where(fire_place,
  318. integrate - torch.ones([batch_size], device=hidden.device),
  319. integrate)
  320. cur = torch.where(fire_place,
  321. distribution_completion,
  322. alpha)
  323. remainds = alpha - cur
  324. frame += cur[:, None] * hidden[:, t, :]
  325. list_frames.append(frame)
  326. frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
  327. remainds[:, None] * hidden[:, t, :],
  328. frame)
  329. fires = torch.stack(list_fires, 1)
  330. frames = torch.stack(list_frames, 1)
  331. list_ls = []
  332. len_labels = torch.round(alphas.sum(-1)).int()
  333. max_label_len = len_labels.max()
  334. for b in range(batch_size):
  335. fire = fires[b, :]
  336. l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
  337. pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
  338. list_ls.append(torch.cat([l, pad_l], 0))
  339. return torch.stack(list_ls, 0), fires
  340. def cif_wo_hidden(alphas, threshold):
  341. batch_size, len_time = alphas.size()
  342. # loop varss
  343. integrate = torch.zeros([batch_size], device=alphas.device)
  344. # intermediate vars along time
  345. list_fires = []
  346. for t in range(len_time):
  347. alpha = alphas[:, t]
  348. integrate += alpha
  349. list_fires.append(integrate)
  350. fire_place = integrate >= threshold
  351. integrate = torch.where(fire_place,
  352. integrate - torch.ones([batch_size], device=alphas.device),
  353. integrate)
  354. fires = torch.stack(list_fires, 1)
  355. return fires
  356. class CifPredictorV3(nn.Module):
  357. def __init__(self,
  358. idim,
  359. l_order,
  360. r_order,
  361. threshold=1.0,
  362. dropout=0.1,
  363. smooth_factor=1.0,
  364. noise_threshold=0,
  365. tail_threshold=0.0,
  366. tf2torch_tensor_name_prefix_torch="predictor",
  367. tf2torch_tensor_name_prefix_tf="seq2seq/cif",
  368. smooth_factor2=1.0,
  369. noise_threshold2=0,
  370. upsample_times=5,
  371. upsample_type="cnn",
  372. use_cif1_cnn=True,
  373. tail_mask=True,
  374. ):
  375. super(CifPredictorV3, self).__init__()
  376. self.pad = nn.ConstantPad1d((l_order, r_order), 0)
  377. self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
  378. self.cif_output = nn.Linear(idim, 1)
  379. self.dropout = torch.nn.Dropout(p=dropout)
  380. self.threshold = threshold
  381. self.smooth_factor = smooth_factor
  382. self.noise_threshold = noise_threshold
  383. self.tail_threshold = tail_threshold
  384. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  385. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  386. self.upsample_times = upsample_times
  387. self.upsample_type = upsample_type
  388. self.use_cif1_cnn = use_cif1_cnn
  389. if self.upsample_type == 'cnn':
  390. self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
  391. self.cif_output2 = nn.Linear(idim, 1)
  392. elif self.upsample_type == 'cnn_blstm':
  393. self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
  394. self.blstm = nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
  395. self.cif_output2 = nn.Linear(idim*2, 1)
  396. elif self.upsample_type == 'cnn_attn':
  397. self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
  398. from funasr.models.encoder.transformer_encoder import EncoderLayer as TransformerEncoderLayer
  399. from funasr.modules.attention import MultiHeadedAttention
  400. from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
  401. positionwise_layer_args = (
  402. idim,
  403. idim*2,
  404. 0.1,
  405. )
  406. self.self_attn = TransformerEncoderLayer(
  407. idim,
  408. MultiHeadedAttention(
  409. 4, idim, 0.1
  410. ),
  411. PositionwiseFeedForward(*positionwise_layer_args),
  412. 0.1,
  413. True, #normalize_before,
  414. False, #concat_after,
  415. )
  416. self.cif_output2 = nn.Linear(idim, 1)
  417. self.smooth_factor2 = smooth_factor2
  418. self.noise_threshold2 = noise_threshold2
  419. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
  420. target_label_length=None):
  421. h = hidden
  422. context = h.transpose(1, 2)
  423. queries = self.pad(context)
  424. output = torch.relu(self.cif_conv1d(queries))
  425. # alphas2 is an extra head for timestamp prediction
  426. if not self.use_cif1_cnn:
  427. _output = context
  428. else:
  429. _output = output
  430. if self.upsample_type == 'cnn':
  431. output2 = self.upsample_cnn(_output)
  432. output2 = output2.transpose(1,2)
  433. elif self.upsample_type == 'cnn_blstm':
  434. output2 = self.upsample_cnn(_output)
  435. output2 = output2.transpose(1,2)
  436. output2, (_, _) = self.blstm(output2)
  437. elif self.upsample_type == 'cnn_attn':
  438. output2 = self.upsample_cnn(_output)
  439. output2 = output2.transpose(1,2)
  440. output2, _ = self.self_attn(output2, mask)
  441. # import pdb; pdb.set_trace()
  442. alphas2 = torch.sigmoid(self.cif_output2(output2))
  443. alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
  444. # repeat the mask in T demension to match the upsampled length
  445. if mask is not None:
  446. mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
  447. mask2 = mask2.unsqueeze(-1)
  448. alphas2 = alphas2 * mask2
  449. alphas2 = alphas2.squeeze(-1)
  450. token_num2 = alphas2.sum(-1)
  451. output = output.transpose(1, 2)
  452. output = self.cif_output(output)
  453. alphas = torch.sigmoid(output)
  454. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  455. if mask is not None:
  456. mask = mask.transpose(-1, -2).float()
  457. alphas = alphas * mask
  458. if mask_chunk_predictor is not None:
  459. alphas = alphas * mask_chunk_predictor
  460. alphas = alphas.squeeze(-1)
  461. mask = mask.squeeze(-1)
  462. if target_label_length is not None:
  463. target_length = target_label_length
  464. elif target_label is not None:
  465. target_length = (target_label != ignore_id).float().sum(-1)
  466. else:
  467. target_length = None
  468. token_num = alphas.sum(-1)
  469. if target_length is not None:
  470. alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
  471. elif self.tail_threshold > 0.0:
  472. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
  473. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  474. if target_length is None and self.tail_threshold > 0.0:
  475. token_num_int = torch.max(token_num).type(torch.int32).item()
  476. acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
  477. return acoustic_embeds, token_num, alphas, cif_peak, token_num2
  478. def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
  479. h = hidden
  480. b = hidden.shape[0]
  481. context = h.transpose(1, 2)
  482. queries = self.pad(context)
  483. output = torch.relu(self.cif_conv1d(queries))
  484. # alphas2 is an extra head for timestamp prediction
  485. if not self.use_cif1_cnn:
  486. _output = context
  487. else:
  488. _output = output
  489. if self.upsample_type == 'cnn':
  490. output2 = self.upsample_cnn(_output)
  491. output2 = output2.transpose(1,2)
  492. elif self.upsample_type == 'cnn_blstm':
  493. output2 = self.upsample_cnn(_output)
  494. output2 = output2.transpose(1,2)
  495. output2, (_, _) = self.blstm(output2)
  496. elif self.upsample_type == 'cnn_attn':
  497. output2 = self.upsample_cnn(_output)
  498. output2 = output2.transpose(1,2)
  499. output2, _ = self.self_attn(output2, mask)
  500. alphas2 = torch.sigmoid(self.cif_output2(output2))
  501. alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
  502. # repeat the mask in T demension to match the upsampled length
  503. if mask is not None:
  504. mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
  505. mask2 = mask2.unsqueeze(-1)
  506. alphas2 = alphas2 * mask2
  507. alphas2 = alphas2.squeeze(-1)
  508. _token_num = alphas2.sum(-1)
  509. if token_num is not None:
  510. alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
  511. # re-downsample
  512. ds_alphas = alphas2.reshape(b, -1, self.upsample_times).sum(-1)
  513. ds_cif_peak = cif_wo_hidden(ds_alphas, self.threshold - 1e-4)
  514. # upsampled alphas and cif_peak
  515. us_alphas = alphas2
  516. us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
  517. return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
  518. def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
  519. b, t, d = hidden.size()
  520. tail_threshold = self.tail_threshold
  521. if mask is not None:
  522. zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
  523. ones_t = torch.ones_like(zeros_t)
  524. mask_1 = torch.cat([mask, zeros_t], dim=1)
  525. mask_2 = torch.cat([ones_t, mask], dim=1)
  526. mask = mask_2 - mask_1
  527. tail_threshold = mask * tail_threshold
  528. alphas = torch.cat([alphas, zeros_t], dim=1)
  529. alphas = torch.add(alphas, tail_threshold)
  530. else:
  531. tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
  532. tail_threshold = torch.reshape(tail_threshold, (1, 1))
  533. alphas = torch.cat([alphas, tail_threshold], dim=1)
  534. zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
  535. hidden = torch.cat([hidden, zeros], dim=1)
  536. token_num = alphas.sum(dim=-1)
  537. token_num_floor = torch.floor(token_num)
  538. return hidden, alphas, token_num_floor
  539. def gen_frame_alignments(self,
  540. alphas: torch.Tensor = None,
  541. encoder_sequence_length: torch.Tensor = None):
  542. batch_size, maximum_length = alphas.size()
  543. int_type = torch.int32
  544. is_training = self.training
  545. if is_training:
  546. token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
  547. else:
  548. token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
  549. max_token_num = torch.max(token_num).item()
  550. alphas_cumsum = torch.cumsum(alphas, dim=1)
  551. alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
  552. alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
  553. index = torch.ones([batch_size, max_token_num], dtype=int_type)
  554. index = torch.cumsum(index, dim=1)
  555. index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
  556. index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
  557. index_div_bool_zeros = index_div.eq(0)
  558. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
  559. index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
  560. token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
  561. index_div_bool_zeros_count *= token_num_mask
  562. index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
  563. ones = torch.ones_like(index_div_bool_zeros_count_tile)
  564. zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
  565. ones = torch.cumsum(ones, dim=2)
  566. cond = index_div_bool_zeros_count_tile == ones
  567. index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
  568. index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
  569. index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
  570. index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
  571. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
  572. predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
  573. int_type).to(encoder_sequence_length.device)
  574. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
  575. predictor_alignments = index_div_bool_zeros_count_tile_out
  576. predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
  577. return predictor_alignments.detach(), predictor_alignments_length.detach()