cif_predictor.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import torch
  6. import logging
  7. import numpy as np
  8. from funasr.register import tables
  9. from funasr.train_utils.device_funcs import to_device
  10. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  11. @tables.register("predictor_classes", "CifPredictor")
  12. class CifPredictor(torch.nn.Module):
  13. def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
  14. super().__init__()
  15. self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
  16. self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
  17. self.cif_output = torch.nn.Linear(idim, 1)
  18. self.dropout = torch.nn.Dropout(p=dropout)
  19. self.threshold = threshold
  20. self.smooth_factor = smooth_factor
  21. self.noise_threshold = noise_threshold
  22. self.tail_threshold = tail_threshold
  23. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
  24. target_label_length=None):
  25. h = hidden
  26. context = h.transpose(1, 2)
  27. queries = self.pad(context)
  28. memory = self.cif_conv1d(queries)
  29. output = memory + context
  30. output = self.dropout(output)
  31. output = output.transpose(1, 2)
  32. output = torch.relu(output)
  33. output = self.cif_output(output)
  34. alphas = torch.sigmoid(output)
  35. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  36. if mask is not None:
  37. mask = mask.transpose(-1, -2).float()
  38. alphas = alphas * mask
  39. if mask_chunk_predictor is not None:
  40. alphas = alphas * mask_chunk_predictor
  41. alphas = alphas.squeeze(-1)
  42. mask = mask.squeeze(-1)
  43. if target_label_length is not None:
  44. target_length = target_label_length
  45. elif target_label is not None:
  46. target_length = (target_label != ignore_id).float().sum(-1)
  47. else:
  48. target_length = None
  49. token_num = alphas.sum(-1)
  50. if target_length is not None:
  51. alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
  52. elif self.tail_threshold > 0.0:
  53. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
  54. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  55. if target_length is None and self.tail_threshold > 0.0:
  56. token_num_int = torch.max(token_num).type(torch.int32).item()
  57. acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
  58. return acoustic_embeds, token_num, alphas, cif_peak
  59. def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
  60. b, t, d = hidden.size()
  61. tail_threshold = self.tail_threshold
  62. if mask is not None:
  63. zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
  64. ones_t = torch.ones_like(zeros_t)
  65. mask_1 = torch.cat([mask, zeros_t], dim=1)
  66. mask_2 = torch.cat([ones_t, mask], dim=1)
  67. mask = mask_2 - mask_1
  68. tail_threshold = mask * tail_threshold
  69. alphas = torch.cat([alphas, zeros_t], dim=1)
  70. alphas = torch.add(alphas, tail_threshold)
  71. else:
  72. tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
  73. tail_threshold = torch.reshape(tail_threshold, (1, 1))
  74. alphas = torch.cat([alphas, tail_threshold], dim=1)
  75. zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
  76. hidden = torch.cat([hidden, zeros], dim=1)
  77. token_num = alphas.sum(dim=-1)
  78. token_num_floor = torch.floor(token_num)
  79. return hidden, alphas, token_num_floor
  80. def gen_frame_alignments(self,
  81. alphas: torch.Tensor = None,
  82. encoder_sequence_length: torch.Tensor = None):
  83. batch_size, maximum_length = alphas.size()
  84. int_type = torch.int32
  85. is_training = self.training
  86. if is_training:
  87. token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
  88. else:
  89. token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
  90. max_token_num = torch.max(token_num).item()
  91. alphas_cumsum = torch.cumsum(alphas, dim=1)
  92. alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
  93. alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
  94. index = torch.ones([batch_size, max_token_num], dtype=int_type)
  95. index = torch.cumsum(index, dim=1)
  96. index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
  97. index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
  98. index_div_bool_zeros = index_div.eq(0)
  99. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
  100. index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
  101. token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
  102. index_div_bool_zeros_count *= token_num_mask
  103. index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
  104. ones = torch.ones_like(index_div_bool_zeros_count_tile)
  105. zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
  106. ones = torch.cumsum(ones, dim=2)
  107. cond = index_div_bool_zeros_count_tile == ones
  108. index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
  109. index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
  110. index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
  111. index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
  112. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
  113. predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
  114. int_type).to(encoder_sequence_length.device)
  115. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
  116. predictor_alignments = index_div_bool_zeros_count_tile_out
  117. predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
  118. return predictor_alignments.detach(), predictor_alignments_length.detach()
  119. @tables.register("predictor_classes", "CifPredictorV2")
  120. class CifPredictorV2(torch.nn.Module):
  121. def __init__(self,
  122. idim,
  123. l_order,
  124. r_order,
  125. threshold=1.0,
  126. dropout=0.1,
  127. smooth_factor=1.0,
  128. noise_threshold=0,
  129. tail_threshold=0.0,
  130. tf2torch_tensor_name_prefix_torch="predictor",
  131. tf2torch_tensor_name_prefix_tf="seq2seq/cif",
  132. tail_mask=True,
  133. ):
  134. super(CifPredictorV2, self).__init__()
  135. self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
  136. self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
  137. self.cif_output = torch.nn.Linear(idim, 1)
  138. self.dropout = torch.nn.Dropout(p=dropout)
  139. self.threshold = threshold
  140. self.smooth_factor = smooth_factor
  141. self.noise_threshold = noise_threshold
  142. self.tail_threshold = tail_threshold
  143. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  144. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  145. self.tail_mask = tail_mask
  146. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
  147. target_label_length=None):
  148. h = hidden
  149. context = h.transpose(1, 2)
  150. queries = self.pad(context)
  151. output = torch.relu(self.cif_conv1d(queries))
  152. output = output.transpose(1, 2)
  153. output = self.cif_output(output)
  154. alphas = torch.sigmoid(output)
  155. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  156. if mask is not None:
  157. mask = mask.transpose(-1, -2).float()
  158. alphas = alphas * mask
  159. if mask_chunk_predictor is not None:
  160. alphas = alphas * mask_chunk_predictor
  161. alphas = alphas.squeeze(-1)
  162. mask = mask.squeeze(-1)
  163. if target_label_length is not None:
  164. target_length = target_label_length.squeeze(-1)
  165. elif target_label is not None:
  166. target_length = (target_label != ignore_id).float().sum(-1)
  167. else:
  168. target_length = None
  169. token_num = alphas.sum(-1)
  170. if target_length is not None:
  171. alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
  172. elif self.tail_threshold > 0.0:
  173. if self.tail_mask:
  174. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
  175. else:
  176. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
  177. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  178. if target_length is None and self.tail_threshold > 0.0:
  179. token_num_int = torch.max(token_num).type(torch.int32).item()
  180. acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
  181. return acoustic_embeds, token_num, alphas, cif_peak
  182. def forward_chunk(self, hidden, cache=None, **kwargs):
  183. is_final = kwargs.get("is_final", False)
  184. batch_size, len_time, hidden_size = hidden.shape
  185. h = hidden
  186. context = h.transpose(1, 2)
  187. queries = self.pad(context)
  188. output = torch.relu(self.cif_conv1d(queries))
  189. output = output.transpose(1, 2)
  190. output = self.cif_output(output)
  191. alphas = torch.sigmoid(output)
  192. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  193. alphas = alphas.squeeze(-1)
  194. token_length = []
  195. list_fires = []
  196. list_frames = []
  197. cache_alphas = []
  198. cache_hiddens = []
  199. if cache is not None and "chunk_size" in cache:
  200. alphas[:, :cache["chunk_size"][0]] = 0.0
  201. if not is_final:
  202. alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
  203. if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
  204. cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
  205. cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
  206. hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
  207. alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
  208. if cache is not None and is_final:
  209. tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
  210. tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
  211. tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
  212. hidden = torch.cat((hidden, tail_hidden), dim=1)
  213. alphas = torch.cat((alphas, tail_alphas), dim=1)
  214. len_time = alphas.shape[1]
  215. for b in range(batch_size):
  216. integrate = 0.0
  217. frames = torch.zeros((hidden_size), device=hidden.device)
  218. list_frame = []
  219. list_fire = []
  220. for t in range(len_time):
  221. alpha = alphas[b][t]
  222. if alpha + integrate < self.threshold:
  223. integrate += alpha
  224. list_fire.append(integrate)
  225. frames += alpha * hidden[b][t]
  226. else:
  227. frames += (self.threshold - integrate) * hidden[b][t]
  228. list_frame.append(frames)
  229. integrate += alpha
  230. list_fire.append(integrate)
  231. integrate -= self.threshold
  232. frames = integrate * hidden[b][t]
  233. cache_alphas.append(integrate)
  234. if integrate > 0.0:
  235. cache_hiddens.append(frames / integrate)
  236. else:
  237. cache_hiddens.append(frames)
  238. token_length.append(torch.tensor(len(list_frame), device=alphas.device))
  239. list_fires.append(list_fire)
  240. list_frames.append(list_frame)
  241. cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
  242. cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
  243. cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
  244. cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
  245. max_token_len = max(token_length)
  246. if max_token_len == 0:
  247. return hidden, torch.stack(token_length, 0), None, None
  248. list_ls = []
  249. for b in range(batch_size):
  250. pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device)
  251. if token_length[b] == 0:
  252. list_ls.append(pad_frames)
  253. else:
  254. list_frames[b] = torch.stack(list_frames[b])
  255. list_ls.append(torch.cat((list_frames[b], pad_frames), dim=0))
  256. cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
  257. cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
  258. cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
  259. cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
  260. return torch.stack(list_ls, 0), torch.stack(token_length, 0), None, None
  261. def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
  262. b, t, d = hidden.size()
  263. tail_threshold = self.tail_threshold
  264. if mask is not None:
  265. zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
  266. ones_t = torch.ones_like(zeros_t)
  267. mask_1 = torch.cat([mask, zeros_t], dim=1)
  268. mask_2 = torch.cat([ones_t, mask], dim=1)
  269. mask = mask_2 - mask_1
  270. tail_threshold = mask * tail_threshold
  271. alphas = torch.cat([alphas, zeros_t], dim=1)
  272. alphas = torch.add(alphas, tail_threshold)
  273. else:
  274. tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
  275. tail_threshold = torch.reshape(tail_threshold, (1, 1))
  276. if b > 1:
  277. alphas = torch.cat([alphas, tail_threshold.repeat(b, 1)], dim=1)
  278. else:
  279. alphas = torch.cat([alphas, tail_threshold], dim=1)
  280. zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
  281. hidden = torch.cat([hidden, zeros], dim=1)
  282. token_num = alphas.sum(dim=-1)
  283. token_num_floor = torch.floor(token_num)
  284. return hidden, alphas, token_num_floor
  285. def gen_frame_alignments(self,
  286. alphas: torch.Tensor = None,
  287. encoder_sequence_length: torch.Tensor = None):
  288. batch_size, maximum_length = alphas.size()
  289. int_type = torch.int32
  290. is_training = self.training
  291. if is_training:
  292. token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
  293. else:
  294. token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
  295. max_token_num = torch.max(token_num).item()
  296. alphas_cumsum = torch.cumsum(alphas, dim=1)
  297. alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
  298. alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
  299. index = torch.ones([batch_size, max_token_num], dtype=int_type)
  300. index = torch.cumsum(index, dim=1)
  301. index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
  302. index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
  303. index_div_bool_zeros = index_div.eq(0)
  304. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
  305. index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
  306. token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
  307. index_div_bool_zeros_count *= token_num_mask
  308. index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
  309. ones = torch.ones_like(index_div_bool_zeros_count_tile)
  310. zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
  311. ones = torch.cumsum(ones, dim=2)
  312. cond = index_div_bool_zeros_count_tile == ones
  313. index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
  314. index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
  315. index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
  316. index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
  317. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
  318. predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
  319. int_type).to(encoder_sequence_length.device)
  320. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
  321. predictor_alignments = index_div_bool_zeros_count_tile_out
  322. predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
  323. return predictor_alignments.detach(), predictor_alignments_length.detach()
  324. def gen_tf2torch_map_dict(self):
  325. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  326. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  327. map_dict_local = {
  328. ## predictor
  329. "{}.cif_conv1d.weight".format(tensor_name_prefix_torch):
  330. {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
  331. "squeeze": None,
  332. "transpose": (2, 1, 0),
  333. }, # (256,256,3),(3,256,256)
  334. "{}.cif_conv1d.bias".format(tensor_name_prefix_torch):
  335. {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
  336. "squeeze": None,
  337. "transpose": None,
  338. }, # (256,),(256,)
  339. "{}.cif_output.weight".format(tensor_name_prefix_torch):
  340. {"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf),
  341. "squeeze": 0,
  342. "transpose": (1, 0),
  343. }, # (1,256),(1,256,1)
  344. "{}.cif_output.bias".format(tensor_name_prefix_torch):
  345. {"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf),
  346. "squeeze": None,
  347. "transpose": None,
  348. }, # (1,),(1,)
  349. }
  350. return map_dict_local
  351. def convert_tf2torch(self,
  352. var_dict_tf,
  353. var_dict_torch,
  354. ):
  355. map_dict = self.gen_tf2torch_map_dict()
  356. var_dict_torch_update = dict()
  357. for name in sorted(var_dict_torch.keys(), reverse=False):
  358. names = name.split('.')
  359. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  360. name_tf = map_dict[name]["name"]
  361. data_tf = var_dict_tf[name_tf]
  362. if map_dict[name]["squeeze"] is not None:
  363. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  364. if map_dict[name]["transpose"] is not None:
  365. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  366. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  367. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  368. var_dict_torch[
  369. name].size(),
  370. data_tf.size())
  371. var_dict_torch_update[name] = data_tf
  372. logging.info(
  373. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  374. var_dict_tf[name_tf].shape))
  375. return var_dict_torch_update
  376. class mae_loss(torch.nn.Module):
  377. def __init__(self, normalize_length=False):
  378. super(mae_loss, self).__init__()
  379. self.normalize_length = normalize_length
  380. self.criterion = torch.nn.L1Loss(reduction='sum')
  381. def forward(self, token_length, pre_token_length):
  382. loss_token_normalizer = token_length.size(0)
  383. if self.normalize_length:
  384. loss_token_normalizer = token_length.sum().type(torch.float32)
  385. loss = self.criterion(token_length, pre_token_length)
  386. loss = loss / loss_token_normalizer
  387. return loss
  388. def cif(hidden, alphas, threshold):
  389. batch_size, len_time, hidden_size = hidden.size()
  390. # loop varss
  391. integrate = torch.zeros([batch_size], device=hidden.device)
  392. frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
  393. # intermediate vars along time
  394. list_fires = []
  395. list_frames = []
  396. for t in range(len_time):
  397. alpha = alphas[:, t]
  398. distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
  399. integrate += alpha
  400. list_fires.append(integrate)
  401. fire_place = integrate >= threshold
  402. integrate = torch.where(fire_place,
  403. integrate - torch.ones([batch_size], device=hidden.device),
  404. integrate)
  405. cur = torch.where(fire_place,
  406. distribution_completion,
  407. alpha)
  408. remainds = alpha - cur
  409. frame += cur[:, None] * hidden[:, t, :]
  410. list_frames.append(frame)
  411. frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
  412. remainds[:, None] * hidden[:, t, :],
  413. frame)
  414. fires = torch.stack(list_fires, 1)
  415. frames = torch.stack(list_frames, 1)
  416. list_ls = []
  417. len_labels = torch.round(alphas.sum(-1)).int()
  418. max_label_len = len_labels.max()
  419. for b in range(batch_size):
  420. fire = fires[b, :]
  421. l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
  422. pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
  423. list_ls.append(torch.cat([l, pad_l], 0))
  424. return torch.stack(list_ls, 0), fires
  425. def cif_wo_hidden(alphas, threshold):
  426. batch_size, len_time = alphas.size()
  427. # loop varss
  428. integrate = torch.zeros([batch_size], device=alphas.device)
  429. # intermediate vars along time
  430. list_fires = []
  431. for t in range(len_time):
  432. alpha = alphas[:, t]
  433. integrate += alpha
  434. list_fires.append(integrate)
  435. fire_place = integrate >= threshold
  436. integrate = torch.where(fire_place,
  437. integrate - torch.ones([batch_size], device=alphas.device)*threshold,
  438. integrate)
  439. fires = torch.stack(list_fires, 1)
  440. return fires