cif.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724
  1. import torch
  2. from torch import nn
  3. import logging
  4. import numpy as np
  5. from funasr.modules.nets_utils import make_pad_mask
  6. from funasr.modules.streaming_utils.utils import sequence_mask
  7. class CifPredictor(nn.Module):
  8. def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
  9. super(CifPredictor, self).__init__()
  10. self.pad = nn.ConstantPad1d((l_order, r_order), 0)
  11. self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
  12. self.cif_output = nn.Linear(idim, 1)
  13. self.dropout = torch.nn.Dropout(p=dropout)
  14. self.threshold = threshold
  15. self.smooth_factor = smooth_factor
  16. self.noise_threshold = noise_threshold
  17. self.tail_threshold = tail_threshold
  18. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
  19. target_label_length=None):
  20. h = hidden
  21. context = h.transpose(1, 2)
  22. queries = self.pad(context)
  23. memory = self.cif_conv1d(queries)
  24. output = memory + context
  25. output = self.dropout(output)
  26. output = output.transpose(1, 2)
  27. output = torch.relu(output)
  28. output = self.cif_output(output)
  29. alphas = torch.sigmoid(output)
  30. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  31. if mask is not None:
  32. mask = mask.transpose(-1, -2).float()
  33. alphas = alphas * mask
  34. if mask_chunk_predictor is not None:
  35. alphas = alphas * mask_chunk_predictor
  36. alphas = alphas.squeeze(-1)
  37. mask = mask.squeeze(-1)
  38. if target_label_length is not None:
  39. target_length = target_label_length
  40. elif target_label is not None:
  41. target_length = (target_label != ignore_id).float().sum(-1)
  42. else:
  43. target_length = None
  44. token_num = alphas.sum(-1)
  45. if target_length is not None:
  46. alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
  47. elif self.tail_threshold > 0.0:
  48. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
  49. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  50. if target_length is None and self.tail_threshold > 0.0:
  51. token_num_int = torch.max(token_num).type(torch.int32).item()
  52. acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
  53. return acoustic_embeds, token_num, alphas, cif_peak
  54. def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
  55. b, t, d = hidden.size()
  56. tail_threshold = self.tail_threshold
  57. if mask is not None:
  58. zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
  59. ones_t = torch.ones_like(zeros_t)
  60. mask_1 = torch.cat([mask, zeros_t], dim=1)
  61. mask_2 = torch.cat([ones_t, mask], dim=1)
  62. mask = mask_2 - mask_1
  63. tail_threshold = mask * tail_threshold
  64. alphas = torch.cat([alphas, zeros_t], dim=1)
  65. alphas = torch.add(alphas, tail_threshold)
  66. else:
  67. tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
  68. tail_threshold = torch.reshape(tail_threshold, (1, 1))
  69. alphas = torch.cat([alphas, tail_threshold], dim=1)
  70. zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
  71. hidden = torch.cat([hidden, zeros], dim=1)
  72. token_num = alphas.sum(dim=-1)
  73. token_num_floor = torch.floor(token_num)
  74. return hidden, alphas, token_num_floor
  75. def gen_frame_alignments(self,
  76. alphas: torch.Tensor = None,
  77. encoder_sequence_length: torch.Tensor = None):
  78. batch_size, maximum_length = alphas.size()
  79. int_type = torch.int32
  80. is_training = self.training
  81. if is_training:
  82. token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
  83. else:
  84. token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
  85. max_token_num = torch.max(token_num).item()
  86. alphas_cumsum = torch.cumsum(alphas, dim=1)
  87. alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
  88. alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
  89. index = torch.ones([batch_size, max_token_num], dtype=int_type)
  90. index = torch.cumsum(index, dim=1)
  91. index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
  92. index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
  93. index_div_bool_zeros = index_div.eq(0)
  94. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
  95. index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
  96. token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
  97. index_div_bool_zeros_count *= token_num_mask
  98. index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
  99. ones = torch.ones_like(index_div_bool_zeros_count_tile)
  100. zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
  101. ones = torch.cumsum(ones, dim=2)
  102. cond = index_div_bool_zeros_count_tile == ones
  103. index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
  104. index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
  105. index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
  106. index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
  107. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
  108. predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
  109. int_type).to(encoder_sequence_length.device)
  110. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
  111. predictor_alignments = index_div_bool_zeros_count_tile_out
  112. predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
  113. return predictor_alignments.detach(), predictor_alignments_length.detach()
  114. class CifPredictorV2(nn.Module):
  115. def __init__(self,
  116. idim,
  117. l_order,
  118. r_order,
  119. threshold=1.0,
  120. dropout=0.1,
  121. smooth_factor=1.0,
  122. noise_threshold=0,
  123. tail_threshold=0.0,
  124. tf2torch_tensor_name_prefix_torch="predictor",
  125. tf2torch_tensor_name_prefix_tf="seq2seq/cif",
  126. tail_mask=True,
  127. ):
  128. super(CifPredictorV2, self).__init__()
  129. self.pad = nn.ConstantPad1d((l_order, r_order), 0)
  130. self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
  131. self.cif_output = nn.Linear(idim, 1)
  132. self.dropout = torch.nn.Dropout(p=dropout)
  133. self.threshold = threshold
  134. self.smooth_factor = smooth_factor
  135. self.noise_threshold = noise_threshold
  136. self.tail_threshold = tail_threshold
  137. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  138. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  139. self.tail_mask = tail_mask
  140. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
  141. target_label_length=None):
  142. h = hidden
  143. context = h.transpose(1, 2)
  144. queries = self.pad(context)
  145. output = torch.relu(self.cif_conv1d(queries))
  146. output = output.transpose(1, 2)
  147. output = self.cif_output(output)
  148. alphas = torch.sigmoid(output)
  149. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  150. if mask is not None:
  151. mask = mask.transpose(-1, -2).float()
  152. alphas = alphas * mask
  153. if mask_chunk_predictor is not None:
  154. alphas = alphas * mask_chunk_predictor
  155. alphas = alphas.squeeze(-1)
  156. mask = mask.squeeze(-1)
  157. if target_label_length is not None:
  158. target_length = target_label_length
  159. elif target_label is not None:
  160. target_length = (target_label != ignore_id).float().sum(-1)
  161. else:
  162. target_length = None
  163. token_num = alphas.sum(-1)
  164. if target_length is not None:
  165. alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
  166. elif self.tail_threshold > 0.0:
  167. if self.tail_mask:
  168. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
  169. else:
  170. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
  171. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  172. if target_length is None and self.tail_threshold > 0.0:
  173. token_num_int = torch.max(token_num).type(torch.int32).item()
  174. acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
  175. return acoustic_embeds, token_num, alphas, cif_peak
  176. def forward_chunk(self, hidden, cache=None):
  177. b, t, d = hidden.size()
  178. h = hidden
  179. context = h.transpose(1, 2)
  180. queries = self.pad(context)
  181. output = torch.relu(self.cif_conv1d(queries))
  182. output = output.transpose(1, 2)
  183. output = self.cif_output(output)
  184. alphas = torch.sigmoid(output)
  185. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  186. alphas = alphas.squeeze(-1)
  187. mask_chunk_predictor = None
  188. if cache is not None:
  189. mask_chunk_predictor = None
  190. mask_chunk_predictor = torch.zeros_like(alphas)
  191. mask_chunk_predictor[:, cache["pad_left"]:cache["stride"] + cache["pad_left"]] = 1.0
  192. if mask_chunk_predictor is not None:
  193. alphas = alphas * mask_chunk_predictor
  194. if cache is not None:
  195. if cache["is_final"]:
  196. alphas[:, cache["stride"] + cache["pad_left"] - 1] += 0.45
  197. if cache["cif_hidden"] is not None:
  198. hidden = torch.cat((cache["cif_hidden"], hidden), 1)
  199. if cache["cif_alphas"] is not None:
  200. alphas = torch.cat((cache["cif_alphas"], alphas), -1)
  201. token_num = alphas.sum(-1)
  202. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  203. len_time = alphas.size(-1)
  204. last_fire_place = len_time - 1
  205. last_fire_remainds = 0.0
  206. pre_alphas_length = 0
  207. last_fire = False
  208. mask_chunk_peak_predictor = None
  209. if cache is not None:
  210. mask_chunk_peak_predictor = None
  211. mask_chunk_peak_predictor = torch.zeros_like(cif_peak)
  212. if cache["cif_alphas"] is not None:
  213. pre_alphas_length = cache["cif_alphas"].size(-1)
  214. mask_chunk_peak_predictor[:, :pre_alphas_length] = 1.0
  215. mask_chunk_peak_predictor[:, pre_alphas_length + cache["pad_left"]:pre_alphas_length + cache["stride"] + cache["pad_left"]] = 1.0
  216. if mask_chunk_peak_predictor is not None:
  217. cif_peak = cif_peak * mask_chunk_peak_predictor.squeeze(-1)
  218. for i in range(len_time):
  219. if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold:
  220. last_fire_place = len_time - 1 - i
  221. last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold
  222. last_fire = True
  223. break
  224. if last_fire:
  225. last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
  226. cache["cif_hidden"] = hidden[:, last_fire_place:, :]
  227. cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
  228. else:
  229. cache["cif_hidden"] = hidden
  230. cache["cif_alphas"] = alphas
  231. token_num_int = token_num.floor().type(torch.int32).item()
  232. return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak
  233. def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
  234. b, t, d = hidden.size()
  235. tail_threshold = self.tail_threshold
  236. if mask is not None:
  237. zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
  238. ones_t = torch.ones_like(zeros_t)
  239. mask_1 = torch.cat([mask, zeros_t], dim=1)
  240. mask_2 = torch.cat([ones_t, mask], dim=1)
  241. mask = mask_2 - mask_1
  242. tail_threshold = mask * tail_threshold
  243. alphas = torch.cat([alphas, zeros_t], dim=1)
  244. alphas = torch.add(alphas, tail_threshold)
  245. else:
  246. tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
  247. tail_threshold = torch.reshape(tail_threshold, (1, 1))
  248. if b > 1:
  249. alphas = torch.cat([alphas, tail_threshold.repeat(b, 1)], dim=1)
  250. else:
  251. alphas = torch.cat([alphas, tail_threshold], dim=1)
  252. zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
  253. hidden = torch.cat([hidden, zeros], dim=1)
  254. token_num = alphas.sum(dim=-1)
  255. token_num_floor = torch.floor(token_num)
  256. return hidden, alphas, token_num_floor
  257. def gen_frame_alignments(self,
  258. alphas: torch.Tensor = None,
  259. encoder_sequence_length: torch.Tensor = None):
  260. batch_size, maximum_length = alphas.size()
  261. int_type = torch.int32
  262. is_training = self.training
  263. if is_training:
  264. token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
  265. else:
  266. token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
  267. max_token_num = torch.max(token_num).item()
  268. alphas_cumsum = torch.cumsum(alphas, dim=1)
  269. alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
  270. alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
  271. index = torch.ones([batch_size, max_token_num], dtype=int_type)
  272. index = torch.cumsum(index, dim=1)
  273. index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
  274. index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
  275. index_div_bool_zeros = index_div.eq(0)
  276. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
  277. index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
  278. token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
  279. index_div_bool_zeros_count *= token_num_mask
  280. index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
  281. ones = torch.ones_like(index_div_bool_zeros_count_tile)
  282. zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
  283. ones = torch.cumsum(ones, dim=2)
  284. cond = index_div_bool_zeros_count_tile == ones
  285. index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
  286. index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
  287. index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
  288. index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
  289. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
  290. predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
  291. int_type).to(encoder_sequence_length.device)
  292. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
  293. predictor_alignments = index_div_bool_zeros_count_tile_out
  294. predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
  295. return predictor_alignments.detach(), predictor_alignments_length.detach()
  296. def gen_tf2torch_map_dict(self):
  297. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  298. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  299. map_dict_local = {
  300. ## predictor
  301. "{}.cif_conv1d.weight".format(tensor_name_prefix_torch):
  302. {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
  303. "squeeze": None,
  304. "transpose": (2, 1, 0),
  305. }, # (256,256,3),(3,256,256)
  306. "{}.cif_conv1d.bias".format(tensor_name_prefix_torch):
  307. {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
  308. "squeeze": None,
  309. "transpose": None,
  310. }, # (256,),(256,)
  311. "{}.cif_output.weight".format(tensor_name_prefix_torch):
  312. {"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf),
  313. "squeeze": 0,
  314. "transpose": (1, 0),
  315. }, # (1,256),(1,256,1)
  316. "{}.cif_output.bias".format(tensor_name_prefix_torch):
  317. {"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf),
  318. "squeeze": None,
  319. "transpose": None,
  320. }, # (1,),(1,)
  321. }
  322. return map_dict_local
  323. def convert_tf2torch(self,
  324. var_dict_tf,
  325. var_dict_torch,
  326. ):
  327. map_dict = self.gen_tf2torch_map_dict()
  328. var_dict_torch_update = dict()
  329. for name in sorted(var_dict_torch.keys(), reverse=False):
  330. names = name.split('.')
  331. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  332. name_tf = map_dict[name]["name"]
  333. data_tf = var_dict_tf[name_tf]
  334. if map_dict[name]["squeeze"] is not None:
  335. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  336. if map_dict[name]["transpose"] is not None:
  337. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  338. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  339. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  340. var_dict_torch[
  341. name].size(),
  342. data_tf.size())
  343. var_dict_torch_update[name] = data_tf
  344. logging.info(
  345. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  346. var_dict_tf[name_tf].shape))
  347. return var_dict_torch_update
  348. class mae_loss(nn.Module):
  349. def __init__(self, normalize_length=False):
  350. super(mae_loss, self).__init__()
  351. self.normalize_length = normalize_length
  352. self.criterion = torch.nn.L1Loss(reduction='sum')
  353. def forward(self, token_length, pre_token_length):
  354. loss_token_normalizer = token_length.size(0)
  355. if self.normalize_length:
  356. loss_token_normalizer = token_length.sum().type(torch.float32)
  357. loss = self.criterion(token_length, pre_token_length)
  358. loss = loss / loss_token_normalizer
  359. return loss
  360. def cif(hidden, alphas, threshold):
  361. batch_size, len_time, hidden_size = hidden.size()
  362. # loop varss
  363. integrate = torch.zeros([batch_size], device=hidden.device)
  364. frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
  365. # intermediate vars along time
  366. list_fires = []
  367. list_frames = []
  368. for t in range(len_time):
  369. alpha = alphas[:, t]
  370. distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
  371. integrate += alpha
  372. list_fires.append(integrate)
  373. fire_place = integrate >= threshold
  374. integrate = torch.where(fire_place,
  375. integrate - torch.ones([batch_size], device=hidden.device),
  376. integrate)
  377. cur = torch.where(fire_place,
  378. distribution_completion,
  379. alpha)
  380. remainds = alpha - cur
  381. frame += cur[:, None] * hidden[:, t, :]
  382. list_frames.append(frame)
  383. frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
  384. remainds[:, None] * hidden[:, t, :],
  385. frame)
  386. fires = torch.stack(list_fires, 1)
  387. frames = torch.stack(list_frames, 1)
  388. list_ls = []
  389. len_labels = torch.round(alphas.sum(-1)).int()
  390. max_label_len = len_labels.max()
  391. for b in range(batch_size):
  392. fire = fires[b, :]
  393. l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
  394. pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
  395. list_ls.append(torch.cat([l, pad_l], 0))
  396. return torch.stack(list_ls, 0), fires
  397. def cif_wo_hidden(alphas, threshold):
  398. batch_size, len_time = alphas.size()
  399. # loop varss
  400. integrate = torch.zeros([batch_size], device=alphas.device)
  401. # intermediate vars along time
  402. list_fires = []
  403. for t in range(len_time):
  404. alpha = alphas[:, t]
  405. integrate += alpha
  406. list_fires.append(integrate)
  407. fire_place = integrate >= threshold
  408. integrate = torch.where(fire_place,
  409. integrate - torch.ones([batch_size], device=alphas.device),
  410. integrate)
  411. fires = torch.stack(list_fires, 1)
  412. return fires
  413. class CifPredictorV3(nn.Module):
  414. def __init__(self,
  415. idim,
  416. l_order,
  417. r_order,
  418. threshold=1.0,
  419. dropout=0.1,
  420. smooth_factor=1.0,
  421. noise_threshold=0,
  422. tail_threshold=0.0,
  423. tf2torch_tensor_name_prefix_torch="predictor",
  424. tf2torch_tensor_name_prefix_tf="seq2seq/cif",
  425. smooth_factor2=1.0,
  426. noise_threshold2=0,
  427. upsample_times=5,
  428. upsample_type="cnn",
  429. use_cif1_cnn=True,
  430. tail_mask=True,
  431. ):
  432. super(CifPredictorV3, self).__init__()
  433. self.pad = nn.ConstantPad1d((l_order, r_order), 0)
  434. self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
  435. self.cif_output = nn.Linear(idim, 1)
  436. self.dropout = torch.nn.Dropout(p=dropout)
  437. self.threshold = threshold
  438. self.smooth_factor = smooth_factor
  439. self.noise_threshold = noise_threshold
  440. self.tail_threshold = tail_threshold
  441. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  442. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  443. self.upsample_times = upsample_times
  444. self.upsample_type = upsample_type
  445. self.use_cif1_cnn = use_cif1_cnn
  446. if self.upsample_type == 'cnn':
  447. self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
  448. self.cif_output2 = nn.Linear(idim, 1)
  449. elif self.upsample_type == 'cnn_blstm':
  450. self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
  451. self.blstm = nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
  452. self.cif_output2 = nn.Linear(idim*2, 1)
  453. elif self.upsample_type == 'cnn_attn':
  454. self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
  455. from funasr.models.encoder.transformer_encoder import EncoderLayer as TransformerEncoderLayer
  456. from funasr.modules.attention import MultiHeadedAttention
  457. from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
  458. positionwise_layer_args = (
  459. idim,
  460. idim*2,
  461. 0.1,
  462. )
  463. self.self_attn = TransformerEncoderLayer(
  464. idim,
  465. MultiHeadedAttention(
  466. 4, idim, 0.1
  467. ),
  468. PositionwiseFeedForward(*positionwise_layer_args),
  469. 0.1,
  470. True, #normalize_before,
  471. False, #concat_after,
  472. )
  473. self.cif_output2 = nn.Linear(idim, 1)
  474. self.smooth_factor2 = smooth_factor2
  475. self.noise_threshold2 = noise_threshold2
  476. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
  477. target_label_length=None):
  478. h = hidden
  479. context = h.transpose(1, 2)
  480. queries = self.pad(context)
  481. output = torch.relu(self.cif_conv1d(queries))
  482. # alphas2 is an extra head for timestamp prediction
  483. if not self.use_cif1_cnn:
  484. _output = context
  485. else:
  486. _output = output
  487. if self.upsample_type == 'cnn':
  488. output2 = self.upsample_cnn(_output)
  489. output2 = output2.transpose(1,2)
  490. elif self.upsample_type == 'cnn_blstm':
  491. output2 = self.upsample_cnn(_output)
  492. output2 = output2.transpose(1,2)
  493. output2, (_, _) = self.blstm(output2)
  494. elif self.upsample_type == 'cnn_attn':
  495. output2 = self.upsample_cnn(_output)
  496. output2 = output2.transpose(1,2)
  497. output2, _ = self.self_attn(output2, mask)
  498. # import pdb; pdb.set_trace()
  499. alphas2 = torch.sigmoid(self.cif_output2(output2))
  500. alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
  501. # repeat the mask in T demension to match the upsampled length
  502. if mask is not None:
  503. mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
  504. mask2 = mask2.unsqueeze(-1)
  505. alphas2 = alphas2 * mask2
  506. alphas2 = alphas2.squeeze(-1)
  507. token_num2 = alphas2.sum(-1)
  508. output = output.transpose(1, 2)
  509. output = self.cif_output(output)
  510. alphas = torch.sigmoid(output)
  511. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  512. if mask is not None:
  513. mask = mask.transpose(-1, -2).float()
  514. alphas = alphas * mask
  515. if mask_chunk_predictor is not None:
  516. alphas = alphas * mask_chunk_predictor
  517. alphas = alphas.squeeze(-1)
  518. mask = mask.squeeze(-1)
  519. if target_label_length is not None:
  520. target_length = target_label_length
  521. elif target_label is not None:
  522. target_length = (target_label != ignore_id).float().sum(-1)
  523. else:
  524. target_length = None
  525. token_num = alphas.sum(-1)
  526. if target_length is not None:
  527. alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
  528. elif self.tail_threshold > 0.0:
  529. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
  530. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  531. if target_length is None and self.tail_threshold > 0.0:
  532. token_num_int = torch.max(token_num).type(torch.int32).item()
  533. acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
  534. return acoustic_embeds, token_num, alphas, cif_peak, token_num2
  535. def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
  536. h = hidden
  537. b = hidden.shape[0]
  538. context = h.transpose(1, 2)
  539. queries = self.pad(context)
  540. output = torch.relu(self.cif_conv1d(queries))
  541. # alphas2 is an extra head for timestamp prediction
  542. if not self.use_cif1_cnn:
  543. _output = context
  544. else:
  545. _output = output
  546. if self.upsample_type == 'cnn':
  547. output2 = self.upsample_cnn(_output)
  548. output2 = output2.transpose(1,2)
  549. elif self.upsample_type == 'cnn_blstm':
  550. output2 = self.upsample_cnn(_output)
  551. output2 = output2.transpose(1,2)
  552. output2, (_, _) = self.blstm(output2)
  553. elif self.upsample_type == 'cnn_attn':
  554. output2 = self.upsample_cnn(_output)
  555. output2 = output2.transpose(1,2)
  556. output2, _ = self.self_attn(output2, mask)
  557. alphas2 = torch.sigmoid(self.cif_output2(output2))
  558. alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
  559. # repeat the mask in T demension to match the upsampled length
  560. if mask is not None:
  561. mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
  562. mask2 = mask2.unsqueeze(-1)
  563. alphas2 = alphas2 * mask2
  564. alphas2 = alphas2.squeeze(-1)
  565. _token_num = alphas2.sum(-1)
  566. if token_num is not None:
  567. alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
  568. # re-downsample
  569. ds_alphas = alphas2.reshape(b, -1, self.upsample_times).sum(-1)
  570. ds_cif_peak = cif_wo_hidden(ds_alphas, self.threshold - 1e-4)
  571. # upsampled alphas and cif_peak
  572. us_alphas = alphas2
  573. us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
  574. return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
  575. def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
  576. b, t, d = hidden.size()
  577. tail_threshold = self.tail_threshold
  578. if mask is not None:
  579. zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
  580. ones_t = torch.ones_like(zeros_t)
  581. mask_1 = torch.cat([mask, zeros_t], dim=1)
  582. mask_2 = torch.cat([ones_t, mask], dim=1)
  583. mask = mask_2 - mask_1
  584. tail_threshold = mask * tail_threshold
  585. alphas = torch.cat([alphas, zeros_t], dim=1)
  586. alphas = torch.add(alphas, tail_threshold)
  587. else:
  588. tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
  589. tail_threshold = torch.reshape(tail_threshold, (1, 1))
  590. alphas = torch.cat([alphas, tail_threshold], dim=1)
  591. zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
  592. hidden = torch.cat([hidden, zeros], dim=1)
  593. token_num = alphas.sum(dim=-1)
  594. token_num_floor = torch.floor(token_num)
  595. return hidden, alphas, token_num_floor
  596. def gen_frame_alignments(self,
  597. alphas: torch.Tensor = None,
  598. encoder_sequence_length: torch.Tensor = None):
  599. batch_size, maximum_length = alphas.size()
  600. int_type = torch.int32
  601. is_training = self.training
  602. if is_training:
  603. token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
  604. else:
  605. token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
  606. max_token_num = torch.max(token_num).item()
  607. alphas_cumsum = torch.cumsum(alphas, dim=1)
  608. alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
  609. alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
  610. index = torch.ones([batch_size, max_token_num], dtype=int_type)
  611. index = torch.cumsum(index, dim=1)
  612. index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
  613. index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
  614. index_div_bool_zeros = index_div.eq(0)
  615. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
  616. index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
  617. token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
  618. index_div_bool_zeros_count *= token_num_mask
  619. index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
  620. ones = torch.ones_like(index_div_bool_zeros_count_tile)
  621. zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
  622. ones = torch.cumsum(ones, dim=2)
  623. cond = index_div_bool_zeros_count_tile == ones
  624. index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
  625. index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
  626. index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
  627. index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
  628. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
  629. predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
  630. int_type).to(encoder_sequence_length.device)
  631. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
  632. predictor_alignments = index_div_bool_zeros_count_tile_out
  633. predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
  634. return predictor_alignments.detach(), predictor_alignments_length.detach()